wdtagger/tests/test_tagger.py

66 lines
2.1 KiB
Python

import numpy as np
import pytest
from PIL import Image
from wdtagger import Tagger
@pytest.fixture
def tagger() -> Tagger:
"""
Create and return a new instance of the Tagger class.
Returns:
Tagger: An instance of the Tagger class.
"""
return Tagger()
@pytest.fixture
def image_file() -> str:
return "./tests/images/赤松楓.9d64b955.jpeg"
def test_tagger(tagger: Tagger, image_file: str) -> None:
image = Image.open(image_file)
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
assert result.character_tags_string == "akamatsu_kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
def test_tagger_path_single(tagger: Tagger, image_file: str) -> None:
result = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
assert result.character_tags_string == "akamatsu_kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
def test_tagger_np(tagger: Tagger, image_file: str) -> None:
image = Image.open(image_file)
image_np = np.array(image)
result = tagger.tag(image_np, character_threshold=0.85, general_threshold=0.35)
assert result.character_tags_string == "akamatsu_kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
def test_tagger_pil(tagger: Tagger, image_file: str) -> None:
image = Image.open(image_file)
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
assert result.character_tags_string == "akamatsu_kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
def test_tagger_path_multi(tagger: Tagger, image_file: str) -> None:
results = tagger.tag([image_file, image_file], character_threshold=0.85, general_threshold=0.35)
assert len(results) == 2
result = results[0]
assert result.character_tags_string == "akamatsu_kaede"
assert result.rating == "general"