Compare commits
10 Commits
Author | SHA1 | Date | |
---|---|---|---|
01935d9e82 | |||
5b39dc7735 | |||
dbec094a3d | |||
1e6b04c0ec | |||
be7085f2f7 | |||
fd67f54fcc | |||
4e5221d7a8 | |||
5e4629b8ea | |||
ae2838f58e | |||
f9ec9de157 |
9
.vscode/settings.json
vendored
Normal file
9
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "ms-python.black-formatter",
|
||||
"editor.formatOnType": true,
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": "explicit"
|
||||
},
|
||||
},
|
||||
}
|
14
README.md
14
README.md
@ -23,3 +23,17 @@ image = Image.open("image.jpg")
|
||||
result = tagger.tag(image)
|
||||
print(result)
|
||||
```
|
||||
|
||||
You can input a image list to the tagger to use batch processing, it is faster than single image processing (test on RTX 3090):
|
||||
|
||||
```log
|
||||
---------------------------------------------------------------------------------- benchmark 'tagger': 5 tests -----------------------------------------------------------------------------------
|
||||
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
|
||||
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
test_tagger_benchmark[16] 540.8711 (1.0) 598.5156 (1.04) 558.2777 (1.0) 22.2954 (4.10) 549.9650 (1.0) 21.7318 (2.51) 2;2 1.7912 (1.0) 10 1
|
||||
test_tagger_benchmark[8] 558.9445 (1.03) 576.7220 (1.0) 567.9235 (1.02) 5.4381 (1.0) 568.7336 (1.03) 8.6569 (1.0) 2;0 1.7608 (0.98) 10 1
|
||||
test_tagger_benchmark[4] 590.6479 (1.09) 626.7126 (1.09) 597.9712 (1.07) 11.0124 (2.03) 594.5067 (1.08) 10.7656 (1.24) 1;1 1.6723 (0.93) 10 1
|
||||
test_tagger_benchmark[2] 622.8689 (1.15) 643.5122 (1.12) 630.1096 (1.13) 7.2365 (1.33) 627.1716 (1.14) 9.5823 (1.11) 3;0 1.5870 (0.89) 10 1
|
||||
test_tagger_benchmark[1] 700.6986 (1.30) 816.3089 (1.42) 721.7431 (1.29) 33.9031 (6.23) 712.6850 (1.30) 12.8756 (1.49) 1;1 1.3855 (0.77) 10 1
|
||||
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
```
|
||||
|
94
poetry.lock
generated
94
poetry.lock
generated
@ -386,37 +386,22 @@ files = [
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onnxruntime"
|
||||
name = "onnxruntime-gpu"
|
||||
version = "1.18.0"
|
||||
description = "ONNX Runtime is a runtime accelerator for Machine Learning models"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "onnxruntime-1.18.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:5a3b7993a5ecf4a90f35542a4757e29b2d653da3efe06cdd3164b91167bbe10d"},
|
||||
{file = "onnxruntime-1.18.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:15b944623b2cdfe7f7945690bfb71c10a4531b51997c8320b84e7b0bb59af902"},
|
||||
{file = "onnxruntime-1.18.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2e61ce5005118064b1a0ed73ebe936bc773a102f067db34108ea6c64dd62a179"},
|
||||
{file = "onnxruntime-1.18.0-cp310-cp310-win32.whl", hash = "sha256:a4fc8a2a526eb442317d280610936a9f73deece06c7d5a91e51570860802b93f"},
|
||||
{file = "onnxruntime-1.18.0-cp310-cp310-win_amd64.whl", hash = "sha256:71ed219b768cab004e5cd83e702590734f968679bf93aa488c1a7ffbe6e220c3"},
|
||||
{file = "onnxruntime-1.18.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:3d24bd623872a72a7fe2f51c103e20fcca2acfa35d48f2accd6be1ec8633d960"},
|
||||
{file = "onnxruntime-1.18.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f15e41ca9b307a12550bfd2ec93f88905d9fba12bab7e578f05138ad0ae10d7b"},
|
||||
{file = "onnxruntime-1.18.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f45ca2887f62a7b847d526965686b2923efa72538c89b7703c7b3fe970afd59"},
|
||||
{file = "onnxruntime-1.18.0-cp311-cp311-win32.whl", hash = "sha256:9e24d9ecc8781323d9e2eeda019b4b24babc4d624e7d53f61b1fe1a929b0511a"},
|
||||
{file = "onnxruntime-1.18.0-cp311-cp311-win_amd64.whl", hash = "sha256:f8608398976ed18aef450d83777ff6f77d0b64eced1ed07a985e1a7db8ea3771"},
|
||||
{file = "onnxruntime-1.18.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:f1d79941f15fc40b1ee67738b2ca26b23e0181bf0070b5fb2984f0988734698f"},
|
||||
{file = "onnxruntime-1.18.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:99e8caf3a8565c853a22d323a3eebc2a81e3de7591981f085a4f74f7a60aab2d"},
|
||||
{file = "onnxruntime-1.18.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:498d2b8380635f5e6ebc50ec1b45f181588927280f32390fb910301d234f97b8"},
|
||||
{file = "onnxruntime-1.18.0-cp312-cp312-win32.whl", hash = "sha256:ba7cc0ce2798a386c082aaa6289ff7e9bedc3dee622eef10e74830cff200a72e"},
|
||||
{file = "onnxruntime-1.18.0-cp312-cp312-win_amd64.whl", hash = "sha256:1fa175bd43f610465d5787ae06050c81f7ce09da2bf3e914eb282cb8eab363ef"},
|
||||
{file = "onnxruntime-1.18.0-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:0284c579c20ec8b1b472dd190290a040cc68b6caec790edb960f065d15cf164a"},
|
||||
{file = "onnxruntime-1.18.0-cp38-cp38-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d47353d036d8c380558a5643ea5f7964d9d259d31c86865bad9162c3e916d1f6"},
|
||||
{file = "onnxruntime-1.18.0-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:885509d2b9ba4b01f08f7fa28d31ee54b6477953451c7ccf124a84625f07c803"},
|
||||
{file = "onnxruntime-1.18.0-cp38-cp38-win32.whl", hash = "sha256:8614733de3695656411d71fc2f39333170df5da6c7efd6072a59962c0bc7055c"},
|
||||
{file = "onnxruntime-1.18.0-cp38-cp38-win_amd64.whl", hash = "sha256:47af3f803752fce23ea790fd8d130a47b2b940629f03193f780818622e856e7a"},
|
||||
{file = "onnxruntime-1.18.0-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:9153eb2b4d5bbab764d0aea17adadffcfc18d89b957ad191b1c3650b9930c59f"},
|
||||
{file = "onnxruntime-1.18.0-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2c7fd86eca727c989bb8d9c5104f3c45f7ee45f445cc75579ebe55d6b99dfd7c"},
|
||||
{file = "onnxruntime-1.18.0-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ac67a4de9c1326c4d87bcbfb652c923039b8a2446bb28516219236bec3b494f5"},
|
||||
{file = "onnxruntime-1.18.0-cp39-cp39-win32.whl", hash = "sha256:6ffb445816d06497df7a6dd424b20e0b2c39639e01e7fe210e247b82d15a23b9"},
|
||||
{file = "onnxruntime-1.18.0-cp39-cp39-win_amd64.whl", hash = "sha256:46de6031cb6745f33f7eca9e51ab73e8c66037fb7a3b6b4560887c5b55ab5d5d"},
|
||||
{file = "onnxruntime_gpu-1.18.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:ebee59886d4a63d54cec673c73bdcc58b5496f9946e3c539af51550c6f222e88"},
|
||||
{file = "onnxruntime_gpu-1.18.0-cp310-cp310-win_amd64.whl", hash = "sha256:c33a2bfd100bd2b82542c72deaaed40d3de858ac9624dc5730e548869dac2f2f"},
|
||||
{file = "onnxruntime_gpu-1.18.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:aa673c044f450b21163265cca2e35eb1ded03d48fb01dec6be1c412811e9b2b0"},
|
||||
{file = "onnxruntime_gpu-1.18.0-cp311-cp311-win_amd64.whl", hash = "sha256:1dfb1172de0043f7bdc32724619f5bc151dbdc5bb8e50e806271d2efa8d72715"},
|
||||
{file = "onnxruntime_gpu-1.18.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:40ce4aec8499352b96f8c67d1f2fe4ac170dbc9281dd8ebe36339d9e1c0bcdb7"},
|
||||
{file = "onnxruntime_gpu-1.18.0-cp312-cp312-win_amd64.whl", hash = "sha256:97df451777322a534dbedff49d43a4f6f74bd5258ccf31b2161e1af606bc724d"},
|
||||
{file = "onnxruntime_gpu-1.18.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:05b5f17bc78586741715903b5ec7fcef4c64d136aee94d6583dd17035190d550"},
|
||||
{file = "onnxruntime_gpu-1.18.0-cp38-cp38-win_amd64.whl", hash = "sha256:40fab33312d02fdaa93b8cc14ffb1fc3aceaef5db50383a579fd8b7ad5a03b17"},
|
||||
{file = "onnxruntime_gpu-1.18.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:e0d9e200afe57c69c6407c90189c7f16a55ce326d1ac45864353e13bfe574b3c"},
|
||||
{file = "onnxruntime_gpu-1.18.0-cp39-cp39-win_amd64.whl", hash = "sha256:a7d13f0983eec46857acab8bdace7e3178428d17ad3c6d5b84b5b211c98e5acd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -614,22 +599,33 @@ testing = ["pytest", "pytest-benchmark"]
|
||||
|
||||
[[package]]
|
||||
name = "protobuf"
|
||||
version = "5.27.0"
|
||||
version = "5.27.1"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "protobuf-5.27.0-cp310-abi3-win32.whl", hash = "sha256:2f83bf341d925650d550b8932b71763321d782529ac0eaf278f5242f513cc04e"},
|
||||
{file = "protobuf-5.27.0-cp310-abi3-win_amd64.whl", hash = "sha256:b276e3f477ea1eebff3c2e1515136cfcff5ac14519c45f9b4aa2f6a87ea627c4"},
|
||||
{file = "protobuf-5.27.0-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:744489f77c29174328d32f8921566fb0f7080a2f064c5137b9d6f4b790f9e0c1"},
|
||||
{file = "protobuf-5.27.0-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:f51f33d305e18646f03acfdb343aac15b8115235af98bc9f844bf9446573827b"},
|
||||
{file = "protobuf-5.27.0-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:56937f97ae0dcf4e220ff2abb1456c51a334144c9960b23597f044ce99c29c89"},
|
||||
{file = "protobuf-5.27.0-cp38-cp38-win32.whl", hash = "sha256:a17f4d664ea868102feaa30a674542255f9f4bf835d943d588440d1f49a3ed15"},
|
||||
{file = "protobuf-5.27.0-cp38-cp38-win_amd64.whl", hash = "sha256:aabbbcf794fbb4c692ff14ce06780a66d04758435717107c387f12fb477bf0d8"},
|
||||
{file = "protobuf-5.27.0-cp39-cp39-win32.whl", hash = "sha256:587be23f1212da7a14a6c65fd61995f8ef35779d4aea9e36aad81f5f3b80aec5"},
|
||||
{file = "protobuf-5.27.0-cp39-cp39-win_amd64.whl", hash = "sha256:7cb65fc8fba680b27cf7a07678084c6e68ee13cab7cace734954c25a43da6d0f"},
|
||||
{file = "protobuf-5.27.0-py3-none-any.whl", hash = "sha256:673ad60f1536b394b4fa0bcd3146a4130fcad85bfe3b60eaa86d6a0ace0fa374"},
|
||||
{file = "protobuf-5.27.0.tar.gz", hash = "sha256:07f2b9a15255e3cf3f137d884af7972407b556a7a220912b252f26dc3121e6bf"},
|
||||
{file = "protobuf-5.27.1-cp310-abi3-win32.whl", hash = "sha256:3adc15ec0ff35c5b2d0992f9345b04a540c1e73bfee3ff1643db43cc1d734333"},
|
||||
{file = "protobuf-5.27.1-cp310-abi3-win_amd64.whl", hash = "sha256:25236b69ab4ce1bec413fd4b68a15ef8141794427e0b4dc173e9d5d9dffc3bcd"},
|
||||
{file = "protobuf-5.27.1-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:4e38fc29d7df32e01a41cf118b5a968b1efd46b9c41ff515234e794011c78b17"},
|
||||
{file = "protobuf-5.27.1-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:917ed03c3eb8a2d51c3496359f5b53b4e4b7e40edfbdd3d3f34336e0eef6825a"},
|
||||
{file = "protobuf-5.27.1-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:ee52874a9e69a30271649be88ecbe69d374232e8fd0b4e4b0aaaa87f429f1631"},
|
||||
{file = "protobuf-5.27.1-cp38-cp38-win32.whl", hash = "sha256:7a97b9c5aed86b9ca289eb5148df6c208ab5bb6906930590961e08f097258107"},
|
||||
{file = "protobuf-5.27.1-cp38-cp38-win_amd64.whl", hash = "sha256:f6abd0f69968792da7460d3c2cfa7d94fd74e1c21df321eb6345b963f9ec3d8d"},
|
||||
{file = "protobuf-5.27.1-cp39-cp39-win32.whl", hash = "sha256:dfddb7537f789002cc4eb00752c92e67885badcc7005566f2c5de9d969d3282d"},
|
||||
{file = "protobuf-5.27.1-cp39-cp39-win_amd64.whl", hash = "sha256:39309898b912ca6febb0084ea912e976482834f401be35840a008da12d189340"},
|
||||
{file = "protobuf-5.27.1-py3-none-any.whl", hash = "sha256:4ac7249a1530a2ed50e24201d6630125ced04b30619262f06224616e0030b6cf"},
|
||||
{file = "protobuf-5.27.1.tar.gz", hash = "sha256:df5e5b8e39b7d1c25b186ffdf9f44f40f810bbcc9d2b71d9d3156fee5a9adf15"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "py-cpuinfo"
|
||||
version = "9.0.0"
|
||||
description = "Get CPU info with pure Python"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690"},
|
||||
{file = "py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -679,6 +675,26 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""}
|
||||
[package.extras]
|
||||
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-benchmark"
|
||||
version = "4.0.0"
|
||||
description = "A ``pytest`` fixture for benchmarking code. It will group the tests into rounds that are calibrated to the chosen timer."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pytest-benchmark-4.0.0.tar.gz", hash = "sha256:fb0785b83efe599a6a956361c0691ae1dbb5318018561af10f3e915caa0048d1"},
|
||||
{file = "pytest_benchmark-4.0.0-py3-none-any.whl", hash = "sha256:fdb7db64e31c8b277dff9850d2a2556d8b60bcb0ea6524e36e28ffd7c87f71d6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
py-cpuinfo = "*"
|
||||
pytest = ">=3.8"
|
||||
|
||||
[package.extras]
|
||||
aspect = ["aspectlib"]
|
||||
elasticsearch = ["elasticsearch"]
|
||||
histogram = ["pygal", "pygaljs"]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.9.0.post0"
|
||||
@ -901,4 +917,4 @@ zstd = ["zstandard (>=0.18.0)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "79947fb7b5011c525936d8a22c4d026001528b3ff1e120ac7e20796e1a35fcea"
|
||||
content-hash = "4ff24b133648154ec552eae7f59f370d80b6506eba1bba7fcd70bd70b8e07af8"
|
||||
|
@ -1,21 +1,22 @@
|
||||
[tool.poetry]
|
||||
name = "wdtagger"
|
||||
version = "0.2.0"
|
||||
version = "0.5.0"
|
||||
description = ""
|
||||
authors = ["Jianqi Pan <jannchie@gmail.com>"]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
onnxruntime = "^1.18.0"
|
||||
pillow = "^10.3.0"
|
||||
pandas = "^2.2.2"
|
||||
huggingface-hub = "^0.23.3"
|
||||
rich = "^13.7.1"
|
||||
onnxruntime-gpu = "^1.18.0"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^8.2.2"
|
||||
pytest-benchmark = "^4.0.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
32
tests/benchmark_tagger.py
Normal file
32
tests/benchmark_tagger.py
Normal file
@ -0,0 +1,32 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from wdtagger import Tagger
|
||||
|
||||
tagger = Tagger()
|
||||
image_dir = "./tests/images/"
|
||||
image_paths = [os.path.join(image_dir, image) for image in os.listdir(image_dir)] * 16
|
||||
images = [Image.open(image_path) for image_path in image_paths]
|
||||
|
||||
|
||||
def tag_in_batch(images, batch=1):
|
||||
for i in range(0, len(images), batch):
|
||||
tagger.tag(images[i : i + batch])
|
||||
|
||||
|
||||
@pytest.mark.benchmark(
|
||||
group="tagger",
|
||||
min_rounds=10,
|
||||
warmup=False,
|
||||
disable_gc=True,
|
||||
)
|
||||
@pytest.mark.parametrize("batch", [1, 2, 4, 8, 16])
|
||||
def test_tagger_benchmark(benchmark, batch):
|
||||
# warmup
|
||||
tag_in_batch(images[:1])
|
||||
benchmark.pedantic(tag_in_batch, args=(images, batch), iterations=1, rounds=10)
|
||||
|
||||
|
||||
# cmd: pytest tests/benchmark_tagger.py -v
|
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
@ -9,6 +10,7 @@ import pandas as pd
|
||||
import rich
|
||||
import rich.live
|
||||
from PIL import Image
|
||||
from rich.logging import RichHandler
|
||||
|
||||
# Access console for rich text and logging
|
||||
console = rich.get_console()
|
||||
@ -20,6 +22,10 @@ HF_TOKEN = os.environ.get(
|
||||
MODEL_FILENAME = "model.onnx" # ONNX model filename
|
||||
LABEL_FILENAME = "selected_tags.csv" # Labels CSV filename
|
||||
|
||||
available_providers = rt.get_available_providers()
|
||||
supported_providers = ["CPUExecutionProvider", "CUDAExecutionProvider"]
|
||||
providers = list(set(available_providers) & set(supported_providers))
|
||||
|
||||
|
||||
def load_labels(dataframe) -> list[str]:
|
||||
"""Load labels from a dataframe and process tag names.
|
||||
@ -67,7 +73,13 @@ def load_labels(dataframe) -> list[str]:
|
||||
|
||||
|
||||
class Result:
|
||||
def __init__(self, pred, sep_tags, general_threshold=0.35, character_threshold=0.9):
|
||||
def __init__(
|
||||
self,
|
||||
pred,
|
||||
sep_tags,
|
||||
general_threshold=0.35,
|
||||
character_threshold=0.9,
|
||||
):
|
||||
"""Initialize the Result object to store tagging results.
|
||||
|
||||
Args:
|
||||
@ -160,6 +172,8 @@ class Tagger:
|
||||
model_repo="SmilingWolf/wd-swinv2-tagger-v3",
|
||||
cache_dir=None,
|
||||
hf_token=HF_TOKEN,
|
||||
loglevel=logging.INFO,
|
||||
num_threads=None,
|
||||
):
|
||||
"""Initialize the Tagger object with the model repository and tokens.
|
||||
|
||||
@ -167,19 +181,27 @@ class Tagger:
|
||||
model_repo (str): Repository name on HuggingFace.
|
||||
cache_dir (str, optional): Directory to cache the model. Defaults to None.
|
||||
hf_token (str, optional): HuggingFace token for authentication. Defaults to HF_TOKEN.
|
||||
loglevel (int, optional): Logging level. Defaults to logging.INFO.
|
||||
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
||||
"""
|
||||
self.logger = logging.getLogger("wdtagger")
|
||||
self.logger.setLevel(loglevel)
|
||||
self.logger.addHandler(RichHandler())
|
||||
self.model_target_size = None
|
||||
self.cache_dir = cache_dir
|
||||
self.hf_token = hf_token
|
||||
self.load_model(model_repo, cache_dir, hf_token)
|
||||
self.load_model(model_repo, cache_dir, hf_token, num_threads=num_threads)
|
||||
|
||||
def load_model(self, model_repo, cache_dir=None, hf_token=None):
|
||||
def load_model(
|
||||
self, model_repo, cache_dir=None, hf_token=None, num_threads: int = None
|
||||
):
|
||||
"""Load the model and tags from the specified repository.
|
||||
|
||||
Args:
|
||||
model_repo (str): Repository name on HuggingFace.
|
||||
cache_dir (str, optional): Directory to cache the model. Defaults to None.
|
||||
hf_token (str, optional): HuggingFace token for authentication. Defaults to None.
|
||||
num_threads (int, optional): Number of threads for ONNX runtime. Defaults to None.
|
||||
"""
|
||||
with console.status("Loading model..."):
|
||||
csv_path = huggingface_hub.hf_hub_download(
|
||||
@ -188,6 +210,7 @@ class Tagger:
|
||||
cache_dir=cache_dir,
|
||||
use_auth_token=hf_token,
|
||||
)
|
||||
|
||||
model_path = huggingface_hub.hf_hub_download(
|
||||
model_repo,
|
||||
MODEL_FILENAME,
|
||||
@ -197,8 +220,15 @@ class Tagger:
|
||||
|
||||
tags_df = pd.read_csv(csv_path)
|
||||
self.sep_tags = load_labels(tags_df)
|
||||
|
||||
model = rt.InferenceSession(model_path)
|
||||
options = rt.SessionOptions()
|
||||
if num_threads:
|
||||
options.intra_op_num_threads = num_threads
|
||||
options.inter_op_num_threads = num_threads
|
||||
model = rt.InferenceSession(
|
||||
model_path,
|
||||
options,
|
||||
providers=providers,
|
||||
)
|
||||
_, height, _, _ = model.get_inputs()[0].shape
|
||||
self.model_target_size = height
|
||||
self.model = model
|
||||
@ -234,8 +264,9 @@ class Tagger:
|
||||
(target_size, target_size),
|
||||
Image.BICUBIC,
|
||||
)
|
||||
|
||||
return np.asarray(padded_image, dtype=np.float32)
|
||||
array = np.asarray(padded_image, dtype=np.float32)
|
||||
array = array[:, :, [2, 1, 0]]
|
||||
return array
|
||||
|
||||
def tag(
|
||||
self,
|
||||
@ -266,9 +297,9 @@ class Tagger:
|
||||
]
|
||||
duration = time.time() - started_at
|
||||
image_length = len(images)
|
||||
console.log(f"Tagging {image_length} image{
|
||||
's' if image_length > 1 else ''
|
||||
} took {duration:.2f} seconds.")
|
||||
self.logger.info(
|
||||
f"Tagging {image_length} image{ 's' if image_length > 1 else ''} took {duration:.2f} seconds."
|
||||
)
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
|
||||
@ -278,4 +309,4 @@ if __name__ == "__main__":
|
||||
tagger = Tagger()
|
||||
image = Image.open("./tests/images/赤松楓.9d64b955.jpeg")
|
||||
result = tagger.tag(image)
|
||||
console.log(result)
|
||||
tagger.logger.info(result)
|
||||
|
Reference in New Issue
Block a user