|
|
|
@ -169,6 +169,7 @@ class Tagger:
|
|
|
|
|
cache_dir=None,
|
|
|
|
|
hf_token=HF_TOKEN,
|
|
|
|
|
loglevel=logging.INFO,
|
|
|
|
|
num_threads=None,
|
|
|
|
|
):
|
|
|
|
|
"""Initialize the Tagger object with the model repository and tokens.
|
|
|
|
|
|
|
|
|
@ -176,6 +177,8 @@ class Tagger:
|
|
|
|
|
model_repo (str): Repository name on HuggingFace.
|
|
|
|
|
cache_dir (str, optional): Directory to cache the model. Defaults to None.
|
|
|
|
|
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
|
|
|
|
|
loglevel (int, optional): Logging level. Defaults to logging.INFO.
|
|
|
|
|
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
|
|
|
|
"""
|
|
|
|
|
self.logger = logging.getLogger("wdtagger")
|
|
|
|
|
self.logger.setLevel(loglevel)
|
|
|
|
@ -183,7 +186,7 @@ class Tagger:
|
|
|
|
|
self.model_target_size = None
|
|
|
|
|
self.cache_dir = cache_dir
|
|
|
|
|
self.hf_token = hf_token
|
|
|
|
|
self.load_model(model_repo, cache_dir, hf_token)
|
|
|
|
|
self.load_model(model_repo, cache_dir, hf_token, num_threads=num_threads)
|
|
|
|
|
|
|
|
|
|
def load_model(
|
|
|
|
|
self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None
|
|
|
|
@ -194,6 +197,7 @@ class Tagger:
|
|
|
|
|
model_repo (str): Repository name on HuggingFace.
|
|
|
|
|
cache_dir (str, optional): Directory to cache the model. Defaults to None.
|
|
|
|
|
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
|
|
|
|
|
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
|
|
|
|
"""
|
|
|
|
|
with console.status("Loading model..."):
|
|
|
|
|
csv_path = huggingface_hub.hf_hub_download(
|
|
|
|
@ -252,8 +256,9 @@ class Tagger:
|
|
|
|
|
(target_size, target_size),
|
|
|
|
|
Image.BICUBIC,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return np.asarray(padded_image, dtype=np.float32)
|
|
|
|
|
array = np.asarray(padded_image, dtype=np.float32)
|
|
|
|
|
array = array[:, :, [2, 1, 0]]
|
|
|
|
|
return array
|
|
|
|
|
|
|
|
|
|
def tag(
|
|
|
|
|
self,
|
|
|
|
|