diff --git a/wdtagger/__init__.py b/wdtagger/__init__.py index 3e47494..623fd1b 100644 --- a/wdtagger/__init__.py +++ b/wdtagger/__init__.py @@ -9,14 +9,9 @@ import huggingface_hub import numpy as np import onnxruntime as rt import pandas as pd -import rich -import rich.live from PIL import Image from rich.logging import RichHandler -# Access console for rich text and logging -console = rich.get_console() - HF_TOKEN = os.environ.get("HF_TOKEN", "") MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" @@ -192,6 +187,7 @@ class Tagger: loglevel=logging.INFO, num_threads=None, providers=None, + console=None, ): """Initialize the Tagger object with the model repository and tokens. @@ -201,7 +197,17 @@ class Tagger: hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN. loglevel (int, optional): Logging level. Defaults to logging.INFO. num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None. + providers (list, optional): List of providers for ONNX runtime. Defaults to None. + console (rich.console.Console, optional): Rich console object. Defaults to None. """ + + if not console: + from rich import get_console + + self.console = get_console() + else: + self.console = console + if providers is None: providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] self.logger = logging.getLogger("wdtagger") @@ -234,7 +240,7 @@ class Tagger: hf_token (str, optional): HuggingFace token for authentication. Defaults to None. num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None. """ - with console.status("Loading model..."): + with self.console.status("Loading model..."): csv_path = huggingface_hub.hf_hub_download( model_repo, LABEL_FILENAME,