4 Commits

3 changed files with 136 additions and 110 deletions

View File

@ -1,6 +1,6 @@
[project]
name = "wdtagger"
version = "0.11.0"
version = "0.11.1"
description = "A simple and easy-to-use wrapper for the tagger model created by [SmilingWolf](https://github.com/SmilingWolf) which is specifically designed for tagging anime illustrations."
authors = [{ name = "Jianqi Pan", email = "jannchie@gmail.com" }]
readme = "README.md"
@ -10,7 +10,6 @@ dependencies = [
"numpy>=2.1.3",
"pandas>=2.2.3",
"pillow>=11.0.0",
"rich>=13.9.4",
]
[dependency-groups]

View File

@ -14,7 +14,6 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError
from PIL import Image
from PIL.ImageFile import ImageFile
from rich.console import Console
from timm.data import create_transform, resolve_data_config
from torch import Tensor, nn
from torch.nn import functional as F
@ -137,15 +136,7 @@ class Result:
character_tag: dict[str, float],
general_tag: dict[str, float],
) -> None:
"""Initialize the Result object to store tagging results.
Args:
preds (np.array): Predictions array from the model.
sep_tags (tuple): Tuple containing separated tags based on categories.
general_threshold (float): Threshold for general tags.
character_threshold (float): Threshold for character tags.
"""
"""Initialize the Result object with the tags and their ratings."""
self.general_tag_data = general_tag
self.character_tag_data = character_tag
self.rating_data = rating_data
@ -153,12 +144,27 @@ class Result:
@property
def general_tags(self) -> tuple[str, ...]:
"""Return general tags as a tuple."""
return tuple(self.general_tag_data.keys())
return tuple(
d.replace("_", " ").replace("(", R"\(").replace(")", R"\)")
for d in sorted(
self.general_tag_data,
key=lambda k: self.general_tag_data[k],
reverse=True,
)
)
@property
def character_tags(self) -> tuple[str, ...]:
"""Return character tags as a tuple."""
return tuple(self.character_tag_data.keys())
return tuple(
d.replace("_", " ").replace("(", R"\(").replace(")", R"\)")
for d in sorted(
self.character_tag_data,
key=lambda k: self.character_tag_data[k],
reverse=True,
)
)
@property
def rating(self) -> Literal["general", "sensitive", "questionable", "explicit"]:
@ -174,7 +180,7 @@ class Result:
reverse=True,
)
string = [x[0] for x in string]
return ", ".join(string).replace("_", " ").replace("(", R"\(").replace(")", R"\)")
return ", ".join(string)
@property
def character_tags_string(self) -> str:
@ -185,7 +191,7 @@ class Result:
reverse=True,
)
string = [x[0] for x in string]
return ", ".join(string).replace("_", " ").replace("(", R"\(").replace(")", R"\)")
return ", ".join(string)
@property
def all_tags(self) -> list[str]:
@ -194,7 +200,7 @@ class Result:
@property
def all_tags_string(self) -> str:
return ", ".join(self.all_tags).replace("_", " ").replace("(", R"\(").replace(")", R"\)")
return ", ".join(self.all_tags)
def __str__(self) -> str:
"""Return a formatted string representation of the tags and their ratings."""
@ -213,30 +219,9 @@ class Tagger:
self,
model_repo: str = "SmilingWolf/wd-swinv2-tagger-v3",
hf_token: str = HF_TOKEN,
console: Console | None = None,
*,
slient: bool = False,
) -> None:
"""Initialize the Tagger object with the model repository and tokens.
Args:
model_repo (str): Repository name on HuggingFace.
cache_dir (str, optional): Directory to cache the model. Defaults to None.
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
loglevel (int, optional): Logging level. Defaults to logging.INFO.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
providers (list, optional): List of providers for ONNX runtime. Defaults to None.
console (rich.console.Console, optional): Rich console object. Defaults to None.
"""
self.slient = slient
if not slient:
if not console:
from rich import get_console
self.console = get_console()
else:
self.console = console
self.logger = logging.getLogger("wdtagger")
"""Initialize the Tagger object with the model repository and tokens."""
self.logger = logging.getLogger("wdtagger")
self.model_target_size = None
self.hf_token = hf_token
self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -254,14 +239,13 @@ class Tagger:
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
"""
if not self.slient:
with self.console.status(f"Loading model '{model_repo}'..."):
self._do_load_model(model_repo)
else:
self._do_load_model(model_repo)
start_time = time.time()
self.logger.info("Loading model from %s", model_repo)
self._do_load_model(model_repo)
self.logger.info("Model loaded successfully in %.2fs", time.time() - start_time)
def _do_load_model(self, model_repo: str) -> None:
model: nn.Module = timm.create_model("hf-hub:" + model_repo).eval()
model: nn.Module = timm.create_model(f"hf-hub:{model_repo}").eval()
state_dict = timm.models.load_state_dict_from_hf(model_repo)
model.load_state_dict(state_dict)
@ -339,13 +323,12 @@ class Tagger:
duration = time.time() - started_at
image_length = len(images)
if not self.slient:
self.logger.info(
"Tagged %d image%s in %.2fs",
image_length,
"s" if image_length > 1 else "",
duration,
)
self.logger.info(
"Tagged %d image%s in %.2fs",
image_length,
"s" if image_length > 1 else "",
duration,
)
if input_is_list:
return results
return results[0] if len(results) == 1 else results

158
uv.lock generated
View File

@ -2,13 +2,31 @@ version = 1
requires-python = ">=3.10"
resolution-markers = [
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
]
conflicts = [[
@ -178,18 +196,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/31/80/3a54838c3fb461f6fec263ebf3a3a41771bd05190238de3486aae8540c36/jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d", size = 133271 },
]
[[package]]
name = "markdown-it-py"
version = "3.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mdurl" },
]
sdist = { url = "https://files.pythonhosted.org/packages/38/71/3b932df36c1a044d397a1f92d1cf91ee0a503d91e470cbd670aa66b07ed0/markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb", size = 74596 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", size = 87528 },
]
[[package]]
name = "markupsafe"
version = "3.0.2"
@ -248,15 +254,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/4f/65/6079a46068dfceaeabb5dcad6d674f5f5c61a6fa5673746f42a9f4c233b3/MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f", size = 15739 },
]
[[package]]
name = "mdurl"
version = "0.1.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 },
]
[[package]]
name = "mpmath"
version = "1.3.0"
@ -599,15 +596,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335 },
]
[[package]]
name = "pygments"
version = "2.18.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/8e/62/8336eff65bcbc8e4cb5d05b55faf041285951b6e80f33e2bff2024788f31/pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199", size = 4891905 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f7/3f/01c8b82017c199075f8f788d0d906b9ffbbc5a47dc9918a945e13d5a2bda/pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a", size = 1205513 },
]
[[package]]
name = "pytest"
version = "8.3.3"
@ -718,20 +706,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 },
]
[[package]]
name = "rich"
version = "13.9.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "markdown-it-py" },
{ name = "pygments" },
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ab/3a/0316b28d0761c6734d6bc14e770d85506c986c85ffb239e688eeaab2c2bc/rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098", size = 223149 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90", size = 242424 },
]
[[package]]
name = "ruff"
version = "0.8.0"
@ -881,13 +855,31 @@ version = "2.5.1+cpu"
source = { registry = "https://download.pytorch.org/whl/cpu" }
resolution-markers = [
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
]
dependencies = [
@ -915,13 +907,31 @@ version = "2.5.1+cu124"
source = { registry = "https://download.pytorch.org/whl/cu124" }
resolution-markers = [
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
]
dependencies = [
@ -962,13 +972,31 @@ version = "0.20.1+cpu"
source = { registry = "https://download.pytorch.org/whl/cpu" }
resolution-markers = [
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
]
dependencies = [
@ -991,13 +1019,31 @@ version = "0.20.1+cu124"
source = { registry = "https://download.pytorch.org/whl/cu124" }
resolution-markers = [
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
"python_full_version >= '3.12'",
]
dependencies = [
@ -1068,14 +1114,13 @@ wheels = [
[[package]]
name = "wdtagger"
version = "0.10.2"
version = "0.11.0"
source = { editable = "." }
dependencies = [
{ name = "huggingface-hub" },
{ name = "numpy" },
{ name = "pandas" },
{ name = "pillow" },
{ name = "rich" },
]
[package.optional-dependencies]
@ -1103,7 +1148,6 @@ requires-dist = [
{ name = "numpy", specifier = ">=2.1.3" },
{ name = "pandas", specifier = ">=2.2.3" },
{ name = "pillow", specifier = ">=11.0.0" },
{ name = "rich", specifier = ">=13.9.4" },
{ name = "timm", marker = "extra == 'cpu'", specifier = ">=1.0.11" },
{ name = "timm", marker = "extra == 'gpu'", specifier = ">=1.0.11" },
{ name = "torch", marker = "extra == 'cpu'", specifier = ">=2.5.1", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "wdtagger", extra = "cpu" } },