feat(threading): allow use external rich.console

This commit is contained in:
Jianqi Pan 2024-07-10 16:53:32 +09:00
parent c910823173
commit 6d1cbc0b2b

View File

@ -9,14 +9,9 @@ import huggingface_hub
import numpy as np
import onnxruntime as rt
import pandas as pd
import rich
import rich.live
from PIL import Image
from rich.logging import RichHandler
# Access console for rich text and logging
console = rich.get_console()
HF_TOKEN = os.environ.get("HF_TOKEN", "")
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
@ -192,6 +187,7 @@ class Tagger:
loglevel=logging.INFO,
num_threads=None,
providers=None,
console=None,
):
"""Initialize the Tagger object with the model repository and tokens.
@ -201,7 +197,17 @@ class Tagger:
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
loglevel (int, optional): Logging level. Defaults to logging.INFO.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
providers (list, optional): List of providers for ONNX runtime. Defaults to None.
console (rich.console.Console, optional): Rich console object. Defaults to None.
"""
if not console:
from rich import get_console
self.console = get_console()
else:
self.console = console
if providers is None:
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
self.logger = logging.getLogger("wdtagger")
@ -234,7 +240,7 @@ class Tagger:
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
"""
with console.status("Loading model..."):
with self.console.status("Loading model..."):
csv_path = huggingface_hub.hf_hub_download(
model_repo,
LABEL_FILENAME,