|
|
|
@ -1,4 +1,6 @@
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
|
|
import huggingface_hub
|
|
|
|
@ -6,7 +8,9 @@ import numpy as np
|
|
|
|
|
import onnxruntime as rt
|
|
|
|
|
import pandas as pd
|
|
|
|
|
import rich
|
|
|
|
|
import rich.live
|
|
|
|
|
from PIL import Image
|
|
|
|
|
from rich.logging import RichHandler
|
|
|
|
|
|
|
|
|
|
# Access console for rich text and logging
|
|
|
|
|
console = rich.get_console()
|
|
|
|
@ -66,7 +70,11 @@ def load_labels(dataframe) -> list[str]:
|
|
|
|
|
|
|
|
|
|
class Result:
|
|
|
|
|
def __init__(
|
|
|
|
|
self, preds, sep_tags, general_threshold=0.35, character_threshold=0.9
|
|
|
|
|
self,
|
|
|
|
|
pred,
|
|
|
|
|
sep_tags,
|
|
|
|
|
general_threshold=0.35,
|
|
|
|
|
character_threshold=0.9,
|
|
|
|
|
):
|
|
|
|
|
"""Initialize the Result object to store tagging results.
|
|
|
|
|
|
|
|
|
@ -80,7 +88,7 @@ class Result:
|
|
|
|
|
rating_indexes = sep_tags[1]
|
|
|
|
|
general_indexes = sep_tags[2]
|
|
|
|
|
character_indexes = sep_tags[3]
|
|
|
|
|
labels = list(zip(tag_names, preds[0].astype(float)))
|
|
|
|
|
labels = list(zip(tag_names, pred.astype(float)))
|
|
|
|
|
|
|
|
|
|
# Ratings
|
|
|
|
|
ratings_names = [labels[i] for i in rating_indexes]
|
|
|
|
@ -129,9 +137,7 @@ class Result:
|
|
|
|
|
reverse=True,
|
|
|
|
|
)
|
|
|
|
|
string = [x[0] for x in string]
|
|
|
|
|
string = ", ".join(string)
|
|
|
|
|
|
|
|
|
|
return string
|
|
|
|
|
return ", ".join(string)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def character_tags_string(self) -> str:
|
|
|
|
@ -142,8 +148,7 @@ class Result:
|
|
|
|
|
reverse=True,
|
|
|
|
|
)
|
|
|
|
|
string = [x[0] for x in string]
|
|
|
|
|
string = ", ".join(string)
|
|
|
|
|
return string
|
|
|
|
|
return ", ".join(string)
|
|
|
|
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
|
|
"""Return a formatted string representation of the tags and their ratings."""
|
|
|
|
@ -163,6 +168,7 @@ class Tagger:
|
|
|
|
|
model_repo="SmilingWolf/wd-swinv2-tagger-v3",
|
|
|
|
|
cache_dir=None,
|
|
|
|
|
hf_token=HF_TOKEN,
|
|
|
|
|
loglevel=logging.INFO,
|
|
|
|
|
):
|
|
|
|
|
"""Initialize the Tagger object with the model repository and tokens.
|
|
|
|
|
|
|
|
|
@ -171,12 +177,17 @@ class Tagger:
|
|
|
|
|
cache_dir (str, optional): Directory to cache the model. Defaults to None.
|
|
|
|
|
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
|
|
|
|
|
"""
|
|
|
|
|
self.logger = logging.getLogger("wdtagger")
|
|
|
|
|
self.logger.setLevel(loglevel)
|
|
|
|
|
self.logger.addHandler(RichHandler())
|
|
|
|
|
self.model_target_size = None
|
|
|
|
|
self.cache_dir = cache_dir
|
|
|
|
|
self.hf_token = hf_token
|
|
|
|
|
self.load_model(model_repo, cache_dir, hf_token)
|
|
|
|
|
|
|
|
|
|
def load_model(self, model_repo, cache_dir=None, hf_token=None):
|
|
|
|
|
def load_model(
|
|
|
|
|
self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None
|
|
|
|
|
):
|
|
|
|
|
"""Load the model and tags from the specified repository.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -191,6 +202,7 @@ class Tagger:
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
use_auth_token=hf_token,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
model_path = huggingface_hub.hf_hub_download(
|
|
|
|
|
model_repo,
|
|
|
|
|
MODEL_FILENAME,
|
|
|
|
@ -200,8 +212,11 @@ class Tagger:
|
|
|
|
|
|
|
|
|
|
tags_df = pd.read_csv(csv_path)
|
|
|
|
|
self.sep_tags = load_labels(tags_df)
|
|
|
|
|
|
|
|
|
|
model = rt.InferenceSession(model_path)
|
|
|
|
|
options = rt.SessionOptions()
|
|
|
|
|
if num_threads:
|
|
|
|
|
options.intra_op_num_threads = num_threads
|
|
|
|
|
options.inter_op_num_threads = num_threads
|
|
|
|
|
model = rt.InferenceSession(model_path, options)
|
|
|
|
|
_, height, _, _ = model.get_inputs()[0].shape
|
|
|
|
|
self.model_target_size = height
|
|
|
|
|
self.model = model
|
|
|
|
@ -238,44 +253,41 @@ class Tagger:
|
|
|
|
|
Image.BICUBIC,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Convert to numpy array
|
|
|
|
|
image_array = np.asarray(padded_image, dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
# Convert PIL-native RGB to BGR
|
|
|
|
|
image_array = image_array[:, :, ::-1]
|
|
|
|
|
|
|
|
|
|
return np.expand_dims(image_array, axis=0)
|
|
|
|
|
return np.asarray(padded_image, dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
def tag(
|
|
|
|
|
self,
|
|
|
|
|
image,
|
|
|
|
|
image: Image.Image | list[Image.Image],
|
|
|
|
|
general_threshold=0.35,
|
|
|
|
|
character_threshold=0.9,
|
|
|
|
|
):
|
|
|
|
|
) -> Result | list[Result]:
|
|
|
|
|
"""Tag the image and return the results.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
image (PIL.Image): Input image.
|
|
|
|
|
image (PIL.Image | list[PIL.Image]): Input image or list of images.
|
|
|
|
|
general_threshold (float): Threshold for general tags.
|
|
|
|
|
character_threshold (float): Threshold for character tags.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Result: Object containing the tagging results.
|
|
|
|
|
Result | list[Result]: Tagging results.
|
|
|
|
|
"""
|
|
|
|
|
with console.status("Tagging..."):
|
|
|
|
|
image = self.prepare_image(image)
|
|
|
|
|
image_array = np.asarray(image, dtype=np.float32)
|
|
|
|
|
image_array = image_array[:, :, ::-1] # Convert PIL-native RGB to BGR
|
|
|
|
|
|
|
|
|
|
image_array = np.expand_dims(image_array, axis=0)
|
|
|
|
|
input_name = self.model.get_inputs()[0].name
|
|
|
|
|
label_name = self.model.get_outputs()[0].name
|
|
|
|
|
|
|
|
|
|
preds = self.model.run([label_name], {input_name: image_array[0]})[0]
|
|
|
|
|
result = Result(
|
|
|
|
|
preds, self.sep_tags, general_threshold, character_threshold
|
|
|
|
|
)
|
|
|
|
|
return result
|
|
|
|
|
started_at = time.time()
|
|
|
|
|
images = [image] if isinstance(image, Image.Image) else image
|
|
|
|
|
images = [self.prepare_image(img) for img in images]
|
|
|
|
|
image_array = np.asarray(images, dtype=np.float32)
|
|
|
|
|
input_name = self.model.get_inputs()[0].name
|
|
|
|
|
label_name = self.model.get_outputs()[0].name
|
|
|
|
|
preds = self.model.run([label_name], {input_name: image_array})[0]
|
|
|
|
|
results = [
|
|
|
|
|
Result(pred, self.sep_tags, general_threshold, character_threshold)
|
|
|
|
|
for pred in preds
|
|
|
|
|
]
|
|
|
|
|
duration = time.time() - started_at
|
|
|
|
|
image_length = len(images)
|
|
|
|
|
self.logger.info(
|
|
|
|
|
f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds."
|
|
|
|
|
)
|
|
|
|
|
return results[0] if len(results) == 1 else results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["Tagger"]
|
|
|
|
@ -283,5 +295,5 @@ __all__ = ["Tagger"]
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
tagger = Tagger()
|
|
|
|
|
image = Image.open("./tests/images/赤松楓.9d64b955.jpeg")
|
|
|
|
|
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
|
|
|
|
|
print(result)
|
|
|
|
|
result = tagger.tag(image)
|
|
|
|
|
tagger.logger.info(result)
|
|
|
|
|