9 Commits

Author SHA1 Message Date
09d331d8f0 version: v0.9.0 2024-07-10 16:53:47 +09:00
6d1cbc0b2b feat(threading): allow use external rich.console 2024-07-10 16:53:32 +09:00
c910823173 version: v0.8.0 2024-06-25 00:38:26 +09:00
16686ce3c4 feat(result): return all tags 2024-06-25 00:38:14 +09:00
9eafed31dc version: v0.7.0 2024-06-23 21:28:06 +09:00
a4003953e6 feat(input): enable input path or numpy array 2024-06-23 21:27:52 +09:00
69f35eb373 version: v0.6.0 2024-06-23 03:21:06 +09:00
c82f881c85 feat(provider): enable change onnx provider 2024-06-23 03:20:56 +09:00
b9a763468a 📚 docs: add codetime badge 2024-06-19 18:34:17 +09:00
4 changed files with 102 additions and 22 deletions

View File

@ -1,5 +1,7 @@
# wdtagger
[![CodeTime Badge](https://img.shields.io/endpoint?style=social&color=222&url=https%3A%2F%2Fapi.codetime.dev%2Fshield%3Fid%3D2%26project%3Dwdtagger%26in=0)](https://codetime.dev)
`wdtagger` is a simple and easy-to-use wrapper for the tagger model created by [SmilingWolf](https://github.com/SmilingWolf) which is specifically designed for tagging anime illustrations.
## Installation

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "wdtagger"
version = "0.5.0"
version = "0.9.0"
description = ""
authors = ["Jianqi Pan <jannchie@gmail.com>"]
readme = "README.md"

View File

@ -1,3 +1,4 @@
import numpy as np
import pytest
from PIL import Image
@ -20,3 +21,39 @@ def test_tagger(tagger, image_file):
assert result.character_tags_string == "akamatsu kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
def test_tagger_path(tagger, image_file):
result = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
assert result.character_tags_string == "akamatsu kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
def test_tagger_np(tagger, image_file):
image = Image.open(image_file)
image_np = np.array(image)
result = tagger.tag(image_np, character_threshold=0.85, general_threshold=0.35)
assert result.character_tags_string == "akamatsu kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
def test_tagger_pil(tagger, image_file):
image = Image.open(image_file)
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
assert result.character_tags_string == "akamatsu kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", [["./tests/images/赤松楓.9d64b955.jpeg"]])
def test_tagger_np_single(tagger, image_file):
results = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
assert len(results) == 1
result = results[0]
assert result.character_tags_string == "akamatsu kaede"
assert result.rating == "general"

View File

@ -2,29 +2,32 @@ import logging
import os
import time
from collections import OrderedDict
from pathlib import Path
from typing import Any, List, Sequence, Union
import huggingface_hub
import numpy as np
import onnxruntime as rt
import pandas as pd
import rich
import rich.live
from PIL import Image
from rich.logging import RichHandler
# Access console for rich text and logging
console = rich.get_console()
HF_TOKEN = os.environ.get("HF_TOKEN", "")
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
# Environment variables and file paths
HF_TOKEN = os.environ.get(
"HF_TOKEN", ""
) # Token for authentication with HuggingFace API
MODEL_FILENAME = "model.onnx" # ONNX model filename
LABEL_FILENAME = "selected_tags.csv" # Labels CSV filename
Input = Union[np.ndarray, Image.Image, str, Path]
available_providers = rt.get_available_providers()
supported_providers = ["CPUExecutionProvider", "CUDAExecutionProvider"]
providers = list(set(available_providers) & set(supported_providers))
def to_pil(img: Input) -> Image.Image:
if isinstance(img, (str, Path)):
return Image.open(img)
elif isinstance(img, np.ndarray):
return Image.fromarray(img)
elif isinstance(img, Image.Image):
return img
else:
raise ValueError("Invalid input type.")
def load_labels(dataframe) -> list[str]:
@ -154,6 +157,15 @@ class Result:
string = [x[0] for x in string]
return ", ".join(string)
@property
def all_tags(self) -> list[str]:
"""Return all tags as a list."""
return [self.rating] + list(self.general_tags) + list(self.character_tags)
@property
def all_tags_string(self) -> str:
return ", ".join(self.all_tags)
def __str__(self) -> str:
"""Return a formatted string representation of the tags and their ratings."""
@ -174,6 +186,8 @@ class Tagger:
hf_token=HF_TOKEN,
loglevel=logging.INFO,
num_threads=None,
providers=None,
console=None,
):
"""Initialize the Tagger object with the model repository and tokens.
@ -183,17 +197,40 @@ class Tagger:
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
loglevel (int, optional): Logging level. Defaults to logging.INFO.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
providers (list, optional): List of providers for ONNX runtime. Defaults to None.
console (rich.console.Console, optional): Rich console object. Defaults to None.
"""
if not console:
from rich import get_console
self.console = get_console()
else:
self.console = console
if providers is None:
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
self.logger = logging.getLogger("wdtagger")
self.logger.setLevel(loglevel)
self.logger.addHandler(RichHandler())
self.model_target_size = None
self.cache_dir = cache_dir
self.hf_token = hf_token
self.load_model(model_repo, cache_dir, hf_token, num_threads=num_threads)
self.load_model(
model_repo,
cache_dir,
hf_token,
num_threads=num_threads,
providers=providers,
)
def load_model(
self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None
self,
model_repo,
cache_dir=None,
hf_token=None,
num_threads: int = None,
providers: Sequence[str | tuple[str, dict[Any, Any]]] = None,
):
"""Load the model and tags from the specified repository.
@ -203,7 +240,7 @@ class Tagger:
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
"""
with console.status("Loading model..."):
with self.console.status("Loading model..."):
csv_path = huggingface_hub.hf_hub_download(
model_repo,
LABEL_FILENAME,
@ -233,7 +270,7 @@ class Tagger:
self.model_target_size = height
self.model = model
def prepare_image(self, image):
def pil_to_cv2_numpy(self, image):
"""Prepare the image for model input.
Args:
@ -270,14 +307,14 @@ class Tagger:
def tag(
self,
image: Image.Image | list[Image.Image],
image: Union[Input, List[Input]],
general_threshold=0.35,
character_threshold=0.9,
) -> Result | list[Result]:
"""Tag the image and return the results.
Args:
image (PIL.Image | list[PIL.Image]): Input image or list of images.
image (Union[Input, List[Input]]): Input image or list of images to tag.
general_threshold (float): Threshold for general tags.
character_threshold (float): Threshold for character tags.
@ -285,8 +322,10 @@ class Tagger:
Result | list[Result]: Tagging results.
"""
started_at = time.time()
images = [image] if isinstance(image, Image.Image) else image
images = [self.prepare_image(img) for img in images]
input_is_list = isinstance(image, list)
images = image if isinstance(image, list) else [image]
images = [to_pil(img) for img in images]
images = [self.pil_to_cv2_numpy(img) for img in images]
image_array = np.asarray(images, dtype=np.float32)
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
@ -300,6 +339,8 @@ class Tagger:
self.logger.info(
f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds."
)
if input_is_list:
return results
return results[0] if len(results) == 1 else results