feat(logger): use logger

This commit is contained in:
Jianqi Pan
2024-06-11 03:18:23 +09:00
parent ea88303c8a
commit f9ec9de157
2 changed files with 35 additions and 8 deletions

9
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,9 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnType": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
},
},
}

View File

@@ -1,3 +1,4 @@
import logging
import os
import time
from collections import OrderedDict
@@ -9,6 +10,7 @@ import pandas as pd
import rich
import rich.live
from PIL import Image
from rich.logging import RichHandler
# Access console for rich text and logging
console = rich.get_console()
@@ -67,7 +69,13 @@ def load_labels(dataframe) -> list[str]:
class Result:
def __init__(self, pred, sep_tags, general_threshold=0.35, character_threshold=0.9):
def __init__(
self,
pred,
sep_tags,
general_threshold=0.35,
character_threshold=0.9,
):
"""Initialize the Result object to store tagging results.
Args:
@@ -160,6 +168,7 @@ class Tagger:
model_repo="SmilingWolf/wd-swinv2-tagger-v3",
cache_dir=None,
hf_token=HF_TOKEN,
loglevel=logging.INFO,
):
"""Initialize the Tagger object with the model repository and tokens.
@@ -168,12 +177,17 @@ class Tagger:
cache_dir (str, optional): Directory to cache the model. Defaults to None.
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
"""
self.logger = logging.getLogger("wdtagger")
self.logger.setLevel(loglevel)
self.logger.addHandler(RichHandler())
self.model_target_size = None
self.cache_dir = cache_dir
self.hf_token = hf_token
self.load_model(model_repo, cache_dir, hf_token)
def load_model(self, model_repo, cache_dir=None, hf_token=None):
def load_model(
self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None
):
"""Load the model and tags from the specified repository.
Args:
@@ -188,6 +202,7 @@ class Tagger:
cache_dir=cache_dir,
use_auth_token=hf_token,
)
model_path = huggingface_hub.hf_hub_download(
model_repo,
MODEL_FILENAME,
@@ -197,8 +212,11 @@ class Tagger:
tags_df = pd.read_csv(csv_path)
self.sep_tags = load_labels(tags_df)
model = rt.InferenceSession(model_path)
options = rt.SessionOptions()
if num_threads:
options.intra_op_num_threads = num_threads
options.inter_op_num_threads = num_threads
model = rt.InferenceSession(model_path, options)
_, height, _, _ = model.get_inputs()[0].shape
self.model_target_size = height
self.model = model
@@ -266,9 +284,9 @@ class Tagger:
]
duration = time.time() - started_at
image_length = len(images)
console.log(f"Tagging {image_length} image{
's' if image_length > 1 else ''
} took {duration:.2f} seconds.")
self.logger.info(
f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds."
)
return results[0] if len(results) == 1 else results
@@ -278,4 +296,4 @@ if __name__ == "__main__":
tagger = Tagger()
image = Image.open("./tests/images/赤松楓.9d64b955.jpeg")
result = tagger.tag(image)
console.log(result)
tagger.logger.info(result)