Compare commits
6 Commits
Author | SHA1 | Date | |
---|---|---|---|
447ad99aea | |||
52931f2a59 | |||
ec1f0b6e26 | |||
fe5ad89a29 | |||
c7fbdeb7b1 | |||
5bcbf991e8 |
951
poetry.lock
generated
951
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,13 +1,12 @@
|
||||
[tool.poetry]
|
||||
name = "wdtagger"
|
||||
version = "0.10.0"
|
||||
version = "0.10.2"
|
||||
description = ""
|
||||
authors = ["Jianqi Pan <jannchie@gmail.com>"]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
pillow = "^10.3.0"
|
||||
pillow = ">=10.3.0,<12.0.0"
|
||||
pandas = "^2.2.2"
|
||||
huggingface-hub = "^0.23.3"
|
||||
rich = "^13.7.1"
|
||||
@ -31,9 +30,6 @@ line-length = 120
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.poetry.package.include]
|
||||
"stubs/**/*.pyi" = {}
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
@ -1,61 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from _typeshed import Incomplete
|
||||
from PIL import Image
|
||||
|
||||
Input = Union[np.ndarray, Image.Image, str, Path]
|
||||
|
||||
__all__ = ["Tagger"]
|
||||
|
||||
class Result:
|
||||
general_tag_data: Incomplete
|
||||
character_tag_data: Incomplete
|
||||
rating_data: Incomplete
|
||||
def __init__(self, pred, sep_tags, general_threshold: float = 0.35, character_threshold: float = 0.9) -> None: ...
|
||||
@property
|
||||
def general_tags(self): ...
|
||||
@property
|
||||
def character_tags(self): ...
|
||||
@property
|
||||
def rating(self): ...
|
||||
@property
|
||||
def general_tags_string(self) -> str: ...
|
||||
@property
|
||||
def character_tags_string(self) -> str: ...
|
||||
@property
|
||||
def all_tags(self) -> list[str]: ...
|
||||
@property
|
||||
def all_tags_string(self) -> str: ...
|
||||
|
||||
class Tagger:
|
||||
console: Incomplete
|
||||
logger: Incomplete
|
||||
model_target_size: Incomplete
|
||||
cache_dir: Incomplete
|
||||
hf_token: Incomplete
|
||||
def __init__(
|
||||
self,
|
||||
model_repo: str = "SmilingWolf/wd-swinv2-tagger-v3",
|
||||
cache_dir: Incomplete | None = None,
|
||||
hf_token=...,
|
||||
loglevel=...,
|
||||
num_threads: Incomplete | None = None,
|
||||
providers: Incomplete | None = None,
|
||||
console: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
sep_tags: Incomplete
|
||||
model: Incomplete
|
||||
def load_model(
|
||||
self,
|
||||
model_repo,
|
||||
cache_dir: Incomplete | None = None,
|
||||
hf_token: Incomplete | None = None,
|
||||
num_threads: int | None = None,
|
||||
providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
|
||||
): ...
|
||||
def pil_to_cv2_numpy(self, image): ...
|
||||
def tag(
|
||||
self, image: Input | list[Input], general_threshold: float = 0.35, character_threshold: float = 0.9
|
||||
) -> Result | list[Result]: ...
|
@ -3,7 +3,7 @@ import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Sequence, Union
|
||||
from typing import Any, List, Literal, Sequence, Union, overload
|
||||
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
@ -115,17 +115,17 @@ class Result:
|
||||
self.rating_data = rating_data
|
||||
|
||||
@property
|
||||
def general_tags(self):
|
||||
def general_tags(self) -> tuple[str]:
|
||||
"""Return general tags as a tuple."""
|
||||
return tuple(self.general_tag_data.keys())
|
||||
|
||||
@property
|
||||
def character_tags(self):
|
||||
def character_tags(self) -> tuple[str]:
|
||||
"""Return character tags as a tuple."""
|
||||
return tuple(self.character_tag_data.keys())
|
||||
|
||||
@property
|
||||
def rating(self):
|
||||
def rating(self) -> Literal["general", "sensitive", "questionable", "explicit"]:
|
||||
"""Return the highest rated tag."""
|
||||
return max(self.rating_data, key=self.rating_data.get)
|
||||
|
||||
@ -307,6 +307,14 @@ class Tagger:
|
||||
array = array[:, :, [2, 1, 0]]
|
||||
return array
|
||||
|
||||
@overload
|
||||
def tag(self, image: Input, general_threshold: float = 0.35, character_threshold: float = 0.9) -> Result: ...
|
||||
|
||||
@overload
|
||||
def tag(
|
||||
self, image: List[Input], general_threshold: float = 0.35, character_threshold: float = 0.9
|
||||
) -> List[Result]: ...
|
||||
|
||||
def tag(
|
||||
self,
|
||||
image: Union[Input, List[Input]],
|
||||
|
Reference in New Issue
Block a user