diff --git a/wdtagger/__init__.py b/wdtagger/__init__.py index 7651dd5..bc2952a 100644 --- a/wdtagger/__init__.py +++ b/wdtagger/__init__.py @@ -169,6 +169,7 @@ class Tagger: cache_dir=None, hf_token=HF_TOKEN, loglevel=logging.INFO, + num_threads=None, ): """Initialize the Tagger object with the model repository and tokens. @@ -183,7 +184,7 @@ class Tagger: self.model_target_size = None self.cache_dir = cache_dir self.hf_token = hf_token - self.load_model(model_repo, cache_dir, hf_token, num_threads=8) + self.load_model(model_repo, cache_dir, hf_token, num_threads=num_threads) def load_model( self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None @@ -255,6 +256,7 @@ class Tagger: array = np.asarray(padded_image, dtype=np.float32) array = array[:, :, [2, 1, 0]] return array + def tag( self, image: Image.Image | list[Image.Image],