From 6d1cbc0b2bb34778300bd40da38c6b40706a9b68 Mon Sep 17 00:00:00 2001 From: Jianqi Pan Date: Wed, 10 Jul 2024 16:53:32 +0900 Subject: [PATCH] feat(threading): allow use external rich.console --- wdtagger/__init__.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/wdtagger/__init__.py b/wdtagger/__init__.py index 3e47494..623fd1b 100644 --- a/wdtagger/__init__.py +++ b/wdtagger/__init__.py @@ -9,14 +9,9 @@ import huggingface_hub import numpy as np import onnxruntime as rt import pandas as pd -import rich -import rich.live from PIL import Image from rich.logging import RichHandler -# Access console for rich text and logging -console = rich.get_console() - HF_TOKEN = os.environ.get("HF_TOKEN", "") MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" @@ -192,6 +187,7 @@ class Tagger: loglevel=logging.INFO, num_threads=None, providers=None, + console=None, ): """Initialize the Tagger object with the model repository and tokens. @@ -201,7 +197,17 @@ class Tagger: hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN. loglevel (int, optional): Logging level. Defaults to logging.INFO. num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None. + providers (list, optional): List of providers for ONNX runtime. Defaults to None. + console (rich.console.Console, optional): Rich console object. Defaults to None. """ + + if not console: + from rich import get_console + + self.console = get_console() + else: + self.console = console + if providers is None: providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] self.logger = logging.getLogger("wdtagger") @@ -234,7 +240,7 @@ class Tagger: hf_token (str, optional): HuggingFace token for authentication. Defaults to None. num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None. """ - with console.status("Loading model..."): + with self.console.status("Loading model..."): csv_path = huggingface_hub.hf_hub_download( model_repo, LABEL_FILENAME,