6 Commits

Author SHA1 Message Date
be7085f2f7 version: v0.4.0 2024-06-11 18:28:51 +09:00
fd67f54fcc 📚 docs: add comment docs 2024-06-11 18:28:30 +09:00
4e5221d7a8 feat(config): add the num_threads option 2024-06-11 16:17:09 +09:00
5e4629b8ea 🩹 fix(color): rgb -> bgr 2024-06-11 16:15:40 +09:00
ae2838f58e version: v0.3.0 2024-06-11 03:18:40 +09:00
f9ec9de157 feat(logger): use logger 2024-06-11 03:18:23 +09:00
3 changed files with 44 additions and 12 deletions

9
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,9 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnType": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
},
},
}

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "wdtagger"
version = "0.2.0"
version = "0.4.0"
description = ""
authors = ["Jianqi Pan <jannchie@gmail.com>"]
readme = "README.md"

View File

@ -1,3 +1,4 @@
import logging
import os
import time
from collections import OrderedDict
@ -9,6 +10,7 @@ import pandas as pd
import rich
import rich.live
from PIL import Image
from rich.logging import RichHandler
# Access console for rich text and logging
console = rich.get_console()
@ -67,7 +69,13 @@ def load_labels(dataframe) -> list[str]:
class Result:
def __init__(self, pred, sep_tags, general_threshold=0.35, character_threshold=0.9):
def __init__(
self,
pred,
sep_tags,
general_threshold=0.35,
character_threshold=0.9,
):
"""Initialize the Result object to store tagging results.
Args:
@ -160,6 +168,8 @@ class Tagger:
model_repo="SmilingWolf/wd-swinv2-tagger-v3",
cache_dir=None,
hf_token=HF_TOKEN,
loglevel=logging.INFO,
num_threads=None,
):
"""Initialize the Tagger object with the model repository and tokens.
@ -167,19 +177,27 @@ class Tagger:
model_repo (str): Repository name on HuggingFace.
cache_dir (str, optional): Directory to cache the model. Defaults to None.
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
loglevel (int, optional): Logging level. Defaults to logging.INFO.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
"""
self.logger = logging.getLogger("wdtagger")
self.logger.setLevel(loglevel)
self.logger.addHandler(RichHandler())
self.model_target_size = None
self.cache_dir = cache_dir
self.hf_token = hf_token
self.load_model(model_repo, cache_dir, hf_token)
self.load_model(model_repo, cache_dir, hf_token, num_threads=num_threads)
def load_model(self, model_repo, cache_dir=None, hf_token=None):
def load_model(
self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None
):
"""Load the model and tags from the specified repository.
Args:
model_repo (str): Repository name on HuggingFace.
cache_dir (str, optional): Directory to cache the model. Defaults to None.
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
"""
with console.status("Loading model..."):
csv_path = huggingface_hub.hf_hub_download(
@ -188,6 +206,7 @@ class Tagger:
cache_dir=cache_dir,
use_auth_token=hf_token,
)
model_path = huggingface_hub.hf_hub_download(
model_repo,
MODEL_FILENAME,
@ -197,8 +216,11 @@ class Tagger:
tags_df = pd.read_csv(csv_path)
self.sep_tags = load_labels(tags_df)
model = rt.InferenceSession(model_path)
options = rt.SessionOptions()
if num_threads:
options.intra_op_num_threads = num_threads
options.inter_op_num_threads = num_threads
model = rt.InferenceSession(model_path, options)
_, height, _, _ = model.get_inputs()[0].shape
self.model_target_size = height
self.model = model
@ -234,8 +256,9 @@ class Tagger:
(target_size, target_size),
Image.BICUBIC,
)
return np.asarray(padded_image, dtype=np.float32)
array = np.asarray(padded_image, dtype=np.float32)
array = array[:, :, [2, 1, 0]]
return array
def tag(
self,
@ -266,9 +289,9 @@ class Tagger:
]
duration = time.time() - started_at
image_length = len(images)
console.log(f"Tagging {image_length} image{
's' if image_length > 1 else ''
} took {duration:.2f} seconds.")
self.logger.info(
f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds."
)
return results[0] if len(results) == 1 else results
@ -278,4 +301,4 @@ if __name__ == "__main__":
tagger = Tagger()
image = Image.open("./tests/images/赤松楓.9d64b955.jpeg")
result = tagger.tag(image)
console.log(result)
tagger.logger.info(result)