12 Commits

5 changed files with 727 additions and 433 deletions

1
.gitignore vendored
View File

@ -2,3 +2,4 @@ dist
models
__pycache__
.pytest_cache
.ruff_cache

997
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,13 +1,12 @@
[tool.poetry]
name = "wdtagger"
version = "0.7.0"
version = "0.10.2"
description = ""
authors = ["Jianqi Pan <jannchie@gmail.com>"]
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.10"
pillow = "^10.3.0"
pillow = ">=10.3.0,<12.0.0"
pandas = "^2.2.2"
huggingface-hub = "^0.23.3"
rich = "^13.7.1"
@ -17,6 +16,19 @@ onnxruntime-gpu = "^1.18.0"
[tool.poetry.group.dev.dependencies]
pytest = "^8.2.2"
pytest-benchmark = "^4.0.0"
ruff = "^0.5.5"
isort = "^5.13.2"
black = "^24.4.2"
mypy = "^1.11.0"
[tool.isort]
profile = "black"
[tool.black]
line-length = 120
[tool.mypy]
ignore_missing_imports = true
[build-system]
requires = ["poetry-core"]

View File

@ -1,4 +1,5 @@
import os
from typing import Any
import pytest
from PIL import Image
@ -11,7 +12,7 @@ image_paths = [os.path.join(image_dir, image) for image in os.listdir(image_dir)
images = [Image.open(image_path) for image_path in image_paths]
def tag_in_batch(images, batch=1):
def tag_in_batch(images: Any, batch: Any = 1) -> None:
for i in range(0, len(images), batch):
tagger.tag(images[i : i + batch])
@ -23,7 +24,7 @@ def tag_in_batch(images, batch=1):
disable_gc=True,
)
@pytest.mark.parametrize("batch", [1, 2, 4, 8, 16])
def test_tagger_benchmark(benchmark, batch):
def test_tagger_benchmark(benchmark: Any, batch: Any) -> None:
# warmup
tag_in_batch(images[:1])
benchmark.pedantic(tag_in_batch, args=(images, batch), iterations=1, rounds=10)

View File

