✨ feat(logger): use logger
This commit is contained in:
parent
ea88303c8a
commit
f9ec9de157
9
.vscode/settings.json
vendored
Normal file
9
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "ms-python.black-formatter",
|
||||
"editor.formatOnType": true,
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": "explicit"
|
||||
},
|
||||
},
|
||||
}
|
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
@ -9,6 +10,7 @@ import pandas as pd
|
||||
import rich
|
||||
import rich.live
|
||||
from PIL import Image
|
||||
from rich.logging import RichHandler
|
||||
|
||||
# Access console for rich text and logging
|
||||
console = rich.get_console()
|
||||
@ -67,7 +69,13 @@ def load_labels(dataframe) -> list[str]:
|
||||
|
||||
|
||||
class Result:
|
||||
def __init__(self, pred, sep_tags, general_threshold=0.35, character_threshold=0.9):
|
||||
def __init__(
|
||||
self,
|
||||
pred,
|
||||
sep_tags,
|
||||
general_threshold=0.35,
|
||||
character_threshold=0.9,
|
||||
):
|
||||
"""Initialize the Result object to store tagging results.
|
||||
|
||||
Args:
|
||||
@ -160,6 +168,7 @@ class Tagger:
|
||||
model_repo="SmilingWolf/wd-swinv2-tagger-v3",
|
||||
cache_dir=None,
|
||||
hf_token=HF_TOKEN,
|
||||
loglevel=logging.INFO,
|
||||
):
|
||||
"""Initialize the Tagger object with the model repository and tokens.
|
||||
|
||||
@ -168,12 +177,17 @@ class Tagger:
|
||||
cache_dir (str, optional): Directory to cache the model. Defaults to None.
|
||||
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
|
||||
"""
|
||||
self.logger = logging.getLogger("wdtagger")
|
||||
self.logger.setLevel(loglevel)
|
||||
self.logger.addHandler(RichHandler())
|
||||
self.model_target_size = None
|
||||
self.cache_dir = cache_dir
|
||||
self.hf_token = hf_token
|
||||
self.load_model(model_repo, cache_dir, hf_token)
|
||||
|
||||
def load_model(self, model_repo, cache_dir=None, hf_token=None):
|
||||
def load_model(
|
||||
self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None
|
||||
):
|
||||
"""Load the model and tags from the specified repository.
|
||||
|
||||
Args:
|
||||
@ -188,6 +202,7 @@ class Tagger:
|
||||
cache_dir=cache_dir,
|
||||
use_auth_token=hf_token,
|
||||
)
|
||||
|
||||
model_path = huggingface_hub.hf_hub_download(
|
||||
model_repo,
|
||||
MODEL_FILENAME,
|
||||
@ -197,8 +212,11 @@ class Tagger:
|
||||
|
||||
tags_df = pd.read_csv(csv_path)
|
||||
self.sep_tags = load_labels(tags_df)
|
||||
|
||||
model = rt.InferenceSession(model_path)
|
||||
options = rt.SessionOptions()
|
||||
if num_threads:
|
||||
options.intra_op_num_threads = num_threads
|
||||
options.inter_op_num_threads = num_threads
|
||||
model = rt.InferenceSession(model_path, options)
|
||||
_, height, _, _ = model.get_inputs()[0].shape
|
||||
self.model_target_size = height
|
||||
self.model = model
|
||||
@ -266,9 +284,9 @@ class Tagger:
|
||||
]
|
||||
duration = time.time() - started_at
|
||||
image_length = len(images)
|
||||
console.log(f"Tagging {image_length} image{
|
||||
's' if image_length > 1 else ''
|
||||
} took {duration:.2f} seconds.")
|
||||
self.logger.info(
|
||||
f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds."
|
||||
)
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
|
||||
@ -278,4 +296,4 @@ if __name__ == "__main__":
|
||||
tagger = Tagger()
|
||||
image = Image.open("./tests/images/赤松楓.9d64b955.jpeg")
|
||||
result = tagger.tag(image)
|
||||
console.log(result)
|
||||
tagger.logger.info(result)
|
||||
|
Loading…
x
Reference in New Issue
Block a user