diff --git a/wdtagger/__init__.py b/wdtagger/__init__.py index 4d0aa2e..72f8868 100644 --- a/wdtagger/__init__.py +++ b/wdtagger/__init__.py @@ -1,4 +1,5 @@ import os +import time from collections import OrderedDict import huggingface_hub @@ -6,6 +7,7 @@ import numpy as np import onnxruntime as rt import pandas as pd import rich +import rich.live from PIL import Image # Access console for rich text and logging @@ -65,9 +67,7 @@ def load_labels(dataframe) -> list[str]: class Result: - def __init__( - self, preds, sep_tags, general_threshold=0.35, character_threshold=0.9 - ): + def __init__(self, pred, sep_tags, general_threshold=0.35, character_threshold=0.9): """Initialize the Result object to store tagging results. Args: @@ -80,7 +80,7 @@ class Result: rating_indexes = sep_tags[1] general_indexes = sep_tags[2] character_indexes = sep_tags[3] - labels = list(zip(tag_names, preds[0].astype(float))) + labels = list(zip(tag_names, pred.astype(float))) # Ratings ratings_names = [labels[i] for i in rating_indexes] @@ -129,9 +129,7 @@ class Result: reverse=True, ) string = [x[0] for x in string] - string = ", ".join(string) - - return string + return ", ".join(string) @property def character_tags_string(self) -> str: @@ -142,8 +140,7 @@ class Result: reverse=True, ) string = [x[0] for x in string] - string = ", ".join(string) - return string + return ", ".join(string) def __str__(self) -> str: """Return a formatted string representation of the tags and their ratings.""" @@ -238,44 +235,41 @@ class Tagger: Image.BICUBIC, ) - # Convert to numpy array - image_array = np.asarray(padded_image, dtype=np.float32) - - # Convert PIL-native RGB to BGR - image_array = image_array[:, :, ::-1] - - return np.expand_dims(image_array, axis=0) + return np.asarray(padded_image, dtype=np.float32) def tag( self, - image, + image: Image.Image | list[Image.Image], general_threshold=0.35, character_threshold=0.9, - ): + ) -> Result | list[Result]: """Tag the image and return the results. Args: - image (PIL.Image): Input image. + image (PIL.Image | list[PIL.Image]): Input image or list of images. general_threshold (float): Threshold for general tags. character_threshold (float): Threshold for character tags. Returns: - Result: Object containing the tagging results. + Result | list[Result]: Tagging results. """ - with console.status("Tagging..."): - image = self.prepare_image(image) - image_array = np.asarray(image, dtype=np.float32) - image_array = image_array[:, :, ::-1] # Convert PIL-native RGB to BGR - - image_array = np.expand_dims(image_array, axis=0) - input_name = self.model.get_inputs()[0].name - label_name = self.model.get_outputs()[0].name - - preds = self.model.run([label_name], {input_name: image_array[0]})[0] - result = Result( - preds, self.sep_tags, general_threshold, character_threshold - ) - return result + started_at = time.time() + images = [image] if isinstance(image, Image.Image) else image + images = [self.prepare_image(img) for img in images] + image_array = np.asarray(images, dtype=np.float32) + input_name = self.model.get_inputs()[0].name + label_name = self.model.get_outputs()[0].name + preds = self.model.run([label_name], {input_name: image_array})[0] + results = [ + Result(pred, self.sep_tags, general_threshold, character_threshold) + for pred in preds + ] + duration = time.time() - started_at + image_length = len(images) + console.log(f"Tagging {image_length} image{ + 's' if image_length > 1 else '' + } took {duration:.2f} seconds.") + return results[0] if len(results) == 1 else results __all__ = ["Tagger"] @@ -283,5 +277,5 @@ __all__ = ["Tagger"] if __name__ == "__main__": tagger = Tagger() image = Image.open("./tests/images/赤松楓.9d64b955.jpeg") - result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35) - print(result) + result = tagger.tag(image) + console.log(result)