diff --git a/wdtagger/__init__.py b/wdtagger/__init__.py index 3494ba1..7651dd5 100644 --- a/wdtagger/__init__.py +++ b/wdtagger/__init__.py @@ -183,7 +183,7 @@ class Tagger: self.model_target_size = None self.cache_dir = cache_dir self.hf_token = hf_token - self.load_model(model_repo, cache_dir, hf_token) + self.load_model(model_repo, cache_dir, hf_token, num_threads=8) def load_model( self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None @@ -252,9 +252,9 @@ class Tagger: (target_size, target_size), Image.BICUBIC, ) - - return np.asarray(padded_image, dtype=np.float32) - + array = np.asarray(padded_image, dtype=np.float32) + array = array[:, :, [2, 1, 0]] + return array def tag( self, image: Image.Image | list[Image.Image],