2 Commits

Author SHA1 Message Date
09d331d8f0 version: v0.9.0 2024-07-10 16:53:47 +09:00
6d1cbc0b2b feat(threading): allow use external rich.console 2024-07-10 16:53:32 +09:00
2 changed files with 13 additions and 7 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "wdtagger"
version = "0.8.0"
version = "0.9.0"
description = ""
authors = ["Jianqi Pan <jannchie@gmail.com>"]
readme = "README.md"

View File

@ -9,14 +9,9 @@ import huggingface_hub
import numpy as np
import onnxruntime as rt
import pandas as pd
import rich
import rich.live
from PIL import Image
from rich.logging import RichHandler
# Access console for rich text and logging
console = rich.get_console()
HF_TOKEN = os.environ.get("HF_TOKEN", "")
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
@ -192,6 +187,7 @@ class Tagger:
loglevel=logging.INFO,
num_threads=None,
providers=None,
console=None,
):
"""Initialize the Tagger object with the model repository and tokens.
@ -201,7 +197,17 @@ class Tagger:
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
loglevel (int, optional): Logging level. Defaults to logging.INFO.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
providers (list, optional): List of providers for ONNX runtime. Defaults to None.
console (rich.console.Console, optional): Rich console object. Defaults to None.
"""
if not console:
from rich import get_console
self.console = get_console()
else:
self.console = console
if providers is None:
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
self.logger = logging.getLogger("wdtagger")
@ -234,7 +240,7 @@ class Tagger:
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
"""
with console.status("Loading model..."):
with self.console.status("Loading model..."):
csv_path = huggingface_hub.hf_hub_download(
model_repo,
LABEL_FILENAME,