diff --git a/wdtagger/__init__.py b/wdtagger/__init__.py index 922041b..36967c4 100644 --- a/wdtagger/__init__.py +++ b/wdtagger/__init__.py @@ -2,6 +2,7 @@ import logging import os import time from collections import OrderedDict +from typing import Any, Sequence import huggingface_hub import numpy as np @@ -15,16 +16,9 @@ from rich.logging import RichHandler # Access console for rich text and logging console = rich.get_console() -# Environment variables and file paths -HF_TOKEN = os.environ.get( - "HF_TOKEN", "" -) # Token for authentication with HuggingFace API -MODEL_FILENAME = "model.onnx" # ONNX model filename -LABEL_FILENAME = "selected_tags.csv" # Labels CSV filename - -available_providers = rt.get_available_providers() -supported_providers = ["CPUExecutionProvider", "CUDAExecutionProvider"] -providers = list(set(available_providers) & set(supported_providers)) +HF_TOKEN = os.environ.get("HF_TOKEN", "") +MODEL_FILENAME = "model.onnx" +LABEL_FILENAME = "selected_tags.csv" def load_labels(dataframe) -> list[str]: @@ -174,6 +168,7 @@ class Tagger: hf_token=HF_TOKEN, loglevel=logging.INFO, num_threads=None, + providers=None, ): """Initialize the Tagger object with the model repository and tokens. @@ -184,16 +179,29 @@ class Tagger: loglevel (int, optional): Logging level. Defaults to logging.INFO. num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None. """ + if providers is None: + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] self.logger = logging.getLogger("wdtagger") self.logger.setLevel(loglevel) self.logger.addHandler(RichHandler()) self.model_target_size = None self.cache_dir = cache_dir self.hf_token = hf_token - self.load_model(model_repo, cache_dir, hf_token, num_threads=num_threads) + self.load_model( + model_repo, + cache_dir, + hf_token, + num_threads=num_threads, + providers=providers, + ) def load_model( - self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None + self, + model_repo, + cache_dir=None, + hf_token=None, + num_threads: int = None, + providers: Sequence[str | tuple[str, dict[Any, Any]]] = None, ): """Load the model and tags from the specified repository.