Compare commits
4 Commits
Author | SHA1 | Date | |
---|---|---|---|
1f04dab915 | |||
31a72457af | |||
b1776ef564 | |||
9fb556c41c |
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "wdtagger"
|
||||
version = "0.11.0"
|
||||
version = "0.11.1"
|
||||
description = "A simple and easy-to-use wrapper for the tagger model created by [SmilingWolf](https://github.com/SmilingWolf) which is specifically designed for tagging anime illustrations."
|
||||
authors = [{ name = "Jianqi Pan", email = "jannchie@gmail.com" }]
|
||||
readme = "README.md"
|
||||
@ -10,7 +10,6 @@ dependencies = [
|
||||
"numpy>=2.1.3",
|
||||
"pandas>=2.2.3",
|
||||
"pillow>=11.0.0",
|
||||
"rich>=13.9.4",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
@ -14,7 +14,6 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import HfHubHTTPError
|
||||
from PIL import Image
|
||||
from PIL.ImageFile import ImageFile
|
||||
from rich.console import Console
|
||||
from timm.data import create_transform, resolve_data_config
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
@ -137,15 +136,7 @@ class Result:
|
||||
character_tag: dict[str, float],
|
||||
general_tag: dict[str, float],
|
||||
) -> None:
|
||||
"""Initialize the Result object to store tagging results.
|
||||
|
||||
Args:
|
||||
preds (np.array): Predictions array from the model.
|
||||
sep_tags (tuple): Tuple containing separated tags based on categories.
|
||||
general_threshold (float): Threshold for general tags.
|
||||
character_threshold (float): Threshold for character tags.
|
||||
"""
|
||||
|
||||
"""Initialize the Result object with the tags and their ratings."""
|
||||
self.general_tag_data = general_tag
|
||||
self.character_tag_data = character_tag
|
||||
self.rating_data = rating_data
|
||||
@ -153,12 +144,27 @@ class Result:
|
||||
@property
|
||||
def general_tags(self) -> tuple[str, ...]:
|
||||
"""Return general tags as a tuple."""
|
||||
return tuple(self.general_tag_data.keys())
|
||||
|
||||
return tuple(
|
||||
d.replace("_", " ").replace("(", R"\(").replace(")", R"\)")
|
||||
for d in sorted(
|
||||
self.general_tag_data,
|
||||
key=lambda k: self.general_tag_data[k],
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def character_tags(self) -> tuple[str, ...]:
|
||||
"""Return character tags as a tuple."""
|
||||
return tuple(self.character_tag_data.keys())
|
||||
return tuple(
|
||||
d.replace("_", " ").replace("(", R"\(").replace(")", R"\)")
|
||||
for d in sorted(
|
||||
self.character_tag_data,
|
||||
key=lambda k: self.character_tag_data[k],
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def rating(self) -> Literal["general", "sensitive", "questionable", "explicit"]:
|
||||
@ -174,7 +180,7 @@ class Result:
|
||||
reverse=True,
|
||||
)
|
||||
string = [x[0] for x in string]
|
||||
return ", ".join(string).replace("_", " ").replace("(", R"\(").replace(")", R"\)")
|
||||
return ", ".join(string)
|
||||
|
||||
@property
|
||||
def character_tags_string(self) -> str:
|
||||
@ -185,7 +191,7 @@ class Result:
|
||||
reverse=True,
|
||||
)
|
||||
string = [x[0] for x in string]
|
||||
return ", ".join(string).replace("_", " ").replace("(", R"\(").replace(")", R"\)")
|
||||
return ", ".join(string)
|
||||
|
||||
@property
|
||||
def all_tags(self) -> list[str]:
|
||||
@ -194,7 +200,7 @@ class Result:
|
||||
|
||||
@property
|
||||
def all_tags_string(self) -> str:
|
||||
return ", ".join(self.all_tags).replace("_", " ").replace("(", R"\(").replace(")", R"\)")
|
||||
return ", ".join(self.all_tags)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a formatted string representation of the tags and their ratings."""
|
||||
@ -213,30 +219,9 @@ class Tagger:
|
||||
self,
|
||||
model_repo: str = "SmilingWolf/wd-swinv2-tagger-v3",
|
||||
hf_token: str = HF_TOKEN,
|
||||
console: Console | None = None,
|
||||
*,
|
||||
slient: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the Tagger object with the model repository and tokens.
|
||||
|
||||
Args:
|
||||
model_repo (str): Repository name on HuggingFace.
|
||||
cache_dir (str, optional): Directory to cache the model. Defaults to None.
|
||||
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
|
||||
loglevel (int, optional): Logging level. Defaults to logging.INFO.
|
||||
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
||||
providers (list, optional): List of providers for ONNX runtime. Defaults to None.
|
||||
console (rich.console.Console, optional): Rich console object. Defaults to None.
|
||||
"""
|
||||
self.slient = slient
|
||||
if not slient:
|
||||
if not console:
|
||||
from rich import get_console
|
||||
|
||||
self.console = get_console()
|
||||
else:
|
||||
self.console = console
|
||||
self.logger = logging.getLogger("wdtagger")
|
||||
"""Initialize the Tagger object with the model repository and tokens."""
|
||||
self.logger = logging.getLogger("wdtagger")
|
||||
self.model_target_size = None
|
||||
self.hf_token = hf_token
|
||||
self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@ -254,14 +239,13 @@ class Tagger:
|
||||
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
|
||||
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
||||
"""
|
||||
if not self.slient:
|
||||
with self.console.status(f"Loading model '{model_repo}'..."):
|
||||
self._do_load_model(model_repo)
|
||||
else:
|
||||
self._do_load_model(model_repo)
|
||||
start_time = time.time()
|
||||
self.logger.info("Loading model from %s", model_repo)
|
||||
self._do_load_model(model_repo)
|
||||
self.logger.info("Model loaded successfully in %.2fs", time.time() - start_time)
|
||||
|
||||
def _do_load_model(self, model_repo: str) -> None:
|
||||
model: nn.Module = timm.create_model("hf-hub:" + model_repo).eval()
|
||||
model: nn.Module = timm.create_model(f"hf-hub:{model_repo}").eval()
|
||||
state_dict = timm.models.load_state_dict_from_hf(model_repo)
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
@ -339,13 +323,12 @@ class Tagger:
|
||||
|
||||
duration = time.time() - started_at
|
||||
image_length = len(images)
|
||||
if not self.slient:
|
||||
self.logger.info(
|
||||
"Tagged %d image%s in %.2fs",
|
||||
image_length,
|
||||
"s" if image_length > 1 else "",
|
||||
duration,
|
||||
)
|
||||
self.logger.info(
|
||||
"Tagged %d image%s in %.2fs",
|
||||
image_length,
|
||||
"s" if image_length > 1 else "",
|
||||
duration,
|
||||
)
|
||||
if input_is_list:
|
||||
return results
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
158
uv.lock
generated
158
uv.lock
generated
@ -2,13 +2,31 @@ version = 1
|
||||
requires-python = ">=3.10"
|
||||
resolution-markers = [
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
]
|
||||
conflicts = [[
|
||||
@ -178,18 +196,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/31/80/3a54838c3fb461f6fec263ebf3a3a41771bd05190238de3486aae8540c36/jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d", size = 133271 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "markdown-it-py"
|
||||
version = "3.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "mdurl" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/38/71/3b932df36c1a044d397a1f92d1cf91ee0a503d91e470cbd670aa66b07ed0/markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb", size = 74596 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", size = 87528 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "markupsafe"
|
||||
version = "3.0.2"
|
||||
@ -248,15 +254,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/65/6079a46068dfceaeabb5dcad6d674f5f5c61a6fa5673746f42a9f4c233b3/MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f", size = 15739 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mdurl"
|
||||
version = "0.1.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mpmath"
|
||||
version = "1.3.0"
|
||||
@ -599,15 +596,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.18.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8e/62/8336eff65bcbc8e4cb5d05b55faf041285951b6e80f33e2bff2024788f31/pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199", size = 4891905 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f7/3f/01c8b82017c199075f8f788d0d906b9ffbbc5a47dc9918a945e13d5a2bda/pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a", size = 1205513 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.3.3"
|
||||
@ -718,20 +706,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rich"
|
||||
version = "13.9.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "markdown-it-py" },
|
||||
{ name = "pygments" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ab/3a/0316b28d0761c6734d6bc14e770d85506c986c85ffb239e688eeaab2c2bc/rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098", size = 223149 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90", size = 242424 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.8.0"
|
||||
@ -881,13 +855,31 @@ version = "2.5.1+cpu"
|
||||
source = { registry = "https://download.pytorch.org/whl/cpu" }
|
||||
resolution-markers = [
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
]
|
||||
dependencies = [
|
||||
@ -915,13 +907,31 @@ version = "2.5.1+cu124"
|
||||
source = { registry = "https://download.pytorch.org/whl/cu124" }
|
||||
resolution-markers = [
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
]
|
||||
dependencies = [
|
||||
@ -962,13 +972,31 @@ version = "0.20.1+cpu"
|
||||
source = { registry = "https://download.pytorch.org/whl/cpu" }
|
||||
resolution-markers = [
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
]
|
||||
dependencies = [
|
||||
@ -991,13 +1019,31 @@ version = "0.20.1+cu124"
|
||||
source = { registry = "https://download.pytorch.org/whl/cu124" }
|
||||
resolution-markers = [
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
"python_full_version >= '3.12'",
|
||||
]
|
||||
dependencies = [
|
||||
@ -1068,14 +1114,13 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "wdtagger"
|
||||
version = "0.10.2"
|
||||
version = "0.11.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "huggingface-hub" },
|
||||
{ name = "numpy" },
|
||||
{ name = "pandas" },
|
||||
{ name = "pillow" },
|
||||
{ name = "rich" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
@ -1103,7 +1148,6 @@ requires-dist = [
|
||||
{ name = "numpy", specifier = ">=2.1.3" },
|
||||
{ name = "pandas", specifier = ">=2.2.3" },
|
||||
{ name = "pillow", specifier = ">=11.0.0" },
|
||||
{ name = "rich", specifier = ">=13.9.4" },
|
||||
{ name = "timm", marker = "extra == 'cpu'", specifier = ">=1.0.11" },
|
||||
{ name = "timm", marker = "extra == 'gpu'", specifier = ">=1.0.11" },
|
||||
{ name = "torch", marker = "extra == 'cpu'", specifier = ">=2.5.1", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "wdtagger", extra = "cpu" } },
|
||||
|
Reference in New Issue
Block a user