4 Commits

Author SHA1 Message Date
be7085f2f7 version: v0.4.0 2024-06-11 18:28:51 +09:00
fd67f54fcc 📚 docs: add comment docs 2024-06-11 18:28:30 +09:00
4e5221d7a8 feat(config): add the num_threads option 2024-06-11 16:17:09 +09:00
5e4629b8ea 🩹 fix(color): rgb -> bgr 2024-06-11 16:15:40 +09:00
2 changed files with 9 additions and 4 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "wdtagger"
version = "0.3.0"
version = "0.4.0"
description = ""
authors = ["Jianqi Pan <jannchie@gmail.com>"]
readme = "README.md"

View File

@ -169,6 +169,7 @@ class Tagger:
cache_dir=None,
hf_token=HF_TOKEN,
loglevel=logging.INFO,
num_threads=None,
):
"""Initialize the Tagger object with the model repository and tokens.
@ -176,6 +177,8 @@ class Tagger:
model_repo (str): Repository name on HuggingFace.
cache_dir (str, optional): Directory to cache the model. Defaults to None.
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
loglevel (int, optional): Logging level. Defaults to logging.INFO.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
"""
self.logger = logging.getLogger("wdtagger")
self.logger.setLevel(loglevel)
@ -183,7 +186,7 @@ class Tagger:
self.model_target_size = None
self.cache_dir = cache_dir
self.hf_token = hf_token
self.load_model(model_repo, cache_dir, hf_token)
self.load_model(model_repo, cache_dir, hf_token, num_threads=num_threads)
def load_model(
self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None
@ -194,6 +197,7 @@ class Tagger:
model_repo (str): Repository name on HuggingFace.
cache_dir (str, optional): Directory to cache the model. Defaults to None.
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
"""
with console.status("Loading model..."):
csv_path = huggingface_hub.hf_hub_download(
@ -252,8 +256,9 @@ class Tagger:
(target_size, target_size),
Image.BICUBIC,
)
return np.asarray(padded_image, dtype=np.float32)
array = np.asarray(padded_image, dtype=np.float32)
array = array[:, :, [2, 1, 0]]
return array
def tag(
self,