refactor(wdtagger): remove rich dependency && simplify initialization logic
This commit is contained in:
parent
9fb556c41c
commit
b1776ef564
@ -14,7 +14,6 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import HfHubHTTPError
|
||||
from PIL import Image
|
||||
from PIL.ImageFile import ImageFile
|
||||
from rich.console import Console
|
||||
from timm.data import create_transform, resolve_data_config
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
@ -137,15 +136,7 @@ class Result:
|
||||
character_tag: dict[str, float],
|
||||
general_tag: dict[str, float],
|
||||
) -> None:
|
||||
"""Initialize the Result object to store tagging results.
|
||||
|
||||
Args:
|
||||
preds (np.array): Predictions array from the model.
|
||||
sep_tags (tuple): Tuple containing separated tags based on categories.
|
||||
general_threshold (float): Threshold for general tags.
|
||||
character_threshold (float): Threshold for character tags.
|
||||
"""
|
||||
|
||||
"""Initialize the Result object with the tags and their ratings."""
|
||||
self.general_tag_data = general_tag
|
||||
self.character_tag_data = character_tag
|
||||
self.rating_data = rating_data
|
||||
@ -228,30 +219,9 @@ class Tagger:
|
||||
self,
|
||||
model_repo: str = "SmilingWolf/wd-swinv2-tagger-v3",
|
||||
hf_token: str = HF_TOKEN,
|
||||
console: Console | None = None,
|
||||
*,
|
||||
slient: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the Tagger object with the model repository and tokens.
|
||||
|
||||
Args:
|
||||
model_repo (str): Repository name on HuggingFace.
|
||||
cache_dir (str, optional): Directory to cache the model. Defaults to None.
|
||||
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
|
||||
loglevel (int, optional): Logging level. Defaults to logging.INFO.
|
||||
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
||||
providers (list, optional): List of providers for ONNX runtime. Defaults to None.
|
||||
console (rich.console.Console, optional): Rich console object. Defaults to None.
|
||||
"""
|
||||
self.slient = slient
|
||||
if not slient:
|
||||
if not console:
|
||||
from rich import get_console
|
||||
|
||||
self.console = get_console()
|
||||
else:
|
||||
self.console = console
|
||||
self.logger = logging.getLogger("wdtagger")
|
||||
"""Initialize the Tagger object with the model repository and tokens."""
|
||||
self.logger = logging.getLogger("wdtagger")
|
||||
self.model_target_size = None
|
||||
self.hf_token = hf_token
|
||||
self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@ -269,11 +239,10 @@ class Tagger:
|
||||
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
|
||||
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
||||
"""
|
||||
if not self.slient:
|
||||
with self.console.status(f"Loading model '{model_repo}'..."):
|
||||
self._do_load_model(model_repo)
|
||||
else:
|
||||
self._do_load_model(model_repo)
|
||||
start_time = time.time()
|
||||
self.logger.info("Loading model from %s", model_repo)
|
||||
self._do_load_model(model_repo)
|
||||
self.logger.info("Model loaded successfully in %.2fs", time.time() - start_time)
|
||||
|
||||
def _do_load_model(self, model_repo: str) -> None:
|
||||
model: nn.Module = timm.create_model(f"hf-hub:{model_repo}").eval()
|
||||
@ -354,13 +323,12 @@ class Tagger:
|
||||
|
||||
duration = time.time() - started_at
|
||||
image_length = len(images)
|
||||
if not self.slient:
|
||||
self.logger.info(
|
||||
"Tagged %d image%s in %.2fs",
|
||||
image_length,
|
||||
"s" if image_length > 1 else "",
|
||||
duration,
|
||||
)
|
||||
self.logger.info(
|
||||
"Tagged %d image%s in %.2fs",
|
||||
image_length,
|
||||
"s" if image_length > 1 else "",
|
||||
duration,
|
||||
)
|
||||
if input_is_list:
|
||||
return results
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
Loading…
x
Reference in New Issue
Block a user