Compare commits
8 Commits
Author | SHA1 | Date | |
---|---|---|---|
874132775e | |||
33a690a109 | |||
3e02bbbba5 | |||
e39c4e9799 | |||
314787404f | |||
243d1c802a | |||
b461680bc0 | |||
458a0d9410 |
@ -1,24 +1,19 @@
|
||||
[project]
|
||||
name = "wdtagger"
|
||||
version = "0.11.2"
|
||||
version = "0.13.1"
|
||||
description = "A simple and easy-to-use wrapper for the tagger model created by [SmilingWolf](https://github.com/SmilingWolf) which is specifically designed for tagging anime illustrations."
|
||||
authors = [{ name = "Jianqi Pan", email = "jannchie@gmail.com" }]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"huggingface-hub>=0.26.2",
|
||||
"numpy<2",
|
||||
"pandas>=2.2.3",
|
||||
"pillow>=11.0.0",
|
||||
]
|
||||
dependencies = ["huggingface-hub>=0.26.2", "pandas>=2.2.3", "pillow>=11.0.0"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = ["pytest>=8.3.3", "pytest-benchmark>=5.1.0", "ruff>=0.8.0"]
|
||||
|
||||
|
||||
[project.optional-dependencies]
|
||||
cpu = ["torch>=2.5.1", "torchvision>=0.20.1", "timm>=1.0.11"]
|
||||
gpu = ["torch>=2.5.1", "torchvision>=0.20.1", "timm>=1.0.11"]
|
||||
cpu = ["torch>=2.0.0", "torchvision>=0.20.1", "timm>=1.0.11"]
|
||||
gpu = ["torch>=2.0.0", "torchvision>=0.20.1", "timm>=1.0.11"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 140
|
||||
@ -36,21 +31,37 @@ ignore = [
|
||||
"N812",
|
||||
]
|
||||
|
||||
[tool.ruff.lint.extend-per-file-ignores]
|
||||
"tests/**/*.py" = [
|
||||
"S101", # asserts allowed in tests
|
||||
"PLR2004", # Magic value used in comparison
|
||||
]
|
||||
[tool.pyright]
|
||||
|
||||
[tool.uv]
|
||||
conflicts = [[{ extra = "cpu" }, { extra = "gpu" }]]
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "cpu" },
|
||||
{ extra = "gpu" },
|
||||
{ extra = "cuda11" },
|
||||
{ extra = "cuda12" },
|
||||
],
|
||||
]
|
||||
package = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "torch-cpu", extra = "cpu" },
|
||||
{ index = "torch-gpu", extra = "gpu" },
|
||||
{ index = "torch-cuda11", extra = "cuda11" },
|
||||
{ index = "torch-cuda12", extra = "cuda12" },
|
||||
]
|
||||
|
||||
torchvision = [
|
||||
{ index = "torch-cpu", extra = "cpu" },
|
||||
{ index = "torch-gpu", extra = "gpu" },
|
||||
{ index = "torch-cuda11", extra = "cuda11" },
|
||||
{ index = "torch-cuda12", extra = "cuda12" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
@ -63,6 +74,16 @@ name = "torch-gpu"
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "torch-cuda11"
|
||||
url = "https://download.pytorch.org/whl/cu118"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "torch-cuda12"
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
explicit = true
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
@ -22,13 +22,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
||||
LABEL_FILENAME = "selected_tags.csv"
|
||||
|
||||
MODEL_REPO_MAP = {
|
||||
"vit": "SmilingWolf/wd-vit-tagger-v3",
|
||||
"swinv2": "SmilingWolf/wd-swinv2-tagger-v3",
|
||||
"convnext": "SmilingWolf/wd-convnext-tagger-v3",
|
||||
}
|
||||
|
||||
|
||||
Input = np.ndarray | Image.Image | str | Path | ImageFile
|
||||
@ -43,8 +36,9 @@ class LabelData:
|
||||
|
||||
|
||||
def load_labels() -> LabelData:
|
||||
with importlib.resources.path("wdtagger.assets", "selected_tags.csv") as tags_path:
|
||||
df: pd.DataFrame = pd.read_csv(tags_path, usecols=["name", "category"])
|
||||
file = importlib.resources.as_file(importlib.resources.files("wdtagger.assets").joinpath("selected_tags.csv"))
|
||||
with file as f:
|
||||
df: pd.DataFrame = pd.read_csv(f, usecols=["name", "category"])
|
||||
rating_catagory_idx = 9
|
||||
general_catagory_idx = 0
|
||||
character_catagory_idx = 4
|
||||
@ -273,8 +267,7 @@ class Tagger:
|
||||
Result | list[Result]: Tagging results.
|
||||
"""
|
||||
started_at = time.time()
|
||||
input_is_list = isinstance(image, list)
|
||||
images = list(image) if isinstance(image, Sequence) else [image]
|
||||
images = list(image) if isinstance(image, Sequence) and not isinstance(image, str) else [image]
|
||||
images = [to_pil(img) for img in images]
|
||||
images = [pil_ensure_rgb(img) for img in images]
|
||||
images = [pil_pad_square(img) for img in images]
|
||||
@ -314,7 +307,7 @@ class Tagger:
|
||||
"s" if image_length > 1 else "",
|
||||
duration,
|
||||
)
|
||||
if input_is_list:
|
||||
if isinstance(image, Sequence) and not isinstance(image, str):
|
||||
return results
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
@ -6,54 +6,60 @@ from wdtagger import Tagger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tagger():
|
||||
def tagger() -> Tagger:
|
||||
"""
|
||||
Create and return a new instance of the Tagger class.
|
||||
|
||||
Returns:
|
||||
Tagger: An instance of the Tagger class.
|
||||
"""
|
||||
return Tagger()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_file():
|
||||
def image_file() -> str:
|
||||
return "./tests/images/赤松楓.9d64b955.jpeg"
|
||||
|
||||
|
||||
def test_tagger(tagger, image_file):
|
||||
def test_tagger(tagger: Tagger, image_file: str) -> None:
|
||||
image = Image.open(image_file)
|
||||
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.character_tags_string == "akamatsu_kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_path(tagger, image_file):
|
||||
def test_tagger_path_single(tagger: Tagger, image_file: str) -> None:
|
||||
result = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.character_tags_string == "akamatsu_kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_np(tagger, image_file):
|
||||
def test_tagger_np(tagger: Tagger, image_file: str) -> None:
|
||||
image = Image.open(image_file)
|
||||
image_np = np.array(image)
|
||||
result = tagger.tag(image_np, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.character_tags_string == "akamatsu_kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_pil(tagger, image_file):
|
||||
def test_tagger_pil(tagger: Tagger, image_file: str) -> None:
|
||||
image = Image.open(image_file)
|
||||
result = tagger.tag(image, character_threshold=0.85, general_threshold=0.35)
|
||||
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.character_tags_string == "akamatsu_kaede"
|
||||
assert result.rating == "general"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_file", [["./tests/images/赤松楓.9d64b955.jpeg"]])
|
||||
def test_tagger_np_single(tagger, image_file):
|
||||
results = tagger.tag(image_file, character_threshold=0.85, general_threshold=0.35)
|
||||
assert len(results) == 1
|
||||
@pytest.mark.parametrize("image_file", ["./tests/images/赤松楓.9d64b955.jpeg"])
|
||||
def test_tagger_path_multi(tagger: Tagger, image_file: str) -> None:
|
||||
results = tagger.tag([image_file, image_file], character_threshold=0.85, general_threshold=0.35)
|
||||
assert len(results) == 2
|
||||
result = results[0]
|
||||
assert result.character_tags_string == "akamatsu kaede"
|
||||
assert result.character_tags_string == "akamatsu_kaede"
|
||||
assert result.rating == "general"
|
||||
|
Reference in New Issue
Block a user