|
|
|
@ -9,14 +9,9 @@ import huggingface_hub
|
|
|
|
|
import numpy as np
|
|
|
|
|
import onnxruntime as rt
|
|
|
|
|
import pandas as pd
|
|
|
|
|
import rich
|
|
|
|
|
import rich.live
|
|
|
|
|
from PIL import Image
|
|
|
|
|
from rich.logging import RichHandler
|
|
|
|
|
|
|
|
|
|
# Access console for rich text and logging
|
|
|
|
|
console = rich.get_console()
|
|
|
|
|
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
|
|
|
|
MODEL_FILENAME = "model.onnx"
|
|
|
|
|
LABEL_FILENAME = "selected_tags.csv"
|
|
|
|
@ -192,6 +187,7 @@ class Tagger:
|
|
|
|
|
loglevel=logging.INFO,
|
|
|
|
|
num_threads=None,
|
|
|
|
|
providers=None,
|
|
|
|
|
console=None,
|
|
|
|
|
):
|
|
|
|
|
"""Initialize the Tagger object with the model repository and tokens.
|
|
|
|
|
|
|
|
|
@ -201,7 +197,17 @@ class Tagger:
|
|
|
|
|
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
|
|
|
|
|
loglevel (int, optional): Logging level. Defaults to logging.INFO.
|
|
|
|
|
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
|
|
|
|
providers (list, optional): List of providers for ONNX runtime. Defaults to None.
|
|
|
|
|
console (rich.console.Console, optional): Rich console object. Defaults to None.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if not console:
|
|
|
|
|
from rich import get_console
|
|
|
|
|
|
|
|
|
|
self.console = get_console()
|
|
|
|
|
else:
|
|
|
|
|
self.console = console
|
|
|
|
|
|
|
|
|
|
if providers is None:
|
|
|
|
|
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
|
|
|
self.logger = logging.getLogger("wdtagger")
|
|
|
|
@ -234,7 +240,7 @@ class Tagger:
|
|
|
|
|
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
|
|
|
|
|
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
|
|
|
|
"""
|
|
|
|
|
with console.status("Loading model..."):
|
|
|
|
|
with self.console.status("Loading model..."):
|
|
|
|
|
csv_path = huggingface_hub.hf_hub_download(
|
|
|
|
|
model_repo,
|
|
|
|
|
LABEL_FILENAME,
|
|
|
|
|