refactor(tagger): remove unused script && update lint config && optimize imports && fix test cases

This commit is contained in:
Jianqi Pan 2025-02-21 22:35:59 +09:00
parent 314787404f
commit e39c4e9799
5 changed files with 32 additions and 37 deletions

10
a.py
View File

@ -1,10 +0,0 @@
from PIL import Image
from wdtagger import Tagger
if __name__ == "__main__":
tagger = Tagger()
images = [
Image.open("./tests/images/赤松楓.9d64b955.jpeg"),
]
results = tagger.tag(images)

View File

@ -36,6 +36,12 @@ ignore = [
"N812",
]
[tool.ruff.lint.extend-per-file-ignores]
"tests/**/*.py" = [
# at least this three should be fine in tests:
"S101", # asserts allowed in tests
"PLR2004", # Magic value used in comparison
]
[tool.pyright]
[tool.uv]

View File

@ -22,13 +22,6 @@ if TYPE_CHECKING:
HF_TOKEN = os.environ.get("HF_TOKEN", "")
LABEL_FILENAME = "selected_tags.csv"
MODEL_REPO_MAP = {
"vit": "SmilingWolf/wd-vit-tagger-v3",
"swinv2": "SmilingWolf/wd-swinv2-tagger-v3",
"convnext": "SmilingWolf/wd-convnext-tagger-v3",
}
Input = np.ndarray | Image.Image | str | Path | ImageFile
@ -43,8 +36,9 @@ class LabelData:
def load_labels() -> LabelData:
with importlib.resources.path("wdtagger.assets", "selected_tags.csv") as tags_path:
df: pd.DataFrame = pd.read_csv(tags_path, usecols=["name", "category"])
file = importlib.resources.as_file(importlib.resources.files("wdtagger.assets").joinpath("selected_tags.csv"))
with file as f:
df: pd.DataFrame = pd.read_csv(f, usecols=["name", "category"])
rating_catagory_idx = 9
general_catagory_idx = 0
character_catagory_idx = 4
@ -273,8 +267,7 @@ class Tagger:
Result | list[Result]: Tagging results.
"""
started_at = time.time()
input_is_list = isinstance(image, list)
images = list(image) if isinstance(image, Sequence) else [image]
images = list(image) if isinstance(image, Sequence) and not isinstance(image, str) else [image]
images = [to_pil(img) for img in images]
images = [pil_ensure_rgb(img) for img in images]
images = [pil_pad_square(img) for img in images]
@ -314,7 +307,7 @@ class Tagger:
"s" if image_length > 1 else "",
duration,
)
if input_is_list:
if isinstance(image, Sequence) and not isinstance(image, str):
return results
return results[0] if len(results) == 1 else results

0
tests/__init__.py Normal file
View File

View File

@ -6,54 +6,60 @@ from wdtagger import Tagger
@pytest.fixture
def tagger():
def tagger() -> Tagger:
"""
Create and return a new instance of the Tagger class.
Returns:
Tagger: An instance of the Tagger class.
"""
return Tagger()
@pytest.fixture
def image_file():
def image_file() -> str:
return "./tests/images/赤松楓.9d64b955.jpeg"
def test_tagger(tagger, image_file):
def test_tagger(tagger: Tagger, image_file: str) -> None:
image = Image.open(image_file)
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
assert result.character_tags_string == "akamatsu kaede"
assert result.character_tags_string == "akamatsu_kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
def test_tagger_path(tagger, image_file):
def test_tagger_path_single(tagger: Tagger, image_file: str) -> None:
result = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
assert result.character_tags_string == "akamatsu kaede"
assert result.character_tags_string == "akamatsu_kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
def test_tagger_np(tagger, image_file):
def test_tagger_np(tagger: Tagger, image_file: str) -> None:
image = Image.open(image_file)
image_np = np.array(image)
result = tagger.tag(image_np, character_threshold=0.85, general_threshold=0.35)
assert result.character_tags_string == "akamatsu kaede"
assert result.character_tags_string == "akamatsu_kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
def test_tagger_pil(tagger, image_file):
def test_tagger_pil(tagger: Tagger, image_file: str) -> None:
image = Image.open(image_file)
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
assert result.character_tags_string == "akamatsu kaede"
assert result.character_tags_string == "akamatsu_kaede"
assert result.rating == "general"
@pytest.mark.parametrize("image_file", [["./tests/images/赤松楓.9d64b955.jpeg"]])
def test_tagger_np_single(tagger, image_file):
results = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
assert len(results) == 1
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
def test_tagger_path_multi(tagger: Tagger, image_file: str) -> None:
results = tagger.tag([image_file, image_file], character_threshold=0.85, general_threshold=0.35)
assert len(results) == 2
result = results[0]
assert result.character_tags_string == "akamatsu kaede"
assert result.character_tags_string == "akamatsu_kaede"
assert result.rating == "general"