Compare commits
4 Commits
Author | SHA1 | Date | |
---|---|---|---|
c910823173 | |||
16686ce3c4 | |||
9eafed31dc | |||
a4003953e6 |
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "wdtagger"
|
||||
version = "0.6.0"
|
||||
version = "0.8.0"
|
||||
description = ""
|
||||
authors = ["Jianqi Pan <jannchie@gmail.com>"]
|
||||
readme = "README.md"
|
||||
|
@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
@ -20,3 +21,39 @@ def test_tagger(tagger, image_file):
|
||||
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_path(tagger, image_file):
|
||||
result = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_np(tagger, image_file):
|
||||
image = Image.open(image_file)
|
||||
image_np = np.array(image)
|
||||
result = tagger.tag(image_np, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_pil(tagger, image_file):
|
||||
image = Image.open(image_file)
|
||||
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", [["./tests/images/赤松楓.9d64b955.jpeg"]])
|
||||
def test_tagger_np_single(tagger, image_file):
|
||||
results = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
|
||||
assert len(results) == 1
|
||||
result = results[0]
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.rating == "general"
|
||||
|
@ -2,7 +2,8 @@ import logging
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Sequence, Union
|
||||
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
@ -20,6 +21,19 @@ HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
||||
MODEL_FILENAME = "model.onnx"
|
||||
LABEL_FILENAME = "selected_tags.csv"
|
||||
|
||||
Input = Union[np.ndarray, Image.Image, str, Path]
|
||||
|
||||
|
||||
def to_pil(img: Input) -> Image.Image:
|
||||
if isinstance(img, (str, Path)):
|
||||
return Image.open(img)
|
||||
elif isinstance(img, np.ndarray):
|
||||
return Image.fromarray(img)
|
||||
elif isinstance(img, Image.Image):
|
||||
return img
|
||||
else:
|
||||
raise ValueError("Invalid input type.")
|
||||
|
||||
|
||||
def load_labels(dataframe) -> list[str]:
|
||||
"""Load labels from a dataframe and process tag names.
|
||||
@ -148,6 +162,15 @@ class Result:
|
||||
string = [x[0] for x in string]
|
||||
return ", ".join(string)
|
||||
|
||||
@property
|
||||
def all_tags(self) -> list[str]:
|
||||
"""Return all tags as a list."""
|
||||
return [self.rating] + list(self.general_tags) + list(self.character_tags)
|
||||
|
||||
@property
|
||||
def all_tags_string(self) -> str:
|
||||
return ", ".join(self.all_tags)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a formatted string representation of the tags and their ratings."""
|
||||
|
||||
@ -241,7 +264,7 @@ class Tagger:
|
||||
self.model_target_size = height
|
||||
self.model = model
|
||||
|
||||
def prepare_image(self, image):
|
||||
def pil_to_cv2_numpy(self, image):
|
||||
"""Prepare the image for model input.
|
||||
|
||||
Args:
|
||||
@ -278,14 +301,14 @@ class Tagger:
|
||||
|
||||
def tag(
|
||||
self,
|
||||
image: Image.Image | list[Image.Image],
|
||||
image: Union[Input, List[Input]],
|
||||
general_threshold=0.35,
|
||||
character_threshold=0.9,
|
||||
) -> Result | list[Result]:
|
||||
"""Tag the image and return the results.
|
||||
|
||||
Args:
|
||||
image (PIL.Image | list[PIL.Image]): Input image or list of images.
|
||||
image (Union[Input, List[Input]]): Input image or list of images to tag.
|
||||
general_threshold (float): Threshold for general tags.
|
||||
character_threshold (float): Threshold for character tags.
|
||||
|
||||
@ -293,8 +316,10 @@ class Tagger:
|
||||
Result | list[Result]: Tagging results.
|
||||
"""
|
||||
started_at = time.time()
|
||||
images = [image] if isinstance(image, Image.Image) else image
|
||||
images = [self.prepare_image(img) for img in images]
|
||||
input_is_list = isinstance(image, list)
|
||||
images = image if isinstance(image, list) else [image]
|
||||
images = [to_pil(img) for img in images]
|
||||
images = [self.pil_to_cv2_numpy(img) for img in images]
|
||||
image_array = np.asarray(images, dtype=np.float32)
|
||||
input_name = self.model.get_inputs()[0].name
|
||||
label_name = self.model.get_outputs()[0].name
|
||||
@ -308,6 +333,8 @@ class Tagger:
|
||||
self.logger.info(
|
||||
f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds."
|
||||
)
|
||||
if input_is_list:
|
||||
return results
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user