4 Commits

Author SHA1 Message Date
ae2838f58e version: v0.3.0 2024-06-11 03:18:40 +09:00
f9ec9de157 feat(logger): use logger 2024-06-11 03:18:23 +09:00
ea88303c8a version: v0.2.0 2024-06-08 00:13:16 +09:00
1f1ab1e3df feat(tagger): batch tagger 2024-06-08 00:12:41 +09:00
3 changed files with 59 additions and 38 deletions

9
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,9 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnType": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
},
},
}

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "wdtagger"
version = "0.1.0"
version = "0.3.0"
description = ""
authors = ["Jianqi Pan <jannchie@gmail.com>"]
readme = "README.md"

View File

@ -1,4 +1,6 @@
import logging
import os
import time
from collections import OrderedDict
import huggingface_hub
@ -6,7 +8,9 @@ import numpy as np
import onnxruntime as rt
import pandas as pd
import rich
import rich.live
from PIL import Image
from rich.logging import RichHandler
# Access console for rich text and logging
console = rich.get_console()
@ -66,7 +70,11 @@ def load_labels(dataframe) -> list[str]:
class Result:
def __init__(
self, preds, sep_tags, general_threshold=0.35, character_threshold=0.9
self,
pred,
sep_tags,
general_threshold=0.35,
character_threshold=0.9,
):
"""Initialize the Result object to store tagging results.
@ -80,7 +88,7 @@ class Result:
rating_indexes = sep_tags[1]
general_indexes = sep_tags[2]
character_indexes = sep_tags[3]
labels = list(zip(tag_names, preds[0].astype(float)))
labels = list(zip(tag_names, pred.astype(float)))
# Ratings
ratings_names = [labels[i] for i in rating_indexes]
@ -129,9 +137,7 @@ class Result:
reverse=True,
)
string = [x[0] for x in string]
string = ", ".join(string)
return string
return ", ".join(string)
@property
def character_tags_string(self) -> str:
@ -142,8 +148,7 @@ class Result:
reverse=True,
)
string = [x[0] for x in string]
string = ", ".join(string)
return string
return ", ".join(string)
def __str__(self) -> str:
"""Return a formatted string representation of the tags and their ratings."""
@ -163,6 +168,7 @@ class Tagger:
model_repo="SmilingWolf/wd-swinv2-tagger-v3",
cache_dir=None,
hf_token=HF_TOKEN,
loglevel=logging.INFO,
):
"""Initialize the Tagger object with the model repository and tokens.
@ -171,12 +177,17 @@ class Tagger:
cache_dir (str, optional): Directory to cache the model. Defaults to None.
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
"""
self.logger = logging.getLogger("wdtagger")
self.logger.setLevel(loglevel)
self.logger.addHandler(RichHandler())
self.model_target_size = None
self.cache_dir = cache_dir
self.hf_token = hf_token
self.load_model(model_repo, cache_dir, hf_token)
def load_model(self, model_repo, cache_dir=None, hf_token=None):
def load_model(
self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None
):
"""Load the model and tags from the specified repository.
Args:
@ -191,6 +202,7 @@ class Tagger:
cache_dir=cache_dir,
use_auth_token=hf_token,
)
model_path = huggingface_hub.hf_hub_download(
model_repo,
MODEL_FILENAME,
@ -200,8 +212,11 @@ class Tagger:
tags_df = pd.read_csv(csv_path)
self.sep_tags = load_labels(tags_df)
model = rt.InferenceSession(model_path)
options = rt.SessionOptions()
if num_threads:
options.intra_op_num_threads = num_threads
options.inter_op_num_threads = num_threads
model = rt.InferenceSession(model_path, options)
_, height, _, _ = model.get_inputs()[0].shape
self.model_target_size = height
self.model = model
@ -238,44 +253,41 @@ class Tagger:
Image.BICUBIC,
)
# Convert to numpy array
image_array = np.asarray(padded_image, dtype=np.float32)
# Convert PIL-native RGB to BGR
image_array = image_array[:, :, ::-1]
return np.expand_dims(image_array, axis=0)
return np.asarray(padded_image, dtype=np.float32)
def tag(
self,
image,
image: Image.Image | list[Image.Image],
general_threshold=0.35,
character_threshold=0.9,
):
) -> Result | list[Result]:
"""Tag the image and return the results.
Args:
image (PIL.Image): Input image.
image (PIL.Image | list[PIL.Image]): Input image or list of images.
general_threshold (float): Threshold for general tags.
character_threshold (float): Threshold for character tags.
Returns:
Result: Object containing the tagging results.
Result | list[Result]: Tagging results.
"""
with console.status("Tagging..."):
image = self.prepare_image(image)
image_array = np.asarray(image, dtype=np.float32)
image_array = image_array[:, :, ::-1] # Convert PIL-native RGB to BGR
image_array = np.expand_dims(image_array, axis=0)
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
preds = self.model.run([label_name], {input_name: image_array[0]})[0]
result = Result(
preds, self.sep_tags, general_threshold, character_threshold
)
return result
started_at = time.time()
images = [image] if isinstance(image, Image.Image) else image
images = [self.prepare_image(img) for img in images]
image_array = np.asarray(images, dtype=np.float32)
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
preds = self.model.run([label_name], {input_name: image_array})[0]
results = [
Result(pred, self.sep_tags, general_threshold, character_threshold)
for pred in preds
]
duration = time.time() - started_at
image_length = len(images)
self.logger.info(
f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds."
)
return results[0] if len(results) == 1 else results
__all__ = ["Tagger"]
@ -283,5 +295,5 @@ __all__ = ["Tagger"]
if __name__ == "__main__":
tagger = Tagger()
image = Image.open("./tests/images/赤松楓.9d64b955.jpeg")
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
print(result)
result = tagger.tag(image)
tagger.logger.info(result)