From a4003953e65279f95945c03fbb08e5299b1133d7 Mon Sep 17 00:00:00 2001 From: Jianqi Pan Date: Sun, 23 Jun 2024 21:27:52 +0900 Subject: [PATCH] feat(input): enable input path or numpy array --- tests/test_tagger.py | 37 +++++++++++++++++++++++++++++++++++++ wdtagger/__init__.py | 30 ++++++++++++++++++++++++------ 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/tests/test_tagger.py b/tests/test_tagger.py index 0702662..09a0cca 100644 --- a/tests/test_tagger.py +++ b/tests/test_tagger.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from PIL import Image @@ -20,3 +21,39 @@ def test_tagger(tagger, image_file): assert result.character_tags_string == "akamatsu kaede" assert result.rating == "general" + + +@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"]) +def test_tagger_path(tagger, image_file): + result = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35) + + assert result.character_tags_string == "akamatsu kaede" + assert result.rating == "general" + + +@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"]) +def test_tagger_np(tagger, image_file): + image = Image.open(image_file) + image_np = np.array(image) + result = tagger.tag(image_np, character_threshold=0.85, general_threshold=0.35) + + assert result.character_tags_string == "akamatsu kaede" + assert result.rating == "general" + + +@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"]) +def test_tagger_pil(tagger, image_file): + image = Image.open(image_file) + result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35) + + assert result.character_tags_string == "akamatsu kaede" + assert result.rating == "general" + + +@pytest.mark.parametrize("image_file", [["./tests/images/赤松楓.9d64b955.jpeg"]]) +def test_tagger_np_single(tagger, image_file): + results = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35) + assert len(results) == 1 + result = results[0] + assert result.character_tags_string == "akamatsu kaede" + assert result.rating == "general" diff --git a/wdtagger/__init__.py b/wdtagger/__init__.py index 36967c4..798c1ce 100644 --- a/wdtagger/__init__.py +++ b/wdtagger/__init__.py @@ -2,7 +2,8 @@ import logging import os import time from collections import OrderedDict -from typing import Any, Sequence +from pathlib import Path +from typing import Any, List, Sequence, Union import huggingface_hub import numpy as np @@ -20,6 +21,19 @@ HF_TOKEN = os.environ.get("HF_TOKEN", "") MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" +Input = Union[np.ndarray, Image.Image, str, Path] + + +def to_pil(img: Input) -> Image.Image: + if isinstance(img, (str, Path)): + return Image.open(img) + elif isinstance(img, np.ndarray): + return Image.fromarray(img) + elif isinstance(img, Image.Image): + return img + else: + raise ValueError("Invalid input type.") + def load_labels(dataframe) -> list[str]: """Load labels from a dataframe and process tag names. @@ -241,7 +255,7 @@ class Tagger: self.model_target_size = height self.model = model - def prepare_image(self, image): + def pil_to_cv2_numpy(self, image): """Prepare the image for model input. Args: @@ -278,14 +292,14 @@ class Tagger: def tag( self, - image: Image.Image | list[Image.Image], + image: Union[Input, List[Input]], general_threshold=0.35, character_threshold=0.9, ) -> Result | list[Result]: """Tag the image and return the results. Args: - image (PIL.Image | list[PIL.Image]): Input image or list of images. + image (Union[Input, List[Input]]): Input image or list of images to tag. general_threshold (float): Threshold for general tags. character_threshold (float): Threshold for character tags. @@ -293,8 +307,10 @@ class Tagger: Result | list[Result]: Tagging results. """ started_at = time.time() - images = [image] if isinstance(image, Image.Image) else image - images = [self.prepare_image(img) for img in images] + input_is_list = isinstance(image, list) + images = image if isinstance(image, list) else [image] + images = [to_pil(img) for img in images] + images = [self.pil_to_cv2_numpy(img) for img in images] image_array = np.asarray(images, dtype=np.float32) input_name = self.model.get_inputs()[0].name label_name = self.model.get_outputs()[0].name @@ -308,6 +324,8 @@ class Tagger: self.logger.info( f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds." ) + if input_is_list: + return results return results[0] if len(results) == 1 else results