✨ feat(tagger): batch tagger
This commit is contained in:
parent
2b67275105
commit
1f1ab1e3df
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
import huggingface_hub
|
||||
@ -6,6 +7,7 @@ import numpy as np
|
||||
import onnxruntime as rt
|
||||
import pandas as pd
|
||||
import rich
|
||||
import rich.live
|
||||
from PIL import Image
|
||||
|
||||
# Access console for rich text and logging
|
||||
@ -65,9 +67,7 @@ def load_labels(dataframe) -> list[str]:
|
||||
|
||||
|
||||
class Result:
|
||||
def __init__(
|
||||
self, preds, sep_tags, general_threshold=0.35, character_threshold=0.9
|
||||
):
|
||||
def __init__(self, pred, sep_tags, general_threshold=0.35, character_threshold=0.9):
|
||||
"""Initialize the Result object to store tagging results.
|
||||
|
||||
Args:
|
||||
@ -80,7 +80,7 @@ class Result:
|
||||
rating_indexes = sep_tags[1]
|
||||
general_indexes = sep_tags[2]
|
||||
character_indexes = sep_tags[3]
|
||||
labels = list(zip(tag_names, preds[0].astype(float)))
|
||||
labels = list(zip(tag_names, pred.astype(float)))
|
||||
|
||||
# Ratings
|
||||
ratings_names = [labels[i] for i in rating_indexes]
|
||||
@ -129,9 +129,7 @@ class Result:
|
||||
reverse=True,
|
||||
)
|
||||
string = [x[0] for x in string]
|
||||
string = ", ".join(string)
|
||||
|
||||
return string
|
||||
return ", ".join(string)
|
||||
|
||||
@property
|
||||
def character_tags_string(self) -> str:
|
||||
@ -142,8 +140,7 @@ class Result:
|
||||
reverse=True,
|
||||
)
|
||||
string = [x[0] for x in string]
|
||||
string = ", ".join(string)
|
||||
return string
|
||||
return ", ".join(string)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a formatted string representation of the tags and their ratings."""
|
||||
@ -238,44 +235,41 @@ class Tagger:
|
||||
Image.BICUBIC,
|
||||
)
|
||||
|
||||
# Convert to numpy array
|
||||
image_array = np.asarray(padded_image, dtype=np.float32)
|
||||
|
||||
# Convert PIL-native RGB to BGR
|
||||
image_array = image_array[:, :, ::-1]
|
||||
|
||||
return np.expand_dims(image_array, axis=0)
|
||||
return np.asarray(padded_image, dtype=np.float32)
|
||||
|
||||
def tag(
|
||||
self,
|
||||
image,
|
||||
image: Image.Image | list[Image.Image],
|
||||
general_threshold=0.35,
|
||||
character_threshold=0.9,
|
||||
):
|
||||
) -> Result | list[Result]:
|
||||
"""Tag the image and return the results.
|
||||
|
||||
Args:
|
||||
image (PIL.Image): Input image.
|
||||
image (PIL.Image | list[PIL.Image]): Input image or list of images.
|
||||
general_threshold (float): Threshold for general tags.
|
||||
character_threshold (float): Threshold for character tags.
|
||||
|
||||
Returns:
|
||||
Result: Object containing the tagging results.
|
||||
Result | list[Result]: Tagging results.
|
||||
"""
|
||||
with console.status("Tagging..."):
|
||||
image = self.prepare_image(image)
|
||||
image_array = np.asarray(image, dtype=np.float32)
|
||||
image_array = image_array[:, :, ::-1] # Convert PIL-native RGB to BGR
|
||||
|
||||
image_array = np.expand_dims(image_array, axis=0)
|
||||
input_name = self.model.get_inputs()[0].name
|
||||
label_name = self.model.get_outputs()[0].name
|
||||
|
||||
preds = self.model.run([label_name], {input_name: image_array[0]})[0]
|
||||
result = Result(
|
||||
preds, self.sep_tags, general_threshold, character_threshold
|
||||
)
|
||||
return result
|
||||
started_at = time.time()
|
||||
images = [image] if isinstance(image, Image.Image) else image
|
||||
images = [self.prepare_image(img) for img in images]
|
||||
image_array = np.asarray(images, dtype=np.float32)
|
||||
input_name = self.model.get_inputs()[0].name
|
||||
label_name = self.model.get_outputs()[0].name
|
||||
preds = self.model.run([label_name], {input_name: image_array})[0]
|
||||
results = [
|
||||
Result(pred, self.sep_tags, general_threshold, character_threshold)
|
||||
for pred in preds
|
||||
]
|
||||
duration = time.time() - started_at
|
||||
image_length = len(images)
|
||||
console.log(f"Tagging {image_length} image{
|
||||
's' if image_length > 1 else ''
|
||||
} took {duration:.2f} seconds.")
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
|
||||
__all__ = ["Tagger"]
|
||||
@ -283,5 +277,5 @@ __all__ = ["Tagger"]
|
||||
if __name__ == "__main__":
|
||||
tagger = Tagger()
|
||||
image = Image.open("./tests/images/赤松楓.9d64b955.jpeg")
|
||||
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
|
||||
print(result)
|
||||
result = tagger.tag(image)
|
||||
console.log(result)
|
||||
|
Loading…
x
Reference in New Issue
Block a user