From f9ec9de1576642da8aa0bb99a4bb16c3587daaae Mon Sep 17 00:00:00 2001 From: Jianqi Pan Date: Tue, 11 Jun 2024 03:18:23 +0900 Subject: [PATCH] :sparkles: feat(logger): use logger --- .vscode/settings.json | 9 +++++++++ wdtagger/__init__.py | 34 ++++++++++++++++++++++++++-------- 2 files changed, 35 insertions(+), 8 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..b473ead --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,9 @@ +{ + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.formatOnType": true, + "editor.codeActionsOnSave": { + "source.organizeImports": "explicit" + }, + }, +} \ No newline at end of file diff --git a/wdtagger/__init__.py b/wdtagger/__init__.py index 72f8868..3494ba1 100644 --- a/wdtagger/__init__.py +++ b/wdtagger/__init__.py @@ -1,3 +1,4 @@ +import logging import os import time from collections import OrderedDict @@ -9,6 +10,7 @@ import pandas as pd import rich import rich.live from PIL import Image +from rich.logging import RichHandler # Access console for rich text and logging console = rich.get_console() @@ -67,7 +69,13 @@ def load_labels(dataframe) -> list[str]: class Result: - def __init__(self, pred, sep_tags, general_threshold=0.35, character_threshold=0.9): + def __init__( + self, + pred, + sep_tags, + general_threshold=0.35, + character_threshold=0.9, + ): """Initialize the Result object to store tagging results. Args: @@ -160,6 +168,7 @@ class Tagger: model_repo="SmilingWolf/wd-swinv2-tagger-v3", cache_dir=None, hf_token=HF_TOKEN, + loglevel=logging.INFO, ): """Initialize the Tagger object with the model repository and tokens. @@ -168,12 +177,17 @@ class Tagger: cache_dir (str, optional): Directory to cache the model. Defaults to None. hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN. """ + self.logger = logging.getLogger("wdtagger") + self.logger.setLevel(loglevel) + self.logger.addHandler(RichHandler()) self.model_target_size = None self.cache_dir = cache_dir self.hf_token = hf_token self.load_model(model_repo, cache_dir, hf_token) - def load_model(self, model_repo, cache_dir=None, hf_token=None): + def load_model( + self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None + ): """Load the model and tags from the specified repository. Args: @@ -188,6 +202,7 @@ class Tagger: cache_dir=cache_dir, use_auth_token=hf_token, ) + model_path = huggingface_hub.hf_hub_download( model_repo, MODEL_FILENAME, @@ -197,8 +212,11 @@ class Tagger: tags_df = pd.read_csv(csv_path) self.sep_tags = load_labels(tags_df) - - model = rt.InferenceSession(model_path) + options = rt.SessionOptions() + if num_threads: + options.intra_op_num_threads = num_threads + options.inter_op_num_threads = num_threads + model = rt.InferenceSession(model_path, options) _, height, _, _ = model.get_inputs()[0].shape self.model_target_size = height self.model = model @@ -266,9 +284,9 @@ class Tagger: ] duration = time.time() - started_at image_length = len(images) - console.log(f"Tagging {image_length} image{ - 's' if image_length > 1 else '' - } took {duration:.2f} seconds.") + self.logger.info( + f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds." + ) return results[0] if len(results) == 1 else results @@ -278,4 +296,4 @@ if __name__ == "__main__": tagger = Tagger() image = Image.open("./tests/images/赤松楓.9d64b955.jpeg") result = tagger.tag(image) - console.log(result) + tagger.logger.info(result)