From e39c4e979928242f8958b473557aaeaf561ad904 Mon Sep 17 00:00:00 2001 From: Jianqi Pan Date: Fri, 21 Feb 2025 22:35:59 +0900 Subject: [PATCH] refactor(tagger): remove unused script && update lint config && optimize imports && fix test cases --- a.py | 10 ---------- pyproject.toml | 6 ++++++ src/wdtagger/__init__.py | 17 +++++------------ tests/__init__.py | 0 tests/test_tagger.py | 36 +++++++++++++++++++++--------------- 5 files changed, 32 insertions(+), 37 deletions(-) delete mode 100644 a.py create mode 100644 tests/__init__.py diff --git a/a.py b/a.py deleted file mode 100644 index d63ab8e..0000000 --- a/a.py +++ /dev/null @@ -1,10 +0,0 @@ -from PIL import Image - -from wdtagger import Tagger - -if __name__ == "__main__": - tagger = Tagger() - images = [ - Image.open("./tests/images/赤松楓.9d64b955.jpeg"), - ] - results = tagger.tag(images) diff --git a/pyproject.toml b/pyproject.toml index 9e43372..39de255 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,12 @@ ignore = [ "N812", ] +[tool.ruff.lint.extend-per-file-ignores] +"tests/**/*.py" = [ + # at least this three should be fine in tests: + "S101", # asserts allowed in tests + "PLR2004", # Magic value used in comparison +] [tool.pyright] [tool.uv] diff --git a/src/wdtagger/__init__.py b/src/wdtagger/__init__.py index 136572a..4a5999a 100644 --- a/src/wdtagger/__init__.py +++ b/src/wdtagger/__init__.py @@ -22,13 +22,6 @@ if TYPE_CHECKING: HF_TOKEN = os.environ.get("HF_TOKEN", "") -LABEL_FILENAME = "selected_tags.csv" - -MODEL_REPO_MAP = { - "vit": "SmilingWolf/wd-vit-tagger-v3", - "swinv2": "SmilingWolf/wd-swinv2-tagger-v3", - "convnext": "SmilingWolf/wd-convnext-tagger-v3", -} Input = np.ndarray | Image.Image | str | Path | ImageFile @@ -43,8 +36,9 @@ class LabelData: def load_labels() -> LabelData: - with importlib.resources.path("wdtagger.assets", "selected_tags.csv") as tags_path: - df: pd.DataFrame = pd.read_csv(tags_path, usecols=["name", "category"]) + file = importlib.resources.as_file(importlib.resources.files("wdtagger.assets").joinpath("selected_tags.csv")) + with file as f: + df: pd.DataFrame = pd.read_csv(f, usecols=["name", "category"]) rating_catagory_idx = 9 general_catagory_idx = 0 character_catagory_idx = 4 @@ -273,8 +267,7 @@ class Tagger: Result | list[Result]: Tagging results. """ started_at = time.time() - input_is_list = isinstance(image, list) - images = list(image) if isinstance(image, Sequence) else [image] + images = list(image) if isinstance(image, Sequence) and not isinstance(image, str) else [image] images = [to_pil(img) for img in images] images = [pil_ensure_rgb(img) for img in images] images = [pil_pad_square(img) for img in images] @@ -314,7 +307,7 @@ class Tagger: "s" if image_length > 1 else "", duration, ) - if input_is_list: + if isinstance(image, Sequence) and not isinstance(image, str): return results return results[0] if len(results) == 1 else results diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_tagger.py b/tests/test_tagger.py index 09a0cca..75ef620 100644 --- a/tests/test_tagger.py +++ b/tests/test_tagger.py @@ -6,54 +6,60 @@ from wdtagger import Tagger @pytest.fixture -def tagger(): +def tagger() -> Tagger: + """ + Create and return a new instance of the Tagger class. + + Returns: + Tagger: An instance of the Tagger class. + """ return Tagger() @pytest.fixture -def image_file(): +def image_file() -> str: return "./tests/images/赤松楓.9d64b955.jpeg" -def test_tagger(tagger, image_file): +def test_tagger(tagger: Tagger, image_file: str) -> None: image = Image.open(image_file) result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35) - assert result.character_tags_string == "akamatsu kaede" + assert result.character_tags_string == "akamatsu_kaede" assert result.rating == "general" @pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"]) -def test_tagger_path(tagger, image_file): +def test_tagger_path_single(tagger: Tagger, image_file: str) -> None: result = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35) - assert result.character_tags_string == "akamatsu kaede" + assert result.character_tags_string == "akamatsu_kaede" assert result.rating == "general" @pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"]) -def test_tagger_np(tagger, image_file): +def test_tagger_np(tagger: Tagger, image_file: str) -> None: image = Image.open(image_file) image_np = np.array(image) result = tagger.tag(image_np, character_threshold=0.85, general_threshold=0.35) - assert result.character_tags_string == "akamatsu kaede" + assert result.character_tags_string == "akamatsu_kaede" assert result.rating == "general" @pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"]) -def test_tagger_pil(tagger, image_file): +def test_tagger_pil(tagger: Tagger, image_file: str) -> None: image = Image.open(image_file) result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35) - assert result.character_tags_string == "akamatsu kaede" + assert result.character_tags_string == "akamatsu_kaede" assert result.rating == "general" -@pytest.mark.parametrize("image_file", [["./tests/images/赤松楓.9d64b955.jpeg"]]) -def test_tagger_np_single(tagger, image_file): - results = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35) - assert len(results) == 1 +@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"]) +def test_tagger_path_multi(tagger: Tagger, image_file: str) -> None: + results = tagger.tag([image_file, image_file], character_threshold=0.85, general_threshold=0.35) + assert len(results) == 2 result = results[0] - assert result.character_tags_string == "akamatsu kaede" + assert result.character_tags_string == "akamatsu_kaede" assert result.rating == "general"