diff --git a/src/wdtagger/__init__.py b/src/wdtagger/__init__.py index e6f9482..e99f8b9 100644 --- a/src/wdtagger/__init__.py +++ b/src/wdtagger/__init__.py @@ -14,7 +14,6 @@ from huggingface_hub import hf_hub_download from huggingface_hub.utils import HfHubHTTPError from PIL import Image from PIL.ImageFile import ImageFile -from rich.console import Console from timm.data import create_transform, resolve_data_config from torch import Tensor, nn from torch.nn import functional as F @@ -137,15 +136,7 @@ class Result: character_tag: dict[str, float], general_tag: dict[str, float], ) -> None: - """Initialize the Result object to store tagging results. - - Args: - preds (np.array): Predictions array from the model. - sep_tags (tuple): Tuple containing separated tags based on categories. - general_threshold (float): Threshold for general tags. - character_threshold (float): Threshold for character tags. - """ - + """Initialize the Result object with the tags and their ratings.""" self.general_tag_data = general_tag self.character_tag_data = character_tag self.rating_data = rating_data @@ -228,30 +219,9 @@ class Tagger: self, model_repo: str = "SmilingWolf/wd-swinv2-tagger-v3", hf_token: str = HF_TOKEN, - console: Console | None = None, - *, - slient: bool = False, ) -> None: - """Initialize the Tagger object with the model repository and tokens. - - Args: - model_repo (str): Repository name on HuggingFace. - cache_dir (str, optional): Directory to cache the model. Defaults to None. - hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN. - loglevel (int, optional): Logging level. Defaults to logging.INFO. - num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None. - providers (list, optional): List of providers for ONNX runtime. Defaults to None. - console (rich.console.Console, optional): Rich console object. Defaults to None. - """ - self.slient = slient - if not slient: - if not console: - from rich import get_console - - self.console = get_console() - else: - self.console = console - self.logger = logging.getLogger("wdtagger") + """Initialize the Tagger object with the model repository and tokens.""" + self.logger = logging.getLogger("wdtagger") self.model_target_size = None self.hf_token = hf_token self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -269,11 +239,10 @@ class Tagger: hf_token (str, optional): HuggingFace token for authentication. Defaults to None. num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None. """ - if not self.slient: - with self.console.status(f"Loading model '{model_repo}'..."): - self._do_load_model(model_repo) - else: - self._do_load_model(model_repo) + start_time = time.time() + self.logger.info("Loading model from %s", model_repo) + self._do_load_model(model_repo) + self.logger.info("Model loaded successfully in %.2fs", time.time() - start_time) def _do_load_model(self, model_repo: str) -> None: model: nn.Module = timm.create_model(f"hf-hub:{model_repo}").eval() @@ -354,13 +323,12 @@ class Tagger: duration = time.time() - started_at image_length = len(images) - if not self.slient: - self.logger.info( - "Tagged %d image%s in %.2fs", - image_length, - "s" if image_length > 1 else "", - duration, - ) + self.logger.info( + "Tagged %d image%s in %.2fs", + image_length, + "s" if image_length > 1 else "", + duration, + ) if input_is_list: return results return results[0] if len(results) == 1 else results