|
|
|
@ -1,3 +1,4 @@
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
@ -9,6 +10,7 @@ import pandas as pd
|
|
|
|
|
import rich
|
|
|
|
|
import rich.live
|
|
|
|
|
from PIL import Image
|
|
|
|
|
from rich.logging import RichHandler
|
|
|
|
|
|
|
|
|
|
# Access console for rich text and logging
|
|
|
|
|
console = rich.get_console()
|
|
|
|
@ -67,7 +69,13 @@ def load_labels(dataframe) -> list[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Result:
|
|
|
|
|
def __init__(self, pred, sep_tags, general_threshold=0.35, character_threshold=0.9):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
pred,
|
|
|
|
|
sep_tags,
|
|
|
|
|
general_threshold=0.35,
|
|
|
|
|
character_threshold=0.9,
|
|
|
|
|
):
|
|
|
|
|
"""Initialize the Result object to store tagging results.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -160,6 +168,8 @@ class Tagger:
|
|
|
|
|
model_repo="SmilingWolf/wd-swinv2-tagger-v3",
|
|
|
|
|
cache_dir=None,
|
|
|
|
|
hf_token=HF_TOKEN,
|
|
|
|
|
loglevel=logging.INFO,
|
|
|
|
|
num_threads=None,
|
|
|
|
|
):
|
|
|
|
|
"""Initialize the Tagger object with the model repository and tokens.
|
|
|
|
|
|
|
|
|
@ -167,19 +177,27 @@ class Tagger:
|
|
|
|
|
model_repo (str): Repository name on HuggingFace.
|
|
|
|
|
cache_dir (str, optional): Directory to cache the model. Defaults to None.
|
|
|
|
|
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
|
|
|
|
|
loglevel (int, optional): Logging level. Defaults to logging.INFO.
|
|
|
|
|
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
|
|
|
|
"""
|
|
|
|
|
self.logger = logging.getLogger("wdtagger")
|
|
|
|
|
self.logger.setLevel(loglevel)
|
|
|
|
|
self.logger.addHandler(RichHandler())
|
|
|
|
|
self.model_target_size = None
|
|
|
|
|
self.cache_dir = cache_dir
|
|
|
|
|
self.hf_token = hf_token
|
|
|
|
|
self.load_model(model_repo, cache_dir, hf_token)
|
|
|
|
|
self.load_model(model_repo, cache_dir, hf_token, num_threads=num_threads)
|
|
|
|
|
|
|
|
|
|
def load_model(self, model_repo, cache_dir=None, hf_token=None):
|
|
|
|
|
def load_model(
|
|
|
|
|
self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None
|
|
|
|
|
):
|
|
|
|
|
"""Load the model and tags from the specified repository.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model_repo (str): Repository name on HuggingFace.
|
|
|
|
|
cache_dir (str, optional): Directory to cache the model. Defaults to None.
|
|
|
|
|
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
|
|
|
|
|
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
|
|
|
|
"""
|
|
|
|
|
with console.status("Loading model..."):
|
|
|
|
|
csv_path = huggingface_hub.hf_hub_download(
|
|
|
|
@ -188,6 +206,7 @@ class Tagger:
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
use_auth_token=hf_token,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
model_path = huggingface_hub.hf_hub_download(
|
|
|
|
|
model_repo,
|
|
|
|
|
MODEL_FILENAME,
|
|
|
|
@ -197,8 +216,11 @@ class Tagger:
|
|
|
|
|
|
|
|
|
|
tags_df = pd.read_csv(csv_path)
|
|
|
|
|
self.sep_tags = load_labels(tags_df)
|
|
|
|
|
|
|
|
|
|
model = rt.InferenceSession(model_path)
|
|
|
|
|
options = rt.SessionOptions()
|
|
|
|
|
if num_threads:
|
|
|
|
|
options.intra_op_num_threads = num_threads
|
|
|
|
|
options.inter_op_num_threads = num_threads
|
|
|
|
|
model = rt.InferenceSession(model_path, options)
|
|
|
|
|
_, height, _, _ = model.get_inputs()[0].shape
|
|
|
|
|
self.model_target_size = height
|
|
|
|
|
self.model = model
|
|
|
|
@ -234,8 +256,9 @@ class Tagger:
|
|
|
|
|
(target_size, target_size),
|
|
|
|
|
Image.BICUBIC,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return np.asarray(padded_image, dtype=np.float32)
|
|
|
|
|
array = np.asarray(padded_image, dtype=np.float32)
|
|
|
|
|
array = array[:, :, [2, 1, 0]]
|
|
|
|
|
return array
|
|
|
|
|
|
|
|
|
|
def tag(
|
|
|
|
|
self,
|
|
|
|
@ -266,9 +289,9 @@ class Tagger:
|
|
|
|
|
]
|
|
|
|
|
duration = time.time() - started_at
|
|
|
|
|
image_length = len(images)
|
|
|
|
|
console.log(f"Tagging {image_length} image{
|
|
|
|
|
's' if image_length > 1 else ''
|
|
|
|
|
} took {duration:.2f} seconds.")
|
|
|
|
|
self.logger.info(
|
|
|
|
|
f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds."
|
|
|
|
|
)
|
|
|
|
|
return results[0] if len(results) == 1 else results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -278,4 +301,4 @@ if __name__ == "__main__":
|
|
|
|
|
tagger = Tagger()
|
|
|
|
|
image = Image.open("./tests/images/赤松楓.9d64b955.jpeg")
|
|
|
|
|
result = tagger.tag(image)
|
|
|
|
|
console.log(result)
|
|
|
|
|
tagger.logger.info(result)
|
|
|
|
|