3 Commits

Author SHA1 Message Date
fe5ad89a29 version: v0.10.1 2024-08-24 19:12:20 +09:00
c7fbdeb7b1 chore(type): better result type 2024-08-24 19:12:07 +09:00
5bcbf991e8 🔧 chore(mypy): remove stub 2024-07-29 23:54:24 +09:00
3 changed files with 13 additions and 69 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "wdtagger"
version = "0.10.0"
version = "0.10.1"
description = ""
authors = ["Jianqi Pan <jannchie@gmail.com>"]
readme = "README.md"
@ -31,9 +31,6 @@ line-length = 120
[tool.mypy]
ignore_missing_imports = true
[tool.poetry.package.include]
"stubs/**/*.pyi" = {}
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@ -1,61 +0,0 @@
from pathlib import Path
from typing import Any, Sequence, Union
import numpy as np
from _typeshed import Incomplete
from PIL import Image
Input = Union[np.ndarray, Image.Image, str, Path]
__all__ = ["Tagger"]
class Result:
general_tag_data: Incomplete
character_tag_data: Incomplete
rating_data: Incomplete
def __init__(self, pred, sep_tags, general_threshold: float = 0.35, character_threshold: float = 0.9) -> None: ...
@property
def general_tags(self): ...
@property
def character_tags(self): ...
@property
def rating(self): ...
@property
def general_tags_string(self) -> str: ...
@property
def character_tags_string(self) -> str: ...
@property
def all_tags(self) -> list[str]: ...
@property
def all_tags_string(self) -> str: ...
class Tagger:
console: Incomplete
logger: Incomplete
model_target_size: Incomplete
cache_dir: Incomplete
hf_token: Incomplete
def __init__(
self,
model_repo: str = "SmilingWolf/wd-swinv2-tagger-v3",
cache_dir: Incomplete | None = None,
hf_token=...,
loglevel=...,
num_threads: Incomplete | None = None,
providers: Incomplete | None = None,
console: Incomplete | None = None,
) -> None: ...
sep_tags: Incomplete
model: Incomplete
def load_model(
self,
model_repo,
cache_dir: Incomplete | None = None,
hf_token: Incomplete | None = None,
num_threads: int | None = None,
providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
): ...
def pil_to_cv2_numpy(self, image): ...
def tag(
self, image: Input | list[Input], general_threshold: float = 0.35, character_threshold: float = 0.9
) -> Result | list[Result]: ...

View File

@ -3,7 +3,7 @@ import os
import time
from collections import OrderedDict
from pathlib import Path
from typing import Any, List, Sequence, Union
from typing import Any, List, Literal, Sequence, Union, overload
import huggingface_hub
import numpy as np
@ -115,17 +115,17 @@ class Result:
self.rating_data = rating_data
@property
def general_tags(self):
def general_tags(self) -> tuple[str]:
"""Return general tags as a tuple."""
return tuple(self.general_tag_data.keys())
@property
def character_tags(self):
def character_tags(self) -> tuple[str]:
"""Return character tags as a tuple."""
return tuple(self.character_tag_data.keys())
@property
def rating(self):
def rating(self) -> Literal["general", "sensitive", "questionable", "explicit"]:
"""Return the highest rated tag."""
return max(self.rating_data, key=self.rating_data.get)
@ -307,6 +307,14 @@ class Tagger:
array = array[:, :, [2, 1, 0]]
return array
@overload
def tag(self, image: Input, general_threshold: float = 0.35, character_threshold: float = 0.9) -> Result: ...
@overload
def tag(
self, image: List[Input], general_threshold: float = 0.35, character_threshold: float = 0.9
) -> List[Result]: ...
def tag(
self,
image: Union[Input, List[Input]],