feat(tagger): batch tagger

This commit is contained in:
Jianqi Pan 2024-06-08 00:12:41 +09:00
parent 2b67275105
commit 1f1ab1e3df

View File

@ -1,4 +1,5 @@
import os
import time
from collections import OrderedDict
import huggingface_hub
@ -6,6 +7,7 @@ import numpy as np
import onnxruntime as rt
import pandas as pd
import rich
import rich.live
from PIL import Image
# Access console for rich text and logging
@ -65,9 +67,7 @@ def load_labels(dataframe) -> list[str]:
class Result:
def __init__(
self, preds, sep_tags, general_threshold=0.35, character_threshold=0.9
):
def __init__(self, pred, sep_tags, general_threshold=0.35, character_threshold=0.9):
"""Initialize the Result object to store tagging results.
Args:
@ -80,7 +80,7 @@ class Result:
rating_indexes = sep_tags[1]
general_indexes = sep_tags[2]
character_indexes = sep_tags[3]
labels = list(zip(tag_names, preds[0].astype(float)))
labels = list(zip(tag_names, pred.astype(float)))
# Ratings
ratings_names = [labels[i] for i in rating_indexes]
@ -129,9 +129,7 @@ class Result:
reverse=True,
)
string = [x[0] for x in string]
string = ", ".join(string)
return string
return ", ".join(string)
@property
def character_tags_string(self) -> str:
@ -142,8 +140,7 @@ class Result:
reverse=True,
)
string = [x[0] for x in string]
string = ", ".join(string)
return string
return ", ".join(string)
def __str__(self) -> str:
"""Return a formatted string representation of the tags and their ratings."""
@ -238,44 +235,41 @@ class Tagger:
Image.BICUBIC,
)
# Convert to numpy array
image_array = np.asarray(padded_image, dtype=np.float32)
# Convert PIL-native RGB to BGR
image_array = image_array[:, :, ::-1]
return np.expand_dims(image_array, axis=0)
return np.asarray(padded_image, dtype=np.float32)
def tag(
self,
image,
image: Image.Image | list[Image.Image],
general_threshold=0.35,
character_threshold=0.9,
):
) -> Result | list[Result]:
"""Tag the image and return the results.
Args:
image (PIL.Image): Input image.
image (PIL.Image | list[PIL.Image]): Input image or list of images.
general_threshold (float): Threshold for general tags.
character_threshold (float): Threshold for character tags.
Returns:
Result: Object containing the tagging results.
Result | list[Result]: Tagging results.
"""
with console.status("Tagging..."):
image = self.prepare_image(image)
image_array = np.asarray(image, dtype=np.float32)
image_array = image_array[:, :, ::-1] # Convert PIL-native RGB to BGR
image_array = np.expand_dims(image_array, axis=0)
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
preds = self.model.run([label_name], {input_name: image_array[0]})[0]
result = Result(
preds, self.sep_tags, general_threshold, character_threshold
)
return result
started_at = time.time()
images = [image] if isinstance(image, Image.Image) else image
images = [self.prepare_image(img) for img in images]
image_array = np.asarray(images, dtype=np.float32)
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
preds = self.model.run([label_name], {input_name: image_array})[0]
results = [
Result(pred, self.sep_tags, general_threshold, character_threshold)
for pred in preds
]
duration = time.time() - started_at
image_length = len(images)
console.log(f"Tagging {image_length} image{
's' if image_length > 1 else ''
} took {duration:.2f} seconds.")
return results[0] if len(results) == 1 else results
__all__ = ["Tagger"]
@ -283,5 +277,5 @@ __all__ = ["Tagger"]
if __name__ == "__main__":
tagger = Tagger()
image = Image.open("./tests/images/赤松楓.9d64b955.jpeg")
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
print(result)
result = tagger.tag(image)
console.log(result)