feat(config): add the num_threads option

This commit is contained in:
Jannchie 2024-06-11 16:17:09 +09:00
parent 5e4629b8ea
commit 4e5221d7a8

View File

@ -169,6 +169,7 @@ class Tagger:
cache_dir=None,
hf_token=HF_TOKEN,
loglevel=logging.INFO,
num_threads=None,
):
"""Initialize the Tagger object with the model repository and tokens.
@ -183,7 +184,7 @@ class Tagger:
self.model_target_size = None
self.cache_dir = cache_dir
self.hf_token = hf_token
self.load_model(model_repo, cache_dir, hf_token, num_threads=8)
self.load_model(model_repo, cache_dir, hf_token, num_threads=num_threads)
def load_model(
self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None
@ -255,6 +256,7 @@ class Tagger:
array = np.asarray(padded_image, dtype=np.float32)
array = array[:, :, [2, 1, 0]]
return array
def tag(
self,
image: Image.Image | list[Image.Image],