feat(provider): enable change onnx provider
This commit is contained in:
parent
b9a763468a
commit
c82f881c85
@ -2,6 +2,7 @@ import logging
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Sequence
|
||||
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
@ -15,16 +16,9 @@ from rich.logging import RichHandler
|
||||
# Access console for rich text and logging
|
||||
console = rich.get_console()
|
||||
|
||||
# Environment variables and file paths
|
||||
HF_TOKEN = os.environ.get(
|
||||
"HF_TOKEN", ""
|
||||
) # Token for authentication with HuggingFace API
|
||||
MODEL_FILENAME = "model.onnx" # ONNX model filename
|
||||
LABEL_FILENAME = "selected_tags.csv" # Labels CSV filename
|
||||
|
||||
available_providers = rt.get_available_providers()
|
||||
supported_providers = ["CPUExecutionProvider", "CUDAExecutionProvider"]
|
||||
providers = list(set(available_providers) & set(supported_providers))
|
||||
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
||||
MODEL_FILENAME = "model.onnx"
|
||||
LABEL_FILENAME = "selected_tags.csv"
|
||||
|
||||
|
||||
def load_labels(dataframe) -> list[str]:
|
||||
@ -174,6 +168,7 @@ class Tagger:
|
||||
hf_token=HF_TOKEN,
|
||||
loglevel=logging.INFO,
|
||||
num_threads=None,
|
||||
providers=None,
|
||||
):
|
||||
"""Initialize the Tagger object with the model repository and tokens.
|
||||
|
||||
@ -184,16 +179,29 @@ class Tagger:
|
||||
loglevel (int, optional): Logging level. Defaults to logging.INFO.
|
||||
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
||||
"""
|
||||
if providers is None:
|
||||
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
self.logger = logging.getLogger("wdtagger")
|
||||
self.logger.setLevel(loglevel)
|
||||
self.logger.addHandler(RichHandler())
|
||||
self.model_target_size = None
|
||||
self.cache_dir = cache_dir
|
||||
self.hf_token = hf_token
|
||||
self.load_model(model_repo, cache_dir, hf_token, num_threads=num_threads)
|
||||
self.load_model(
|
||||
model_repo,
|
||||
cache_dir,
|
||||
hf_token,
|
||||
num_threads=num_threads,
|
||||
providers=providers,
|
||||
)
|
||||
|
||||
def load_model(
|
||||
self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None
|
||||
self,
|
||||
model_repo,
|
||||
cache_dir=None,
|
||||
hf_token=None,
|
||||
num_threads: int = None,
|
||||
providers: Sequence[str | tuple[str, dict[Any, Any]]] = None,
|
||||
):
|
||||
"""Load the model and tags from the specified repository.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user