Compare commits
No commits in common. "main" and "v0.10.1" have entirely different histories.
3
.gitignore
vendored
3
.gitignore
vendored
@ -2,5 +2,4 @@ dist
|
||||
models
|
||||
__pycache__
|
||||
.pytest_cache
|
||||
.ruff_cache
|
||||
.venv
|
||||
.ruff_cache
|
@ -1 +0,0 @@
|
||||
3.12
|
9
.vscode/settings.json
vendored
Normal file
9
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "ms-python.black-formatter",
|
||||
"editor.formatOnType": true,
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": "explicit"
|
||||
},
|
||||
},
|
||||
}
|
1120
poetry.lock
generated
Normal file
1120
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
108
pyproject.toml
108
pyproject.toml
@ -1,92 +1,36 @@
|
||||
[project]
|
||||
[tool.poetry]
|
||||
name = "wdtagger"
|
||||
version = "0.14.0"
|
||||
description = "A simple and easy-to-use wrapper for the tagger model created by [SmilingWolf](https://github.com/SmilingWolf) which is specifically designed for tagging anime illustrations."
|
||||
authors = [{ name = "Jianqi Pan", email = "jannchie@gmail.com" }]
|
||||
version = "0.10.1"
|
||||
description = ""
|
||||
authors = ["Jianqi Pan <jannchie@gmail.com>"]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = ["huggingface-hub>=0.26.2", "pandas>=2.2.3", "pillow>=11.0.0"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = ["pytest>=8.3.3", "pytest-benchmark>=5.1.0", "ruff>=0.8.0"]
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
pillow = "^10.3.0"
|
||||
pandas = "^2.2.2"
|
||||
huggingface-hub = "^0.23.3"
|
||||
rich = "^13.7.1"
|
||||
onnxruntime-gpu = "^1.18.0"
|
||||
|
||||
|
||||
[project.optional-dependencies]
|
||||
cpu = ["torch>=2.0.0", "torchvision>=0.20.1", "timm>=1.0.11"]
|
||||
gpu = ["torch>=2.0.0", "torchvision>=0.20.1", "timm>=1.0.11"]
|
||||
cuda11 = ["torch>=2.0.0", "torchvision>=0.20.1", "timm>=1.0.11"]
|
||||
cuda12 = ["torch>=2.0.0", "torchvision>=0.20.1", "timm>=1.0.11"]
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^8.2.2"
|
||||
pytest-benchmark = "^4.0.0"
|
||||
ruff = "^0.5.5"
|
||||
isort = "^5.13.2"
|
||||
black = "^24.4.2"
|
||||
mypy = "^1.11.0"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 140
|
||||
select = ["ALL"]
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
|
||||
ignore = [
|
||||
"PGH",
|
||||
"RUF003",
|
||||
"BLE001",
|
||||
"ERA001",
|
||||
"FIX002",
|
||||
"TD002",
|
||||
"TD003",
|
||||
"D",
|
||||
"N812",
|
||||
]
|
||||
|
||||
[tool.ruff.lint.extend-per-file-ignores]
|
||||
"tests/**/*.py" = [
|
||||
"S101", # asserts allowed in tests
|
||||
"PLR2004", # Magic value used in comparison
|
||||
]
|
||||
[tool.pyright]
|
||||
|
||||
[tool.uv]
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "cpu" },
|
||||
{ extra = "gpu" },
|
||||
{ extra = "cuda11" },
|
||||
{ extra = "cuda12" },
|
||||
],
|
||||
]
|
||||
package = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "torch-cpu", extra = "cpu" },
|
||||
{ index = "torch-gpu", extra = "gpu" },
|
||||
{ index = "torch-cuda11", extra = "cuda11" },
|
||||
{ index = "torch-cuda12", extra = "cuda12" },
|
||||
]
|
||||
|
||||
torchvision = [
|
||||
{ index = "torch-cpu", extra = "cpu" },
|
||||
{ index = "torch-gpu", extra = "gpu" },
|
||||
{ index = "torch-cuda11", extra = "cuda11" },
|
||||
{ index = "torch-cuda12", extra = "cuda12" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "torch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "torch-gpu"
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "torch-cuda11"
|
||||
url = "https://download.pytorch.org/whl/cu118"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "torch-cuda12"
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
explicit = true
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
@ -1,326 +0,0 @@
|
||||
import importlib.resources
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, overload
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import timm
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL.ImageFile import ImageFile
|
||||
from timm.data import create_transform, resolve_data_config
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
|
||||
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
||||
|
||||
|
||||
Input = np.ndarray | Image.Image | str | Path | ImageFile
|
||||
|
||||
|
||||
@dataclass
|
||||
class LabelData:
|
||||
names: list[str]
|
||||
rating: list[np.int64]
|
||||
general: list[np.int64]
|
||||
character: list[np.int64]
|
||||
|
||||
|
||||
def load_labels() -> LabelData:
|
||||
file = importlib.resources.as_file(importlib.resources.files("wdtagger.assets").joinpath("selected_tags.csv"))
|
||||
with file as f:
|
||||
df: pd.DataFrame = pd.read_csv(f, usecols=["name", "category"])
|
||||
rating_catagory_idx = 9
|
||||
general_catagory_idx = 0
|
||||
character_catagory_idx = 4
|
||||
return LabelData(
|
||||
names=df["name"].tolist(),
|
||||
rating=list(np.where(df["category"] == rating_catagory_idx)[0]),
|
||||
general=list(np.where(df["category"] == general_catagory_idx)[0]),
|
||||
character=list(np.where(df["category"] == character_catagory_idx)[0]),
|
||||
)
|
||||
|
||||
|
||||
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
|
||||
# convert to RGB/RGBA if not already (deals with palette images etc.)
|
||||
if image.mode not in ["RGB", "RGBA"]:
|
||||
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
|
||||
# convert RGBA to RGB with white background
|
||||
if image.mode == "RGBA":
|
||||
canvas = Image.new("RGBA", image.size, (255, 255, 255))
|
||||
canvas.alpha_composite(image)
|
||||
image = canvas.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
def pil_pad_square(image: Image.Image) -> Image.Image:
|
||||
w, h = image.size
|
||||
# get the largest dimension so we can pad to a square
|
||||
px = max(image.size)
|
||||
# pad to square with white background
|
||||
canvas = Image.new("RGB", (px, px), (255, 255, 255))
|
||||
canvas.paste(image, ((px - w) // 2, (px - h) // 2))
|
||||
return canvas
|
||||
|
||||
|
||||
def to_pil(img: Input) -> Image.Image:
|
||||
if isinstance(img, str | Path):
|
||||
return Image.open(img)
|
||||
if isinstance(img, np.ndarray):
|
||||
return Image.fromarray(img)
|
||||
if isinstance(img, Image.Image):
|
||||
return img
|
||||
msg = "Invalid input type."
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def get_tags(
|
||||
probs: Tensor,
|
||||
labels: LabelData,
|
||||
gen_threshold: float,
|
||||
char_threshold: float,
|
||||
) -> tuple[dict[str, float], dict[str, float], dict[str, float]]:
|
||||
# Convert indices+probs to labels
|
||||
probs_list = list(zip(labels.names, probs.numpy(), strict=False))
|
||||
|
||||
# First 4 labels are actually ratings
|
||||
rating_labels = dict([probs_list[i] for i in labels.rating])
|
||||
|
||||
# General labels, pick any where prediction confidence > threshold
|
||||
gen_labels = [probs_list[i] for i in labels.general]
|
||||
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
|
||||
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
|
||||
|
||||
# Character labels, pick any where prediction confidence > threshold
|
||||
char_labels = [probs_list[i] for i in labels.character]
|
||||
char_labels = dict([x for x in char_labels if x[1] > char_threshold])
|
||||
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
|
||||
|
||||
return rating_labels, char_labels, gen_labels
|
||||
|
||||
|
||||
class Result:
|
||||
def __init__(
|
||||
self,
|
||||
rating_data: dict[str, float],
|
||||
character_tag: dict[str, float],
|
||||
general_tag: dict[str, float],
|
||||
) -> None:
|
||||
"""Initialize the Result object with the tags and their ratings."""
|
||||
self.general_tag_data = general_tag
|
||||
self.character_tag_data = character_tag
|
||||
self.rating_data = rating_data
|
||||
|
||||
@property
|
||||
def general_tags(self) -> tuple[str, ...]:
|
||||
"""Return general tags as a tuple."""
|
||||
|
||||
return tuple(
|
||||
d.replace("_", " ")
|
||||
for d in sorted(
|
||||
self.general_tag_data,
|
||||
key=lambda k: self.general_tag_data[k],
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def character_tags(self) -> tuple[str, ...]:
|
||||
"""Return character tags as a tuple."""
|
||||
return tuple(
|
||||
d.replace("_", " ")
|
||||
for d in sorted(
|
||||
self.character_tag_data,
|
||||
key=lambda k: self.character_tag_data[k],
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def rating(self) -> Literal["general", "sensitive", "questionable", "explicit"]:
|
||||
"""Return the highest rated tag."""
|
||||
return max(self.rating_data, key=lambda k: self.rating_data[k]) # type: ignore
|
||||
|
||||
@property
|
||||
def general_tags_string(self) -> str:
|
||||
"""Return general tags as a sorted string."""
|
||||
string = sorted(
|
||||
self.general_tag_data.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)
|
||||
string = [x[0] for x in string]
|
||||
return ", ".join(string)
|
||||
|
||||
@property
|
||||
def character_tags_string(self) -> str:
|
||||
"""Return character tags as a sorted string."""
|
||||
string = sorted(
|
||||
self.character_tag_data.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)
|
||||
string = [x[0] for x in string]
|
||||
return ", ".join(string)
|
||||
|
||||
@property
|
||||
def all_tags(self) -> list[str]:
|
||||
"""Return all tags as a list."""
|
||||
return [self.rating, *list(self.character_tags), *list(self.general_tags)]
|
||||
|
||||
@property
|
||||
def all_tags_string(self) -> str:
|
||||
return ", ".join(self.all_tags)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a formatted string representation of the tags and their ratings."""
|
||||
|
||||
def get_tag_with_rate(tag_dict: dict[str, float]) -> str:
|
||||
return ", ".join([f"{k} ({v:.2f})" for k, v in tag_dict.items()])
|
||||
|
||||
result = f"General tags: {get_tag_with_rate(self.general_tag_data)}\n"
|
||||
result += f"Character tags: {get_tag_with_rate(self.character_tag_data)}\n"
|
||||
result += f"Rating: {self.rating} ({self.rating_data[self.rating]:.2f})"
|
||||
return result
|
||||
|
||||
|
||||
class Tagger:
|
||||
def __init__(
|
||||
self,
|
||||
model_repo: str = "SmilingWolf/wd-swinv2-tagger-v3",
|
||||
hf_token: str = HF_TOKEN,
|
||||
) -> None:
|
||||
"""Initialize the Tagger object with the model repository and tokens."""
|
||||
self.logger = logging.getLogger("wdtagger")
|
||||
self.model_target_size = None
|
||||
self.hf_token = hf_token
|
||||
self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self._load_model(model_repo)
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_repo: str,
|
||||
) -> None:
|
||||
"""Load the model and tags from the specified repository.
|
||||
|
||||
Args:
|
||||
model_repo (str): Repository name on HuggingFace.
|
||||
cache_dir (str, optional): Directory to cache the model. Defaults to None.
|
||||
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
|
||||
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
||||
"""
|
||||
start_time = time.time()
|
||||
self.logger.info("Loading model from %s", model_repo)
|
||||
self._do_load_model(model_repo)
|
||||
self.logger.info("Model loaded successfully in %.2fs", time.time() - start_time)
|
||||
|
||||
def _do_load_model(self, model_repo: str) -> None:
|
||||
model: nn.Module = timm.create_model(f"hf-hub:{model_repo}").eval()
|
||||
state_dict = timm.models.load_state_dict_from_hf(model_repo)
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
self.labels: LabelData = load_labels()
|
||||
|
||||
self.transform: Compose = create_transform(**resolve_data_config(model.pretrained_cfg, model=model)) # type: ignore
|
||||
|
||||
self.model = model.to(self.torch_device)
|
||||
|
||||
@overload
|
||||
def tag(
|
||||
self,
|
||||
image: Input,
|
||||
general_threshold: float = 0.35,
|
||||
character_threshold: float = 0.9,
|
||||
) -> Result: ...
|
||||
|
||||
@overload
|
||||
def tag(
|
||||
self,
|
||||
image: Sequence[Input],
|
||||
general_threshold: float = 0.35,
|
||||
character_threshold: float = 0.9,
|
||||
) -> Sequence[Result]: ...
|
||||
|
||||
def tag(
|
||||
self,
|
||||
image: Input | Sequence[Input],
|
||||
general_threshold=0.35,
|
||||
character_threshold=0.9,
|
||||
) -> Result | Sequence[Result]:
|
||||
"""Tag the image and return the results.
|
||||
|
||||
Args:
|
||||
image (Union[Input, List[Input]]): Input image or list of images to tag.
|
||||
general_threshold (float): Threshold for general tags.
|
||||
character_threshold (float): Threshold for character tags.
|
||||
|
||||
Returns:
|
||||
Result | list[Result]: Tagging results.
|
||||
"""
|
||||
started_at = time.time()
|
||||
images = list(image) if isinstance(image, Sequence) and not isinstance(image, str) else [image]
|
||||
images = [to_pil(img) for img in images]
|
||||
images = [pil_ensure_rgb(img) for img in images]
|
||||
images = [pil_pad_square(img) for img in images]
|
||||
inputs: Tensor = torch.stack([self.transform(img) for img in images]) # type: ignore
|
||||
inputs = inputs[:, [2, 1, 0]] # BGR to RGB
|
||||
|
||||
with torch.inference_mode():
|
||||
# move model to GPU, if available
|
||||
if self.torch_device.type != "cpu":
|
||||
inputs = inputs.to(self.torch_device)
|
||||
# run the model
|
||||
outputs = self.model.forward(inputs)
|
||||
# apply the final activation function (timm doesn't support doing this internally)
|
||||
outputs = F.sigmoid(outputs)
|
||||
# move inputs, outputs, and model back to to cpu if we were on GPU
|
||||
if self.torch_device.type != "cpu":
|
||||
inputs = inputs.to("cpu")
|
||||
outputs = outputs.to("cpu")
|
||||
|
||||
results = [
|
||||
Result(
|
||||
*get_tags(
|
||||
probs=o,
|
||||
labels=self.labels,
|
||||
gen_threshold=general_threshold,
|
||||
char_threshold=character_threshold,
|
||||
),
|
||||
)
|
||||
for o in outputs
|
||||
]
|
||||
|
||||
duration = time.time() - started_at
|
||||
image_length = len(images)
|
||||
self.logger.info(
|
||||
"Tagged %d image%s in %.2fs",
|
||||
image_length,
|
||||
"s" if image_length > 1 else "",
|
||||
duration,
|
||||
)
|
||||
if isinstance(image, Sequence) and not isinstance(image, str):
|
||||
return results
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
|
||||
__all__ = ["Tagger"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("wdtagger")
|
||||
tagger = Tagger()
|
||||
images = [
|
||||
Image.open("./tests/images/赤松楓.9d64b955.jpeg"),
|
||||
]
|
||||
results = tagger.tag(images)
|
||||
for result in results:
|
||||
logger.info(result.all_tags_string)
|
File diff suppressed because it is too large
Load Diff
@ -6,60 +6,54 @@ from wdtagger import Tagger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tagger() -> Tagger:
|
||||
"""
|
||||
Create and return a new instance of the Tagger class.
|
||||
|
||||
Returns:
|
||||
Tagger: An instance of the Tagger class.
|
||||
"""
|
||||
def tagger():
|
||||
return Tagger()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_file() -> str:
|
||||
def image_file():
|
||||
return "./tests/images/赤松楓.9d64b955.jpeg"
|
||||
|
||||
|
||||
def test_tagger(tagger: Tagger, image_file: str) -> None:
|
||||
def test_tagger(tagger, image_file):
|
||||
image = Image.open(image_file)
|
||||
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu_kaede"
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_path_single(tagger: Tagger, image_file: str) -> None:
|
||||
def test_tagger_path(tagger, image_file):
|
||||
result = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu_kaede"
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_np(tagger: Tagger, image_file: str) -> None:
|
||||
def test_tagger_np(tagger, image_file):
|
||||
image = Image.open(image_file)
|
||||
image_np = np.array(image)
|
||||
result = tagger.tag(image_np, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu_kaede"
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_pil(tagger: Tagger, image_file: str) -> None:
|
||||
def test_tagger_pil(tagger, image_file):
|
||||
image = Image.open(image_file)
|
||||
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu_kaede"
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_path_multi(tagger: Tagger, image_file: str) -> None:
|
||||
results = tagger.tag([image_file, image_file], character_threshold=0.85, general_threshold=0.35)
|
||||
assert len(results) == 2
|
||||
@pytest.mark.parametrize("image_file", [["./tests/images/赤松楓.9d64b955.jpeg"]])
|
||||
def test_tagger_np_single(tagger, image_file):
|
||||
results = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
|
||||
assert len(results) == 1
|
||||
result = results[0]
|
||||
assert result.character_tags_string == "akamatsu_kaede"
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.rating == "general"
|
||||
|
361
wdtagger/__init__.py
Normal file
361
wdtagger/__init__.py
Normal file
@ -0,0 +1,361 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Literal, Sequence, Union, overload
|
||||
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import onnxruntime as rt
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from rich.logging import RichHandler
|
||||
|
||||
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
||||
MODEL_FILENAME = "model.onnx"
|
||||
LABEL_FILENAME = "selected_tags.csv"
|
||||
|
||||
Input = Union[np.ndarray, Image.Image, str, Path]
|
||||
|
||||
|
||||
def to_pil(img: Input) -> Image.Image:
|
||||
if isinstance(img, (str, Path)):
|
||||
return Image.open(img)
|
||||
elif isinstance(img, np.ndarray):
|
||||
return Image.fromarray(img)
|
||||
elif isinstance(img, Image.Image):
|
||||
return img
|
||||
else:
|
||||
raise ValueError("Invalid input type.")
|
||||
|
||||
|
||||
def load_labels(dataframe) -> tuple[Any, list[Any], list[Any], list[Any]]:
|
||||
"""Load labels from a dataframe and process tag names.
|
||||
|
||||
Args:
|
||||
dataframe (pd.DataFrame): DataFrame containing the tag names and categories.
|
||||
|
||||
Returns:
|
||||
tag_names: List of tag names.
|
||||
rating_indexes: List of indexes for rating tags.
|
||||
general_indexes: List of indexes for general tags.
|
||||
character_indexes: List of indexes for character tags.
|
||||
"""
|
||||
name_series = dataframe["name"]
|
||||
kaomojis = [
|
||||
"0_0",
|
||||
"(o)_(o)",
|
||||
"+_+",
|
||||
"+_-",
|
||||
"._.",
|
||||
"<o>_<o>",
|
||||
"<|>_<|>",
|
||||
"=_=",
|
||||
">_<",
|
||||
"3_3",
|
||||
"6_9",
|
||||
">_o",
|
||||
"@_@",
|
||||
"^_^",
|
||||
"o_o",
|
||||
"u_u",
|
||||
"x_x",
|
||||
"|_|",
|
||||
"||_||",
|
||||
]
|
||||
name_series = name_series.map(lambda x: x.replace("_", " ") if x not in kaomojis else x)
|
||||
tag_names = name_series.tolist()
|
||||
rating_indexes = list(np.where(dataframe["category"] == 9)[0])
|
||||
general_indexes = list(np.where(dataframe["category"] == 0)[0])
|
||||
character_indexes = list(np.where(dataframe["category"] == 4)[0])
|
||||
|
||||
return tag_names, rating_indexes, general_indexes, character_indexes
|
||||
|
||||
|
||||
class Result:
|
||||
def __init__(
|
||||
self,
|
||||
pred,
|
||||
sep_tags,
|
||||
general_threshold=0.35,
|
||||
character_threshold=0.9,
|
||||
):
|
||||
"""Initialize the Result object to store tagging results.
|
||||
|
||||
Args:
|
||||
preds (np.array): Predictions array from the model.
|
||||
sep_tags (tuple): Tuple containing separated tags based on categories.
|
||||
general_threshold (float): Threshold for general tags.
|
||||
character_threshold (float): Threshold for character tags.
|
||||
"""
|
||||
tag_names = sep_tags[0]
|
||||
rating_indexes = sep_tags[1]
|
||||
general_indexes = sep_tags[2]
|
||||
character_indexes = sep_tags[3]
|
||||
labels = list(zip(tag_names, pred.astype(float)))
|
||||
|
||||
# Ratings
|
||||
ratings_names = [labels[i] for i in rating_indexes]
|
||||
rating_data = dict(ratings_names)
|
||||
rating_data = OrderedDict(sorted(rating_data.items(), key=lambda x: x[1], reverse=True))
|
||||
|
||||
# General tags
|
||||
general_names = [labels[i] for i in general_indexes]
|
||||
general_tag = [x for x in general_names if x[1] > general_threshold]
|
||||
general_tag = OrderedDict(sorted(general_tag, key=lambda x: x[1], reverse=True))
|
||||
|
||||
# Character tags
|
||||
character_names = [labels[i] for i in character_indexes]
|
||||
character_tag = [x for x in character_names if x[1] > character_threshold]
|
||||
character_tag = OrderedDict(sorted(character_tag, key=lambda x: x[1], reverse=True))
|
||||
|
||||
self.general_tag_data = general_tag
|
||||
self.character_tag_data = character_tag
|
||||
self.rating_data = rating_data
|
||||
|
||||
@property
|
||||
def general_tags(self) -> tuple[str]:
|
||||
"""Return general tags as a tuple."""
|
||||
return tuple(self.general_tag_data.keys())
|
||||
|
||||
@property
|
||||
def character_tags(self) -> tuple[str]:
|
||||
"""Return character tags as a tuple."""
|
||||
return tuple(self.character_tag_data.keys())
|
||||
|
||||
@property
|
||||
def rating(self) -> Literal["general", "sensitive", "questionable", "explicit"]:
|
||||
"""Return the highest rated tag."""
|
||||
return max(self.rating_data, key=self.rating_data.get)
|
||||
|
||||
@property
|
||||
def general_tags_string(self) -> str:
|
||||
"""Return general tags as a sorted string."""
|
||||
string = sorted(
|
||||
self.general_tag_data.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)
|
||||
string = [x[0] for x in string]
|
||||
return ", ".join(string)
|
||||
|
||||
@property
|
||||
def character_tags_string(self) -> str:
|
||||
"""Return character tags as a sorted string."""
|
||||
string = sorted(
|
||||
self.character_tag_data.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)
|
||||
string = [x[0] for x in string]
|
||||
return ", ".join(string)
|
||||
|
||||
@property
|
||||
def all_tags(self) -> list[str]:
|
||||
"""Return all tags as a list."""
|
||||
return [self.rating] + list(self.general_tags) + list(self.character_tags)
|
||||
|
||||
@property
|
||||
def all_tags_string(self) -> str:
|
||||
return ", ".join(self.all_tags)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a formatted string representation of the tags and their ratings."""
|
||||
|
||||
def get_tag_with_rate(tag_dict):
|
||||
return ", ".join([f"{k} ({v:.2f})" for k, v in tag_dict.items()])
|
||||
|
||||
result = f"General tags: {get_tag_with_rate(self.general_tag_data)}\n"
|
||||
result += f"Character tags: {get_tag_with_rate(self.character_tag_data)}\n"
|
||||
result += f"Rating: {self.rating} ({self.rating_data[self.rating]:.2f})"
|
||||
return result
|
||||
|
||||
|
||||
class Tagger:
|
||||
def __init__(
|
||||
self,
|
||||
model_repo="SmilingWolf/wd-swinv2-tagger-v3",
|
||||
cache_dir=None,
|
||||
hf_token=HF_TOKEN,
|
||||
loglevel=logging.INFO,
|
||||
num_threads=None,
|
||||
providers=None,
|
||||
console=None,
|
||||
slient=False,
|
||||
):
|
||||
"""Initialize the Tagger object with the model repository and tokens.
|
||||
|
||||
Args:
|
||||
model_repo (str): Repository name on HuggingFace.
|
||||
cache_dir (str, optional): Directory to cache the model. Defaults to None.
|
||||
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
|
||||
loglevel (int, optional): Logging level. Defaults to logging.INFO.
|
||||
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
||||
providers (list, optional): List of providers for ONNX runtime. Defaults to None.
|
||||
console (rich.console.Console, optional): Rich console object. Defaults to None.
|
||||
"""
|
||||
self.slient = slient
|
||||
|
||||
if providers is None:
|
||||
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
if not slient:
|
||||
if not console:
|
||||
from rich import get_console
|
||||
|
||||
self.console = get_console()
|
||||
else:
|
||||
self.console = console
|
||||
self.logger = logging.getLogger("wdtagger")
|
||||
self.logger.setLevel(loglevel)
|
||||
self.logger.addHandler(RichHandler())
|
||||
self.model_target_size = None
|
||||
self.cache_dir = cache_dir
|
||||
self.hf_token = hf_token
|
||||
self.load_model(
|
||||
model_repo,
|
||||
cache_dir,
|
||||
hf_token,
|
||||
num_threads=num_threads,
|
||||
providers=providers,
|
||||
)
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
model_repo,
|
||||
cache_dir=None,
|
||||
hf_token=None,
|
||||
num_threads: int | None = None,
|
||||
providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
|
||||
):
|
||||
"""Load the model and tags from the specified repository.
|
||||
|
||||
Args:
|
||||
model_repo (str): Repository name on HuggingFace.
|
||||
cache_dir (str, optional): Directory to cache the model. Defaults to None.
|
||||
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
|
||||
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
||||
"""
|
||||
if not self.slient:
|
||||
with self.console.status("Loading model..."):
|
||||
self.do_load_model(model_repo, cache_dir, hf_token, num_threads, providers)
|
||||
else:
|
||||
self.do_load_model(model_repo, cache_dir, hf_token, num_threads, providers)
|
||||
|
||||
def do_load_model(self, model_repo, cache_dir, hf_token, num_threads, providers):
|
||||
csv_path = huggingface_hub.hf_hub_download(
|
||||
model_repo,
|
||||
LABEL_FILENAME,
|
||||
cache_dir=cache_dir,
|
||||
use_auth_token=hf_token,
|
||||
)
|
||||
|
||||
model_path = huggingface_hub.hf_hub_download(
|
||||
model_repo,
|
||||
MODEL_FILENAME,
|
||||
cache_dir=cache_dir,
|
||||
use_auth_token=hf_token,
|
||||
)
|
||||
|
||||
tags_df = pd.read_csv(csv_path)
|
||||
self.sep_tags = load_labels(tags_df)
|
||||
options = rt.SessionOptions()
|
||||
if num_threads:
|
||||
options.intra_op_num_threads = num_threads
|
||||
options.inter_op_num_threads = num_threads
|
||||
model = rt.InferenceSession(
|
||||
model_path,
|
||||
options,
|
||||
providers=providers,
|
||||
)
|
||||
_, height, _, _ = model.get_inputs()[0].shape
|
||||
self.model_target_size = height
|
||||
self.model = model
|
||||
|
||||
def pil_to_cv2_numpy(self, image):
|
||||
"""Prepare the image for model input.
|
||||
|
||||
Args:
|
||||
image (PIL.Image): Input image.
|
||||
|
||||
Returns:
|
||||
np.array: Processed image as a NumPy array.
|
||||
"""
|
||||
if image.mode != "RGBA":
|
||||
image = image.convert("RGBA")
|
||||
target_size = self.model_target_size
|
||||
canvas = Image.new("RGBA", image.size, (255, 255, 255))
|
||||
canvas.alpha_composite(image)
|
||||
image = canvas.convert("RGB")
|
||||
|
||||
# Pad image to square
|
||||
image_shape = image.size
|
||||
max_dim = max(image_shape)
|
||||
pad_left = (max_dim - image_shape[0]) // 2
|
||||
pad_top = (max_dim - image_shape[1]) // 2
|
||||
|
||||
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
|
||||
padded_image.paste(image, (pad_left, pad_top))
|
||||
|
||||
# Resize
|
||||
if max_dim != target_size:
|
||||
padded_image = padded_image.resize(
|
||||
(target_size, target_size),
|
||||
Image.BICUBIC,
|
||||
)
|
||||
array = np.asarray(padded_image, dtype=np.float32)
|
||||
array = array[:, :, [2, 1, 0]]
|
||||
return array
|
||||
|
||||
@overload
|
||||
def tag(self, image: Input, general_threshold: float = 0.35, character_threshold: float = 0.9) -> Result: ...
|
||||
|
||||
@overload
|
||||
def tag(
|
||||
self, image: List[Input], general_threshold: float = 0.35, character_threshold: float = 0.9
|
||||
) -> List[Result]: ...
|
||||
|
||||
def tag(
|
||||
self,
|
||||
image: Union[Input, List[Input]],
|
||||
general_threshold=0.35,
|
||||
character_threshold=0.9,
|
||||
) -> Result | list[Result]:
|
||||
"""Tag the image and return the results.
|
||||
|
||||
Args:
|
||||
image (Union[Input, List[Input]]): Input image or list of images to tag.
|
||||
general_threshold (float): Threshold for general tags.
|
||||
character_threshold (float): Threshold for character tags.
|
||||
|
||||
Returns:
|
||||
Result | list[Result]: Tagging results.
|
||||
"""
|
||||
started_at = time.time()
|
||||
input_is_list = isinstance(image, list)
|
||||
images = image if isinstance(image, list) else [image]
|
||||
images = [to_pil(img) for img in images]
|
||||
images = [self.pil_to_cv2_numpy(img) for img in images]
|
||||
image_array = np.asarray(images, dtype=np.float32)
|
||||
input_name = self.model.get_inputs()[0].name
|
||||
label_name = self.model.get_outputs()[0].name
|
||||
preds = self.model.run([label_name], {input_name: image_array})[0]
|
||||
results = [Result(pred, self.sep_tags, general_threshold, character_threshold) for pred in preds]
|
||||
duration = time.time() - started_at
|
||||
image_length = len(images)
|
||||
if not self.slient:
|
||||
self.logger.info(
|
||||
f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds."
|
||||
)
|
||||
if input_is_list:
|
||||
return results
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
|
||||
__all__ = ["Tagger"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
tagger = Tagger()
|
||||
image = Image.open("./tests/images/赤松楓.9d64b955.jpeg")
|
||||
result = tagger.tag(image)
|
||||
tagger.logger.info(result)
|
Loading…
x
Reference in New Issue
Block a user