|
|
|
@ -2,6 +2,7 @@ import logging
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
from typing import Any, Sequence
|
|
|
|
|
|
|
|
|
|
import huggingface_hub
|
|
|
|
|
import numpy as np
|
|
|
|
@ -15,16 +16,9 @@ from rich.logging import RichHandler
|
|
|
|
|
# Access console for rich text and logging
|
|
|
|
|
console = rich.get_console()
|
|
|
|
|
|
|
|
|
|
# Environment variables and file paths
|
|
|
|
|
HF_TOKEN = os.environ.get(
|
|
|
|
|
"HF_TOKEN", ""
|
|
|
|
|
) # Token for authentication with HuggingFace API
|
|
|
|
|
MODEL_FILENAME = "model.onnx" # ONNX model filename
|
|
|
|
|
LABEL_FILENAME = "selected_tags.csv" # Labels CSV filename
|
|
|
|
|
|
|
|
|
|
available_providers = rt.get_available_providers()
|
|
|
|
|
supported_providers = ["CPUExecutionProvider", "CUDAExecutionProvider"]
|
|
|
|
|
providers = list(set(available_providers) & set(supported_providers))
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
|
|
|
|
MODEL_FILENAME = "model.onnx"
|
|
|
|
|
LABEL_FILENAME = "selected_tags.csv"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_labels(dataframe) -> list[str]:
|
|
|
|
@ -174,6 +168,7 @@ class Tagger:
|
|
|
|
|
hf_token=HF_TOKEN,
|
|
|
|
|
loglevel=logging.INFO,
|
|
|
|
|
num_threads=None,
|
|
|
|
|
providers=None,
|
|
|
|
|
):
|
|
|
|
|
"""Initialize the Tagger object with the model repository and tokens.
|
|
|
|
|
|
|
|
|
@ -184,16 +179,29 @@ class Tagger:
|
|
|
|
|
loglevel (int, optional): Logging level. Defaults to logging.INFO.
|
|
|
|
|
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
|
|
|
|
"""
|
|
|
|
|
if providers is None:
|
|
|
|
|
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
|
|
|
self.logger = logging.getLogger("wdtagger")
|
|
|
|
|
self.logger.setLevel(loglevel)
|
|
|
|
|
self.logger.addHandler(RichHandler())
|
|
|
|
|
self.model_target_size = None
|
|
|
|
|
self.cache_dir = cache_dir
|
|
|
|
|
self.hf_token = hf_token
|
|
|
|
|
self.load_model(model_repo, cache_dir, hf_token, num_threads=num_threads)
|
|
|
|
|
self.load_model(
|
|
|
|
|
model_repo,
|
|
|
|
|
cache_dir,
|
|
|
|
|
hf_token,
|
|
|
|
|
num_threads=num_threads,
|
|
|
|
|
providers=providers,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def load_model(
|
|
|
|
|
self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None
|
|
|
|
|
self,
|
|
|
|
|
model_repo,
|
|
|
|
|
cache_dir=None,
|
|
|
|
|
hf_token=None,
|
|
|
|
|
num_threads: int = None,
|
|
|
|
|
providers: Sequence[str | tuple[str, dict[Any, Any]]] = None,
|
|
|
|
|
):
|
|
|
|
|
"""Load the model and tags from the specified repository.
|
|
|
|
|
|
|
|
|
|