refactor(wdtagger): remove rich dependency && simplify initialization logic

This commit is contained in:
Jianqi Pan 2024-11-26 22:55:37 +09:00
parent 9fb556c41c
commit b1776ef564

View File

@ -14,7 +14,6 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError
from PIL import Image
from PIL.ImageFile import ImageFile
from rich.console import Console
from timm.data import create_transform, resolve_data_config
from torch import Tensor, nn
from torch.nn import functional as F
@ -137,15 +136,7 @@ class Result:
character_tag: dict[str, float],
general_tag: dict[str, float],
) -> None:
"""Initialize the Result object to store tagging results.
Args:
preds (np.array): Predictions array from the model.
sep_tags (tuple): Tuple containing separated tags based on categories.
general_threshold (float): Threshold for general tags.
character_threshold (float): Threshold for character tags.
"""
"""Initialize the Result object with the tags and their ratings."""
self.general_tag_data = general_tag
self.character_tag_data = character_tag
self.rating_data = rating_data
@ -228,30 +219,9 @@ class Tagger:
self,
model_repo: str = "SmilingWolf/wd-swinv2-tagger-v3",
hf_token: str = HF_TOKEN,
console: Console | None = None,
*,
slient: bool = False,
) -> None:
"""Initialize the Tagger object with the model repository and tokens.
Args:
model_repo (str): Repository name on HuggingFace.
cache_dir (str, optional): Directory to cache the model. Defaults to None.
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
loglevel (int, optional): Logging level. Defaults to logging.INFO.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
providers (list, optional): List of providers for ONNX runtime. Defaults to None.
console (rich.console.Console, optional): Rich console object. Defaults to None.
"""
self.slient = slient
if not slient:
if not console:
from rich import get_console
self.console = get_console()
else:
self.console = console
self.logger = logging.getLogger("wdtagger")
"""Initialize the Tagger object with the model repository and tokens."""
self.logger = logging.getLogger("wdtagger")
self.model_target_size = None
self.hf_token = hf_token
self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -269,11 +239,10 @@ class Tagger:
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
"""
if not self.slient:
with self.console.status(f"Loading model '{model_repo}'..."):
self._do_load_model(model_repo)
else:
self._do_load_model(model_repo)
start_time = time.time()
self.logger.info("Loading model from %s", model_repo)
self._do_load_model(model_repo)
self.logger.info("Model loaded successfully in %.2fs", time.time() - start_time)
def _do_load_model(self, model_repo: str) -> None:
model: nn.Module = timm.create_model(f"hf-hub:{model_repo}").eval()
@ -354,13 +323,12 @@ class Tagger:
duration = time.time() - started_at
image_length = len(images)
if not self.slient:
self.logger.info(
"Tagged %d image%s in %.2fs",
image_length,
"s" if image_length > 1 else "",
duration,
)
self.logger.info(
"Tagged %d image%s in %.2fs",
image_length,
"s" if image_length > 1 else "",
duration,
)
if input_is_list:
return results
return results[0] if len(results) == 1 else results