refactor(tagger): remove unused script && update lint config && optimize imports && fix test cases
This commit is contained in:
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
@ -6,54 +6,60 @@ from wdtagger import Tagger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tagger():
|
||||
def tagger() -> Tagger:
|
||||
"""
|
||||
Create and return a new instance of the Tagger class.
|
||||
|
||||
Returns:
|
||||
Tagger: An instance of the Tagger class.
|
||||
"""
|
||||
return Tagger()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_file():
|
||||
def image_file() -> str:
|
||||
return "./tests/images/赤松楓.9d64b955.jpeg"
|
||||
|
||||
|
||||
def test_tagger(tagger, image_file):
|
||||
def test_tagger(tagger: Tagger, image_file: str) -> None:
|
||||
image = Image.open(image_file)
|
||||
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.character_tags_string == "akamatsu_kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_path(tagger, image_file):
|
||||
def test_tagger_path_single(tagger: Tagger, image_file: str) -> None:
|
||||
result = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.character_tags_string == "akamatsu_kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_np(tagger, image_file):
|
||||
def test_tagger_np(tagger: Tagger, image_file: str) -> None:
|
||||
image = Image.open(image_file)
|
||||
image_np = np.array(image)
|
||||
result = tagger.tag(image_np, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.character_tags_string == "akamatsu_kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_pil(tagger, image_file):
|
||||
def test_tagger_pil(tagger: Tagger, image_file: str) -> None:
|
||||
image = Image.open(image_file)
|
||||
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.character_tags_string == "akamatsu_kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", [["./tests/images/赤松楓.9d64b955.jpeg"]])
|
||||
def test_tagger_np_single(tagger, image_file):
|
||||
results = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
|
||||
assert len(results) == 1
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_path_multi(tagger: Tagger, image_file: str) -> None:
|
||||
results = tagger.tag([image_file, image_file], character_threshold=0.85, general_threshold=0.35)
|
||||
assert len(results) == 2
|
||||
result = results[0]
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.character_tags_string == "akamatsu_kaede"
|
||||
assert result.rating == "general"
|
||||
|
Reference in New Issue
Block a user