@ -3,20 +3,15 @@ import os
import time
from collections import OrderedDict
from pathlib import Path
from typing import Any, List, Sequence, Union
from typing import Any, List, Literal, Sequence, Union, overload
import huggingface_hub
import numpy as np
import onnxruntime as rt
import pandas as pd
import rich
import rich.live
from PIL import Image
from rich.logging import RichHandler
# Access console for rich text and logging
console = rich.get_console()
HF_TOKEN = os.environ.get("HF_TOKEN", "")
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
@ -35,7 +30,7 @@ def to_pil(img: Input) -> Image.Image:
raise ValueError("Invalid input type.")
def load_labels(dataframe) -> list[str]:
def load_labels(dataframe) -> tuple[Any, list[Any], list[Any], list[Any]]:
"""Load labels from a dataframe and process tag names.
Args:
@ -69,9 +64,7 @@ def load_labels(dataframe) -> list[str]:
"|_|",
"||_||",
]
name_series = name_series.map(
lambda x: x.replace("_", " ") if x not in kaomojis else x
)
name_series = name_series.map(lambda x: x.replace("_", " ") if x not in kaomojis else x)
tag_names = name_series.tolist()
rating_indexes = list(np.where(dataframe["category"] == 9)[0])
general_indexes = list(np.where(dataframe["category"] == 0)[0])
@ -105,9 +98,7 @@ class Result:
# Ratings
ratings_names = [labels[i] for i in rating_indexes]
rating_data = dict(ratings_names)
rating_data = OrderedDict(
sorted(rating_data.items(), key=lambda x: x[1], reverse=True)
)
rating_data = OrderedDict(sorted(rating_data.items(), key=lambda x: x[1], reverse=True))
# General tags
general_names = [labels[i] for i in general_indexes]
@ -117,26 +108,24 @@ class Result:
# Character tags
character_names = [labels[i] for i in character_indexes]
character_tag = [x for x in character_names if x[1] > character_threshold]
character_tag = OrderedDict(
sorted(character_tag, key=lambda x: x[1], reverse=True)
)
character_tag = OrderedDict(sorted(character_tag, key=lambda x: x[1], reverse=True))
self.general_tag_data = general_tag
self.character_tag_data = character_tag
self.rating_data = rating_data
@property
def general_tags(self):
def general_tags(self) -> tuple[str]:
"""Return general tags as a tuple."""
return tuple(self.general_tag_data.keys())
@property
def character_tags(self):
def character_tags(self) -> tuple[str]:
"""Return character tags as a tuple."""
return tuple(self.character_tag_data.keys())
@property
def rating(self):
def rating(self) -> Literal["general", "sensitive", "questionable", "explicit"]:
"""Return the highest rated tag."""
return max(self.rating_data, key=self.rating_data.get)
@ -162,6 +151,15 @@ class Result:
string = [x[0] for x in string]
return ", ".join(string)
@property
def all_tags(self) -> list[str]:
"""Return all tags as a list."""
return [self.rating] + list(self.general_tags) + list(self.character_tags)
@property
def all_tags_string(self) -> str:
return ", ".join(self.all_tags)
def __str__(self) -> str:
"""Return a formatted string representation of the tags and their ratings."""
@ -183,6 +181,8 @@ class Tagger:
loglevel=logging.INFO,
num_threads=None,
providers=None,
console=None,
slient=False,
):
"""Initialize the Tagger object with the model repository and tokens.
@ -192,12 +192,23 @@ class Tagger:
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
loglevel (int, optional): Logging level. Defaults to logging.INFO.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
providers (list, optional): List of providers for ONNX runtime. Defaults to None.
console (rich.console.Console, optional): Rich console object. Defaults to None.
"""
self.slient = slient
if providers is None:
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
self.logger = logging.getLogger("wdtagger")
self.logger.setLevel(loglevel)
self.logger.addHandler(RichHandler())
if not slient:
if not console:
from rich import get_console
self.console = get_console()
else:
self.console = console
self.logger = logging.getLogger("wdtagger")
self.logger.setLevel(loglevel)
self.logger.addHandler(RichHandler())
self.model_target_size = None
self.cache_dir = cache_dir
self.hf_token = hf_token
@ -214,8 +225,8 @@ class Tagger:
model_repo,
cache_dir=None,
hf_token=None,
num_threads: int = None,
providers: Sequence[str | tuple[str, dict[Any, Any]]] = None,
num_threads: int | None = None,
providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
):
"""Load the model and tags from the specified repository.
@ -225,35 +236,41 @@ class Tagger:
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
"""
with console.status("Loading model..."):
csv_path = huggingface_hub.hf_hub_download(
model_repo,
LABEL_FILENAME,
cache_dir=cache_dir,
use_auth_token=hf_token,
)
if not self.slient:
with self.console.status("Loading model..."):
self.do_load_model(model_repo, cache_dir, hf_token, num_threads, providers)
else:
self.do_load_model(model_repo, cache_dir, hf_token, num_threads, providers)
model_path = huggingface_hub.hf_hub_download(
model_repo,
MODEL_FILENAME,
cache_dir=cache_dir,
use_auth_token=hf_token,
)
def do_load_model(self, model_repo, cache_dir, hf_token, num_threads, providers):
csv_path = huggingface_hub.hf_hub_download(
model_repo,
LABEL_FILENAME,
cache_dir=cache_dir,
use_auth_token=hf_token,
)
tags_df = pd.read_csv(csv_path)
self.sep_tags = load_labels(tags_df)
options = rt.SessionOptions()
if num_threads:
options.intra_op_num_threads = num_threads
options.inter_op_num_threads = num_threads
model = rt.InferenceSession(
model_path,
options,
providers=providers,
)
_, height, _, _ = model.get_inputs()[0].shape
self.model_target_size = height
self.model = model
model_path = huggingface_hub.hf_hub_download(
model_repo,
MODEL_FILENAME,
cache_dir=cache_dir,
use_auth_token=hf_token,
)
tags_df = pd.read_csv(csv_path)
self.sep_tags = load_labels(tags_df)
options = rt.SessionOptions()
if num_threads:
options.intra_op_num_threads = num_threads
options.inter_op_num_threads = num_threads
model = rt.InferenceSession(
model_path,
options,
providers=providers,
)
_, height, _, _ = model.get_inputs()[0].shape
self.model_target_size = height
self.model = model
def pil_to_cv2_numpy(self, image):
"""Prepare the image for model input.
@ -290,6 +307,14 @@ class Tagger:
array = array[:, :, [2, 1, 0]]
return array
@overload
def tag(self, image: Input, general_threshold: float = 0.35, character_threshold: float = 0.9) -> Result: ...
@overload
def tag(
self, image: List[Input], general_threshold: float = 0.35, character_threshold: float = 0.9
) -> List[Result]: ...
def tag(
self,
image: Union[Input, List[Input]],
@ -315,15 +340,13 @@ class Tagger:
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
preds = self.model.run([label_name], {input_name: image_array})[0]
results = [
Result(pred, self.sep_tags, general_threshold, character_threshold)
for pred in preds
]
results = [Result(pred, self.sep_tags, general_threshold, character_threshold) for pred in preds]
duration = time.time() - started_at
image_length = len(images)
self.logger.info(
f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds."
)
if not self.slient:
self.logger.info(
f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds."
)
if input_is_list:
return results
return results[0] if len(results) == 1 else results