feat(input): enable input path or numpy array

This commit is contained in:
Jianqi Pan 2024-06-23 21:27:52 +09:00
parent 69f35eb373
commit a4003953e6
2 changed files with 61 additions and 6 deletions

View File

@ -1,3 +1,4 @@
import numpy as np
import pytest
from PIL import Image
@ -20,3 +21,39 @@ def test_tagger(tagger, image_file):
assert result.character_tags_string == "akamatsu kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
def test_tagger_path(tagger, image_file):
result = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
assert result.character_tags_string == "akamatsu kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
def test_tagger_np(tagger, image_file):
image = Image.open(image_file)
image_np = np.array(image)
result = tagger.tag(image_np, character_threshold=0.85, general_threshold=0.35)
assert result.character_tags_string == "akamatsu kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
def test_tagger_pil(tagger, image_file):
image = Image.open(image_file)
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
assert result.character_tags_string == "akamatsu kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", [["./tests/images/赤松楓.9d64b955.jpeg"]])
def test_tagger_np_single(tagger, image_file):
results = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
assert len(results) == 1
result = results[0]
assert result.character_tags_string == "akamatsu kaede"
assert result.rating == "general"

View File

@ -2,7 +2,8 @@ import logging
import os
import time
from collections import OrderedDict
from typing import Any, Sequence
from pathlib import Path
from typing import Any, List, Sequence, Union
import huggingface_hub
import numpy as np
@ -20,6 +21,19 @@ HF_TOKEN = os.environ.get("HF_TOKEN", "")
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
Input = Union[np.ndarray, Image.Image, str, Path]
def to_pil(img: Input) -> Image.Image:
if isinstance(img, (str, Path)):
return Image.open(img)
elif isinstance(img, np.ndarray):
return Image.fromarray(img)
elif isinstance(img, Image.Image):
return img
else:
raise ValueError("Invalid input type.")
def load_labels(dataframe) -> list[str]:
"""Load labels from a dataframe and process tag names.
@ -241,7 +255,7 @@ class Tagger:
self.model_target_size = height
self.model = model
def prepare_image(self, image):
def pil_to_cv2_numpy(self, image):
"""Prepare the image for model input.
Args:
@ -278,14 +292,14 @@ class Tagger:
def tag(
self,
image: Image.Image | list[Image.Image],
image: Union[Input, List[Input]],
general_threshold=0.35,
character_threshold=0.9,
) -> Result | list[Result]:
"""Tag the image and return the results.
Args:
image (PIL.Image | list[PIL.Image]): Input image or list of images.
image (Union[Input, List[Input]]): Input image or list of images to tag.
general_threshold (float): Threshold for general tags.
character_threshold (float): Threshold for character tags.
@ -293,8 +307,10 @@ class Tagger:
Result | list[Result]: Tagging results.
"""
started_at = time.time()
images = [image] if isinstance(image, Image.Image) else image
images = [self.prepare_image(img) for img in images]
input_is_list = isinstance(image, list)
images = image if isinstance(image, list) else [image]
images = [to_pil(img) for img in images]
images = [self.pil_to_cv2_numpy(img) for img in images]
image_array = np.asarray(images, dtype=np.float32)
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
@ -308,6 +324,8 @@ class Tagger:
self.logger.info(
f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds."
)
if input_is_list:
return results
return results[0] if len(results) == 1 else results