Compare commits
17 Commits
Author | SHA1 | Date | |
---|---|---|---|
447ad99aea | |||
52931f2a59 | |||
ec1f0b6e26 | |||
fe5ad89a29 | |||
c7fbdeb7b1 | |||
5bcbf991e8 | |||
f6955384b7 | |||
117f153b33 | |||
09d331d8f0 | |||
6d1cbc0b2b | |||
c910823173 | |||
16686ce3c4 | |||
9eafed31dc | |||
a4003953e6 | |||
69f35eb373 | |||
c82f881c85 | |||
b9a763468a |
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,3 +2,4 @@ dist
|
||||
models
|
||||
__pycache__
|
||||
.pytest_cache
|
||||
.ruff_cache
|
@ -1,5 +1,7 @@
|
||||
# wdtagger
|
||||
|
||||
[](https://codetime.dev)
|
||||
|
||||
`wdtagger` is a simple and easy-to-use wrapper for the tagger model created by [SmilingWolf](https://github.com/SmilingWolf) which is specifically designed for tagging anime illustrations.
|
||||
|
||||
## Installation
|
||||
|
997
poetry.lock
generated
997
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,13 +1,12 @@
|
||||
[tool.poetry]
|
||||
name = "wdtagger"
|
||||
version = "0.5.0"
|
||||
version = "0.10.2"
|
||||
description = ""
|
||||
authors = ["Jianqi Pan <jannchie@gmail.com>"]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
pillow = "^10.3.0"
|
||||
pillow = ">=10.3.0,<12.0.0"
|
||||
pandas = "^2.2.2"
|
||||
huggingface-hub = "^0.23.3"
|
||||
rich = "^13.7.1"
|
||||
@ -17,6 +16,19 @@ onnxruntime-gpu = "^1.18.0"
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^8.2.2"
|
||||
pytest-benchmark = "^4.0.0"
|
||||
ruff = "^0.5.5"
|
||||
isort = "^5.13.2"
|
||||
black = "^24.4.2"
|
||||
mypy = "^1.11.0"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
@ -11,7 +12,7 @@ image_paths = [os.path.join(image_dir, image) for image in os.listdir(image_dir)
|
||||
images = [Image.open(image_path) for image_path in image_paths]
|
||||
|
||||
|
||||
def tag_in_batch(images, batch=1):
|
||||
def tag_in_batch(images: Any, batch: Any = 1) -> None:
|
||||
for i in range(0, len(images), batch):
|
||||
tagger.tag(images[i : i + batch])
|
||||
|
||||
@ -23,7 +24,7 @@ def tag_in_batch(images, batch=1):
|
||||
disable_gc=True,
|
||||
)
|
||||
@pytest.mark.parametrize("batch", [1, 2, 4, 8, 16])
|
||||
def test_tagger_benchmark(benchmark, batch):
|
||||
def test_tagger_benchmark(benchmark: Any, batch: Any) -> None:
|
||||
# warmup
|
||||
tag_in_batch(images[:1])
|
||||
benchmark.pedantic(tag_in_batch, args=(images, batch), iterations=1, rounds=10)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
@ -20,3 +21,39 @@ def test_tagger(tagger, image_file):
|
||||
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_path(tagger, image_file):
|
||||
result = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_np(tagger, image_file):
|
||||
image = Image.open(image_file)
|
||||
image_np = np.array(image)
|
||||
result = tagger.tag(image_np, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_pil(tagger, image_file):
|
||||
image = Image.open(image_file)
|
||||
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", [["./tests/images/赤松楓.9d64b955.jpeg"]])
|
||||
def test_tagger_np_single(tagger, image_file):
|
||||
results = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
|
||||
assert len(results) == 1
|
||||
result = results[0]
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.rating == "general"
|
||||
|
@ -2,32 +2,35 @@ import logging
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Literal, Sequence, Union, overload
|
||||
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import onnxruntime as rt
|
||||
import pandas as pd
|
||||
import rich
|
||||
import rich.live
|
||||
from PIL import Image
|
||||
from rich.logging import RichHandler
|
||||
|
||||
# Access console for rich text and logging
|
||||
console = rich.get_console()
|
||||
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
||||
MODEL_FILENAME = "model.onnx"
|
||||
LABEL_FILENAME = "selected_tags.csv"
|
||||
|
||||
# Environment variables and file paths
|
||||
HF_TOKEN = os.environ.get(
|
||||
"HF_TOKEN", ""
|
||||
) # Token for authentication with HuggingFace API
|
||||
MODEL_FILENAME = "model.onnx" # ONNX model filename
|
||||
LABEL_FILENAME = "selected_tags.csv" # Labels CSV filename
|
||||
|
||||
available_providers = rt.get_available_providers()
|
||||
supported_providers = ["CPUExecutionProvider", "CUDAExecutionProvider"]
|
||||
providers = list(set(available_providers) & set(supported_providers))
|
||||
Input = Union[np.ndarray, Image.Image, str, Path]
|
||||
|
||||
|
||||
def load_labels(dataframe) -> list[str]:
|
||||
def to_pil(img: Input) -> Image.Image:
|
||||
if isinstance(img, (str, Path)):
|
||||
return Image.open(img)
|
||||
elif isinstance(img, np.ndarray):
|
||||
return Image.fromarray(img)
|
||||
elif isinstance(img, Image.Image):
|
||||
return img
|
||||
else:
|
||||
raise ValueError("Invalid input type.")
|
||||
|
||||
|
||||
def load_labels(dataframe) -> tuple[Any, list[Any], list[Any], list[Any]]:
|
||||
"""Load labels from a dataframe and process tag names.
|
||||
|
||||
Args:
|
||||
@ -61,9 +64,7 @@ def load_labels(dataframe) -> list[str]:
|
||||
"|_|",
|
||||
"||_||",
|
||||
]
|
||||
name_series = name_series.map(
|
||||
lambda x: x.replace("_", " ") if x not in kaomojis else x
|
||||
)
|
||||
name_series = name_series.map(lambda x: x.replace("_", " ") if x not in kaomojis else x)
|
||||
tag_names = name_series.tolist()
|
||||
rating_indexes = list(np.where(dataframe["category"] == 9)[0])
|
||||
general_indexes = list(np.where(dataframe["category"] == 0)[0])
|
||||
@ -97,9 +98,7 @@ class Result:
|
||||
# Ratings
|
||||
ratings_names = [labels[i] for i in rating_indexes]
|
||||
rating_data = dict(ratings_names)
|
||||
rating_data = OrderedDict(
|
||||
sorted(rating_data.items(), key=lambda x: x[1], reverse=True)
|
||||
)
|
||||
rating_data = OrderedDict(sorted(rating_data.items(), key=lambda x: x[1], reverse=True))
|
||||
|
||||
# General tags
|
||||
general_names = [labels[i] for i in general_indexes]
|
||||
@ -109,26 +108,24 @@ class Result:
|
||||
# Character tags
|
||||
character_names = [labels[i] for i in character_indexes]
|
||||
character_tag = [x for x in character_names if x[1] > character_threshold]
|
||||
character_tag = OrderedDict(
|
||||
sorted(character_tag, key=lambda x: x[1], reverse=True)
|
||||
)
|
||||
character_tag = OrderedDict(sorted(character_tag, key=lambda x: x[1], reverse=True))
|
||||
|
||||
self.general_tag_data = general_tag
|
||||
self.character_tag_data = character_tag
|
||||
self.rating_data = rating_data
|
||||
|
||||
@property
|
||||
def general_tags(self):
|
||||
def general_tags(self) -> tuple[str]:
|
||||
"""Return general tags as a tuple."""
|
||||
return tuple(self.general_tag_data.keys())
|
||||
|
||||
@property
|
||||
def character_tags(self):
|
||||
def character_tags(self) -> tuple[str]:
|
||||
"""Return character tags as a tuple."""
|
||||
return tuple(self.character_tag_data.keys())
|
||||
|
||||
@property
|
||||
def rating(self):
|
||||
def rating(self) -> Literal["general", "sensitive", "questionable", "explicit"]:
|
||||
"""Return the highest rated tag."""
|
||||
return max(self.rating_data, key=self.rating_data.get)
|
||||
|
||||
@ -154,6 +151,15 @@ class Result:
|
||||
string = [x[0] for x in string]
|
||||
return ", ".join(string)
|
||||
|
||||
@property
|
||||
def all_tags(self) -> list[str]:
|
||||
"""Return all tags as a list."""
|
||||
return [self.rating] + list(self.general_tags) + list(self.character_tags)
|
||||
|
||||
@property
|
||||
def all_tags_string(self) -> str:
|
||||
return ", ".join(self.all_tags)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a formatted string representation of the tags and their ratings."""
|
||||
|
||||
@ -174,6 +180,9 @@ class Tagger:
|
||||
hf_token=HF_TOKEN,
|
||||
loglevel=logging.INFO,
|
||||
num_threads=None,
|
||||
providers=None,
|
||||
console=None,
|
||||
slient=False,
|
||||
):
|
||||
"""Initialize the Tagger object with the model repository and tokens.
|
||||
|
||||
@ -183,17 +192,41 @@ class Tagger:
|
||||
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
|
||||
loglevel (int, optional): Logging level. Defaults to logging.INFO.
|
||||
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
||||
providers (list, optional): List of providers for ONNX runtime. Defaults to None.
|
||||
console (rich.console.Console, optional): Rich console object. Defaults to None.
|
||||
"""
|
||||
self.logger = logging.getLogger("wdtagger")
|
||||
self.logger.setLevel(loglevel)
|
||||
self.logger.addHandler(RichHandler())
|
||||
self.slient = slient
|
||||
|
||||
if providers is None:
|
||||
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
if not slient:
|
||||
if not console:
|
||||
from rich import get_console
|
||||
|
||||
self.console = get_console()
|
||||
else:
|
||||
self.console = console
|
||||
self.logger = logging.getLogger("wdtagger")
|
||||
self.logger.setLevel(loglevel)
|
||||
self.logger.addHandler(RichHandler())
|
||||
self.model_target_size = None
|
||||
self.cache_dir = cache_dir
|
||||
self.hf_token = hf_token
|
||||
self.load_model(model_repo, cache_dir, hf_token, num_threads=num_threads)
|
||||
self.load_model(
|
||||
model_repo,
|
||||
cache_dir,
|
||||
hf_token,
|
||||
num_threads=num_threads,
|
||||
providers=providers,
|
||||
)
|
||||
|
||||
def load_model(
|
||||
self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None
|
||||
self,
|
||||
model_repo,
|
||||
cache_dir=None,
|
||||
hf_token=None,
|
||||
num_threads: int | None = None,
|
||||
providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
|
||||
):
|
||||
"""Load the model and tags from the specified repository.
|
||||
|
||||
@ -203,37 +236,43 @@ class Tagger:
|
||||
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
|
||||
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
||||
"""
|
||||
with console.status("Loading model..."):
|
||||
csv_path = huggingface_hub.hf_hub_download(
|
||||
model_repo,
|
||||
LABEL_FILENAME,
|
||||
cache_dir=cache_dir,
|
||||
use_auth_token=hf_token,
|
||||
)
|
||||
if not self.slient:
|
||||
with self.console.status("Loading model..."):
|
||||
self.do_load_model(model_repo, cache_dir, hf_token, num_threads, providers)
|
||||
else:
|
||||
self.do_load_model(model_repo, cache_dir, hf_token, num_threads, providers)
|
||||
|
||||
model_path = huggingface_hub.hf_hub_download(
|
||||
model_repo,
|
||||
MODEL_FILENAME,
|
||||
cache_dir=cache_dir,
|
||||
use_auth_token=hf_token,
|
||||
)
|
||||
def do_load_model(self, model_repo, cache_dir, hf_token, num_threads, providers):
|
||||
csv_path = huggingface_hub.hf_hub_download(
|
||||
model_repo,
|
||||
LABEL_FILENAME,
|
||||
cache_dir=cache_dir,
|
||||
use_auth_token=hf_token,
|
||||
)
|
||||
|
||||
tags_df = pd.read_csv(csv_path)
|
||||
self.sep_tags = load_labels(tags_df)
|
||||
options = rt.SessionOptions()
|
||||
if num_threads:
|
||||
options.intra_op_num_threads = num_threads
|
||||
options.inter_op_num_threads = num_threads
|
||||
model = rt.InferenceSession(
|
||||
model_path,
|
||||
options,
|
||||
providers=providers,
|
||||
)
|
||||
_, height, _, _ = model.get_inputs()[0].shape
|
||||
self.model_target_size = height
|
||||
self.model = model
|
||||
model_path = huggingface_hub.hf_hub_download(
|
||||
model_repo,
|
||||
MODEL_FILENAME,
|
||||
cache_dir=cache_dir,
|
||||
use_auth_token=hf_token,
|
||||
)
|
||||
|
||||
def prepare_image(self, image):
|
||||
tags_df = pd.read_csv(csv_path)
|
||||
self.sep_tags = load_labels(tags_df)
|
||||
options = rt.SessionOptions()
|
||||
if num_threads:
|
||||
options.intra_op_num_threads = num_threads
|
||||
options.inter_op_num_threads = num_threads
|
||||
model = rt.InferenceSession(
|
||||
model_path,
|
||||
options,
|
||||
providers=providers,
|
||||
)
|
||||
_, height, _, _ = model.get_inputs()[0].shape
|
||||
self.model_target_size = height
|
||||
self.model = model
|
||||
|
||||
def pil_to_cv2_numpy(self, image):
|
||||
"""Prepare the image for model input.
|
||||
|
||||
Args:
|
||||
@ -268,16 +307,24 @@ class Tagger:
|
||||
array = array[:, :, [2, 1, 0]]
|
||||
return array
|
||||
|
||||
@overload
|
||||
def tag(self, image: Input, general_threshold: float = 0.35, character_threshold: float = 0.9) -> Result: ...
|
||||
|
||||
@overload
|
||||
def tag(
|
||||
self, image: List[Input], general_threshold: float = 0.35, character_threshold: float = 0.9
|
||||
) -> List[Result]: ...
|
||||
|
||||
def tag(
|
||||
self,
|
||||
image: Image.Image | list[Image.Image],
|
||||
image: Union[Input, List[Input]],
|
||||
general_threshold=0.35,
|
||||
character_threshold=0.9,
|
||||
) -> Result | list[Result]:
|
||||
"""Tag the image and return the results.
|
||||
|
||||
Args:
|
||||
image (PIL.Image | list[PIL.Image]): Input image or list of images.
|
||||
image (Union[Input, List[Input]]): Input image or list of images to tag.
|
||||
general_threshold (float): Threshold for general tags.
|
||||
character_threshold (float): Threshold for character tags.
|
||||
|
||||
@ -285,21 +332,23 @@ class Tagger:
|
||||
Result | list[Result]: Tagging results.
|
||||
"""
|
||||
started_at = time.time()
|
||||
images = [image] if isinstance(image, Image.Image) else image
|
||||
images = [self.prepare_image(img) for img in images]
|
||||
input_is_list = isinstance(image, list)
|
||||
images = image if isinstance(image, list) else [image]
|
||||
images = [to_pil(img) for img in images]
|
||||
images = [self.pil_to_cv2_numpy(img) for img in images]
|
||||
image_array = np.asarray(images, dtype=np.float32)
|
||||
input_name = self.model.get_inputs()[0].name
|
||||
label_name = self.model.get_outputs()[0].name
|
||||
preds = self.model.run([label_name], {input_name: image_array})[0]
|
||||
results = [
|
||||
Result(pred, self.sep_tags, general_threshold, character_threshold)
|
||||
for pred in preds
|
||||
]
|
||||
results = [Result(pred, self.sep_tags, general_threshold, character_threshold) for pred in preds]
|
||||
duration = time.time() - started_at
|
||||
image_length = len(images)
|
||||
self.logger.info(
|
||||
f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds."
|
||||
)
|
||||
if not self.slient:
|
||||
self.logger.info(
|
||||
f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds."
|
||||
)
|
||||
if input_is_list:
|
||||
return results
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user