29 Commits

Author SHA1 Message Date
447ad99aea version: v0.10.2 2024-11-22 19:25:44 +09:00
52931f2a59 chore(dependencies): update multiple package versions in poetry.lock 2024-11-22 19:25:23 +09:00
ec1f0b6e26 chore(dependencies): update various package versions in poetry.lock && update pillow version range in pyproject.toml 2024-11-22 19:23:59 +09:00
fe5ad89a29 version: v0.10.1 2024-08-24 19:12:20 +09:00
c7fbdeb7b1 chore(type): better result type 2024-08-24 19:12:07 +09:00
5bcbf991e8 🔧 chore(mypy): remove stub 2024-07-29 23:54:24 +09:00
f6955384b7 version: v0.10.0 2024-07-29 21:08:01 +09:00
117f153b33 feat(slient): allow turn off all outputs 2024-07-29 21:07:48 +09:00
09d331d8f0 version: v0.9.0 2024-07-10 16:53:47 +09:00
6d1cbc0b2b feat(threading): allow use external rich.console 2024-07-10 16:53:32 +09:00
c910823173 version: v0.8.0 2024-06-25 00:38:26 +09:00
16686ce3c4 feat(result): return all tags 2024-06-25 00:38:14 +09:00
9eafed31dc version: v0.7.0 2024-06-23 21:28:06 +09:00
a4003953e6 feat(input): enable input path or numpy array 2024-06-23 21:27:52 +09:00
69f35eb373 version: v0.6.0 2024-06-23 03:21:06 +09:00
c82f881c85 feat(provider): enable change onnx provider 2024-06-23 03:20:56 +09:00
b9a763468a 📚 docs: add codetime badge 2024-06-19 18:34:17 +09:00
01935d9e82 version: v0.5.0 2024-06-19 18:30:37 +09:00
5b39dc7735 docs(benchmark): add batch benchmark results 2024-06-19 18:29:51 +09:00
dbec094a3d test(benchmark): add batch size benchmark 2024-06-19 18:28:47 +09:00
1e6b04c0ec feat(onnx): use gpu 2024-06-19 18:28:15 +09:00
be7085f2f7 version: v0.4.0 2024-06-11 18:28:51 +09:00
fd67f54fcc 📚 docs: add comment docs 2024-06-11 18:28:30 +09:00
4e5221d7a8 feat(config): add the num_threads option 2024-06-11 16:17:09 +09:00
5e4629b8ea 🩹 fix(color): rgb -> bgr 2024-06-11 16:15:40 +09:00
ae2838f58e version: v0.3.0 2024-06-11 03:18:40 +09:00
f9ec9de157 feat(logger): use logger 2024-06-11 03:18:23 +09:00
ea88303c8a version: v0.2.0 2024-06-08 00:13:16 +09:00
1f1ab1e3df feat(tagger): batch tagger 2024-06-08 00:12:41 +09:00
8 changed files with 924 additions and 468 deletions

1
.gitignore vendored
View File

@ -2,3 +2,4 @@ dist
models
__pycache__
.pytest_cache
.ruff_cache

9
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,9 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnType": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
},
},
}

View File

