✨ feat(config): add the num_threads option
This commit is contained in:
parent
5e4629b8ea
commit
4e5221d7a8
@ -169,6 +169,7 @@ class Tagger:
|
||||
cache_dir=None,
|
||||
hf_token=HF_TOKEN,
|
||||
loglevel=logging.INFO,
|
||||
num_threads=None,
|
||||
):
|
||||
"""Initialize the Tagger object with the model repository and tokens.
|
||||
|
||||
@ -183,7 +184,7 @@ class Tagger:
|
||||
self.model_target_size = None
|
||||
self.cache_dir = cache_dir
|
||||
self.hf_token = hf_token
|
||||
self.load_model(model_repo, cache_dir, hf_token, num_threads=8)
|
||||
self.load_model(model_repo, cache_dir, hf_token, num_threads=num_threads)
|
||||
|
||||
def load_model(
|
||||
self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None
|
||||
@ -255,6 +256,7 @@ class Tagger:
|
||||
array = np.asarray(padded_image, dtype=np.float32)
|
||||
array = array[:, :, [2, 1, 0]]
|
||||
return array
|
||||
|
||||
def tag(
|
||||
self,
|
||||
image: Image.Image | list[Image.Image],
|
||||
|
Loading…
x
Reference in New Issue
Block a user