feat(provider): enable change onnx provider

This commit is contained in:
Jianqi Pan 2024-06-23 03:20:56 +09:00
parent b9a763468a
commit c82f881c85

View File

@ -2,6 +2,7 @@ import logging
import os
import time
from collections import OrderedDict
from typing import Any, Sequence
import huggingface_hub
import numpy as np
@ -15,16 +16,9 @@ from rich.logging import RichHandler
# Access console for rich text and logging
console = rich.get_console()
# Environment variables and file paths
HF_TOKEN = os.environ.get(
"HF_TOKEN", ""
) # Token for authentication with HuggingFace API
MODEL_FILENAME = "model.onnx" # ONNX model filename
LABEL_FILENAME = "selected_tags.csv" # Labels CSV filename
available_providers = rt.get_available_providers()
supported_providers = ["CPUExecutionProvider", "CUDAExecutionProvider"]
providers = list(set(available_providers) & set(supported_providers))
HF_TOKEN = os.environ.get("HF_TOKEN", "")
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
def load_labels(dataframe) -> list[str]:
@ -174,6 +168,7 @@ class Tagger:
hf_token=HF_TOKEN,
loglevel=logging.INFO,
num_threads=None,
providers=None,
):
"""Initialize the Tagger object with the model repository and tokens.
@ -184,16 +179,29 @@ class Tagger:
loglevel (int, optional): Logging level. Defaults to logging.INFO.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
"""
if providers is None:
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
self.logger = logging.getLogger("wdtagger")
self.logger.setLevel(loglevel)
self.logger.addHandler(RichHandler())
self.model_target_size = None
self.cache_dir = cache_dir
self.hf_token = hf_token
self.load_model(model_repo, cache_dir, hf_token, num_threads=num_threads)
self.load_model(
model_repo,
cache_dir,
hf_token,
num_threads=num_threads,
providers=providers,
)
def load_model(
self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None
self,
model_repo,
cache_dir=None,
hf_token=None,
num_threads: int = None,
providers: Sequence[str | tuple[str, dict[Any, Any]]] = None,
):
"""Load the model and tags from the specified repository.