@ -1,5 +1,7 @@
# wdtagger
[![CodeTime Badge](https://img.shields.io/endpoint?style=social&color=222&url=https%3A%2F%2Fapi.codetime.dev%2Fshield%3Fid%3D2%26project%3Dwdtagger%26in=0)](https://codetime.dev)
`wdtagger` is a simple and easy-to-use wrapper for the tagger model created by [SmilingWolf](https://github.com/SmilingWolf) which is specifically designed for tagging anime illustrations.
## Installation
@ -23,3 +25,17 @@ image = Image.open("image.jpg")
result = tagger.tag(image)
print(result)
```
You can input a image list to the tagger to use batch processing, it is faster than single image processing (test on RTX 3090):
```log
---------------------------------------------------------------------------------- benchmark 'tagger': 5 tests -----------------------------------------------------------------------------------
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_tagger_benchmark[16] 540.8711 (1.0) 598.5156 (1.04) 558.2777 (1.0) 22.2954 (4.10) 549.9650 (1.0) 21.7318 (2.51) 2;2 1.7912 (1.0) 10 1
test_tagger_benchmark[8] 558.9445 (1.03) 576.7220 (1.0) 567.9235 (1.02) 5.4381 (1.0) 568.7336 (1.03) 8.6569 (1.0) 2;0 1.7608 (0.98) 10 1
test_tagger_benchmark[4] 590.6479 (1.09) 626.7126 (1.09) 597.9712 (1.07) 11.0124 (2.03) 594.5067 (1.08) 10.7656 (1.24) 1;1 1.6723 (0.93) 10 1
test_tagger_benchmark[2] 622.8689 (1.15) 643.5122 (1.12) 630.1096 (1.13) 7.2365 (1.33) 627.1716 (1.14) 9.5823 (1.11) 3;0 1.5870 (0.89) 10 1
test_tagger_benchmark[1] 700.6986 (1.30) 816.3089 (1.42) 721.7431 (1.29) 33.9031 (6.23) 712.6850 (1.30) 12.8756 (1.49) 1;1 1.3855 (0.77) 10 1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
```

1049
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,21 +1,34 @@
[tool.poetry]
name = "wdtagger"
version = "0.1.0"
version = "0.10.2"
description = ""
authors = ["Jianqi Pan <jannchie@gmail.com>"]
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.10"
onnxruntime = "^1.18.0"
pillow = "^10.3.0"
pillow = ">=10.3.0,<12.0.0"
pandas = "^2.2.2"
huggingface-hub = "^0.23.3"
rich = "^13.7.1"
onnxruntime-gpu = "^1.18.0"
[tool.poetry.group.dev.dependencies]
pytest = "^8.2.2"
pytest-benchmark = "^4.0.0"
ruff = "^0.5.5"
isort = "^5.13.2"
black = "^24.4.2"
mypy = "^1.11.0"
[tool.isort]
profile = "black"
[tool.black]
line-length = 120
[tool.mypy]
ignore_missing_imports = true
[build-system]
requires = ["poetry-core"]

33
tests/benchmark_tagger.py Normal file
View File

@ -0,0 +1,33 @@
import os
from typing import Any
import pytest
from PIL import Image
from wdtagger import Tagger
tagger = Tagger()
image_dir = "./tests/images/"
image_paths = [os.path.join(image_dir, image) for image in os.listdir(image_dir)] * 16
images = [Image.open(image_path) for image_path in image_paths]
def tag_in_batch(images: Any, batch: Any = 1) -> None:
for i in range(0, len(images), batch):
tagger.tag(images[i : i + batch])
@pytest.mark.benchmark(
group="tagger",
min_rounds=10,
warmup=False,
disable_gc=True,
)
@pytest.mark.parametrize("batch", [1, 2, 4, 8, 16])
def test_tagger_benchmark(benchmark: Any, batch: Any) -> None:
# warmup
tag_in_batch(images[:1])
benchmark.pedantic(tag_in_batch, args=(images, batch), iterations=1, rounds=10)
# cmd: pytest tests/benchmark_tagger.py -v

View File

@ -1,3 +1,4 @@
import numpy as np
import pytest
from PIL import Image
@ -20,3 +21,39 @@ def test_tagger(tagger, image_file):
assert result.character_tags_string == "akamatsu kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
def test_tagger_path(tagger, image_file):
result = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
assert result.character_tags_string == "akamatsu kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
def test_tagger_np(tagger, image_file):
image = Image.open(image_file)
image_np = np.array(image)
result = tagger.tag(image_np, character_threshold=0.85, general_threshold=0.35)
assert result.character_tags_string == "akamatsu kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
def test_tagger_pil(tagger, image_file):
image = Image.open(image_file)
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
assert result.character_tags_string == "akamatsu kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", [["./tests/images/赤松楓.9d64b955.jpeg"]])
def test_tagger_np_single(tagger, image_file):
results = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
assert len(results) == 1
result = results[0]
assert result.character_tags_string == "akamatsu kaede"
assert result.rating == "general"

View File

@ -1,25 +1,36 @@
import logging
import os
import time
from collections import OrderedDict
from pathlib import Path
from typing import Any, List, Literal, Sequence, Union, overload
import huggingface_hub
import numpy as np
import onnxruntime as rt
import pandas as pd
import rich
from PIL import Image
from rich.logging import RichHandler
# Access console for rich text and logging
console = rich.get_console()
HF_TOKEN = os.environ.get("HF_TOKEN", "")
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
# Environment variables and file paths
HF_TOKEN = os.environ.get(
"HF_TOKEN", ""
) # Token for authentication with HuggingFace API
MODEL_FILENAME = "model.onnx" # ONNX model filename
LABEL_FILENAME = "selected_tags.csv" # Labels CSV filename
Input = Union[np.ndarray, Image.Image, str, Path]
def load_labels(dataframe) -> list[str]:
def to_pil(img: Input) -> Image.Image:
if isinstance(img, (str, Path)):
return Image.open(img)
elif isinstance(img, np.ndarray):
return Image.fromarray(img)
elif isinstance(img, Image.Image):
return img
else:
raise ValueError("Invalid input type.")
def load_labels(dataframe) -> tuple[Any, list[Any], list[Any], list[Any]]:
"""Load labels from a dataframe and process tag names.
Args:
@ -53,9 +64,7 @@ def load_labels(dataframe) -> list[str]:
"|_|",
"||_||",
]
name_series = name_series.map(
lambda x: x.replace("_", " ") if x not in kaomojis else x
)
name_series = name_series.map(lambda x: x.replace("_", " ") if x not in kaomojis else x)
tag_names = name_series.tolist()
rating_indexes = list(np.where(dataframe["category"] == 9)[0])
general_indexes = list(np.where(dataframe["category"] == 0)[0])
@ -66,7 +75,11 @@ def load_labels(dataframe) -> list[str]:
class Result:
def __init__(
self, preds, sep_tags, general_threshold=0.35, character_threshold=0.9
self,
pred,
sep_tags,
general_threshold=0.35,
character_threshold=0.9,
):
"""Initialize the Result object to store tagging results.
@ -80,14 +93,12 @@ class Result:
rating_indexes = sep_tags[1]
general_indexes = sep_tags[2]
character_indexes = sep_tags[3]
labels = list(zip(tag_names, preds[0].astype(float)))
labels = list(zip(tag_names, pred.astype(float)))
# Ratings
ratings_names = [labels[i] for i in rating_indexes]
rating_data = dict(ratings_names)
rating_data = OrderedDict(
sorted(rating_data.items(), key=lambda x: x[1], reverse=True)
)
rating_data = OrderedDict(sorted(rating_data.items(), key=lambda x: x[1], reverse=True))
# General tags
general_names = [labels[i] for i in general_indexes]
@ -97,26 +108,24 @@ class Result:
# Character tags
character_names = [labels[i] for i in character_indexes]
character_tag = [x for x in character_names if x[1] > character_threshold]
character_tag = OrderedDict(
sorted(character_tag, key=lambda x: x[1], reverse=True)
)
character_tag = OrderedDict(sorted(character_tag, key=lambda x: x[1], reverse=True))
self.general_tag_data = general_tag
self.character_tag_data = character_tag
self.rating_data = rating_data
@property
def general_tags(self):
def general_tags(self) -> tuple[str]:
"""Return general tags as a tuple."""
return tuple(self.general_tag_data.keys())
@property
def character_tags(self):
def character_tags(self) -> tuple[str]:
"""Return character tags as a tuple."""
return tuple(self.character_tag_data.keys())
@property
def rating(self):
def rating(self) -> Literal["general", "sensitive", "questionable", "explicit"]:
"""Return the highest rated tag."""
return max(self.rating_data, key=self.rating_data.get)
@ -129,9 +138,7 @@ class Result:
reverse=True,
)
string = [x[0] for x in string]
string = ", ".join(string)
return string
return ", ".join(string)
@property
def character_tags_string(self) -> str:
@ -142,8 +149,16 @@ class Result:
reverse=True,
)
string = [x[0] for x in string]
string = ", ".join(string)
return string
return ", ".join(string)
@property
def all_tags(self) -> list[str]:
"""Return all tags as a list."""
return [self.rating] + list(self.general_tags) + list(self.character_tags)
@property
def all_tags_string(self) -> str:
return ", ".join(self.all_tags)
def __str__(self) -> str:
"""Return a formatted string representation of the tags and their ratings."""
@ -163,6 +178,11 @@ class Tagger:
model_repo="SmilingWolf/wd-swinv2-tagger-v3",
cache_dir=None,
hf_token=HF_TOKEN,
loglevel=logging.INFO,
num_threads=None,
providers=None,
console=None,
slient=False,
):
"""Initialize the Tagger object with the model repository and tokens.
@ -170,43 +190,89 @@ class Tagger:
model_repo (str): Repository name on HuggingFace.
cache_dir (str, optional): Directory to cache the model. Defaults to None.
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
loglevel (int, optional): Logging level. Defaults to logging.INFO.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
providers (list, optional): List of providers for ONNX runtime. Defaults to None.
console (rich.console.Console, optional): Rich console object. Defaults to None.
"""
self.slient = slient
if providers is None:
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
if not slient:
if not console:
from rich import get_console
self.console = get_console()
else:
self.console = console
self.logger = logging.getLogger("wdtagger")
self.logger.setLevel(loglevel)
self.logger.addHandler(RichHandler())
self.model_target_size = None
self.cache_dir = cache_dir
self.hf_token = hf_token
self.load_model(model_repo, cache_dir, hf_token)
self.load_model(
model_repo,
cache_dir,
hf_token,
num_threads=num_threads,
providers=providers,
)
def load_model(self, model_repo, cache_dir=None, hf_token=None):
def load_model(
self,
model_repo,
cache_dir=None,
hf_token=None,
num_threads: int | None = None,
providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
):
"""Load the model and tags from the specified repository.
Args:
model_repo (str): Repository name on HuggingFace.
cache_dir (str, optional): Directory to cache the model. Defaults to None.
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
"""
with console.status("Loading model..."):
csv_path = huggingface_hub.hf_hub_download(
model_repo,
LABEL_FILENAME,
cache_dir=cache_dir,
use_auth_token=hf_token,
)
model_path = huggingface_hub.hf_hub_download(
model_repo,
MODEL_FILENAME,
cache_dir=cache_dir,
use_auth_token=hf_token,
)
if not self.slient:
with self.console.status("Loading model..."):
self.do_load_model(model_repo, cache_dir, hf_token, num_threads, providers)
else:
self.do_load_model(model_repo, cache_dir, hf_token, num_threads, providers)
tags_df = pd.read_csv(csv_path)
self.sep_tags = load_labels(tags_df)
def do_load_model(self, model_repo, cache_dir, hf_token, num_threads, providers):
csv_path = huggingface_hub.hf_hub_download(
model_repo,
LABEL_FILENAME,
cache_dir=cache_dir,
use_auth_token=hf_token,
)
model = rt.InferenceSession(model_path)
_, height, _, _ = model.get_inputs()[0].shape
self.model_target_size = height
self.model = model
model_path = huggingface_hub.hf_hub_download(
model_repo,
MODEL_FILENAME,
cache_dir=cache_dir,
use_auth_token=hf_token,
)
def prepare_image(self, image):
tags_df = pd.read_csv(csv_path)
self.sep_tags = load_labels(tags_df)
options = rt.SessionOptions()
if num_threads:
options.intra_op_num_threads = num_threads
options.inter_op_num_threads = num_threads
model = rt.InferenceSession(
model_path,
options,
providers=providers,
)
_, height, _, _ = model.get_inputs()[0].shape
self.model_target_size = height
self.model = model
def pil_to_cv2_numpy(self, image):
"""Prepare the image for model input.
Args:
@ -237,45 +303,53 @@ class Tagger:
(target_size, target_size),
Image.BICUBIC,
)
array = np.asarray(padded_image, dtype=np.float32)
array = array[:, :, [2, 1, 0]]
return array
# Convert to numpy array
image_array = np.asarray(padded_image, dtype=np.float32)
@overload
def tag(self, image: Input, general_threshold: float = 0.35, character_threshold: float = 0.9) -> Result: ...
# Convert PIL-native RGB to BGR
image_array = image_array[:, :, ::-1]
return np.expand_dims(image_array, axis=0)
@overload
def tag(
self, image: List[Input], general_threshold: float = 0.35, character_threshold: float = 0.9
) -> List[Result]: ...
def tag(
self,
image,
image: Union[Input, List[Input]],
general_threshold=0.35,
character_threshold=0.9,
):
) -> Result | list[Result]:
"""Tag the image and return the results.
Args:
image (PIL.Image): Input image.
image (Union[Input, List[Input]]): Input image or list of images to tag.
general_threshold (float): Threshold for general tags.
character_threshold (float): Threshold for character tags.
Returns:
Result: Object containing the tagging results.
Result | list[Result]: Tagging results.
"""
with console.status("Tagging..."):
image = self.prepare_image(image)
image_array = np.asarray(image, dtype=np.float32)
image_array = image_array[:, :, ::-1] # Convert PIL-native RGB to BGR
image_array = np.expand_dims(image_array, axis=0)
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
preds = self.model.run([label_name], {input_name: image_array[0]})[0]
result = Result(
preds, self.sep_tags, general_threshold, character_threshold
started_at = time.time()
input_is_list = isinstance(image, list)
images = image if isinstance(image, list) else [image]
images = [to_pil(img) for img in images]
images = [self.pil_to_cv2_numpy(img) for img in images]
image_array = np.asarray(images, dtype=np.float32)
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
preds = self.model.run([label_name], {input_name: image_array})[0]
results = [Result(pred, self.sep_tags, general_threshold, character_threshold) for pred in preds]
duration = time.time() - started_at
image_length = len(images)
if not self.slient:
self.logger.info(
f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds."
)
return result
if input_is_list:
return results
return results[0] if len(results) == 1 else results
__all__ = ["Tagger"]
@ -283,5 +357,5 @@ __all__ = ["Tagger"]
if __name__ == "__main__":
tagger = Tagger()
image = Image.open("./tests/images/赤松楓.9d64b955.jpeg")
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
print(result)
result = tagger.tag(image)
tagger.logger.info(result)