diff --git a/poetry.lock b/poetry.lock index 5ea51bb..f3677ef 100644 --- a/poetry.lock +++ b/poetry.lock @@ -386,37 +386,22 @@ files = [ ] [[package]] -name = "onnxruntime" +name = "onnxruntime-gpu" version = "1.18.0" description = "ONNX Runtime is a runtime accelerator for Machine Learning models" optional = false python-versions = "*" files = [ - {file = "onnxruntime-1.18.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:5a3b7993a5ecf4a90f35542a4757e29b2d653da3efe06cdd3164b91167bbe10d"}, - {file = "onnxruntime-1.18.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:15b944623b2cdfe7f7945690bfb71c10a4531b51997c8320b84e7b0bb59af902"}, - {file = "onnxruntime-1.18.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2e61ce5005118064b1a0ed73ebe936bc773a102f067db34108ea6c64dd62a179"}, - {file = "onnxruntime-1.18.0-cp310-cp310-win32.whl", hash = "sha256:a4fc8a2a526eb442317d280610936a9f73deece06c7d5a91e51570860802b93f"}, - {file = "onnxruntime-1.18.0-cp310-cp310-win_amd64.whl", hash = "sha256:71ed219b768cab004e5cd83e702590734f968679bf93aa488c1a7ffbe6e220c3"}, - {file = "onnxruntime-1.18.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:3d24bd623872a72a7fe2f51c103e20fcca2acfa35d48f2accd6be1ec8633d960"}, - {file = "onnxruntime-1.18.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f15e41ca9b307a12550bfd2ec93f88905d9fba12bab7e578f05138ad0ae10d7b"}, - {file = "onnxruntime-1.18.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f45ca2887f62a7b847d526965686b2923efa72538c89b7703c7b3fe970afd59"}, - {file = "onnxruntime-1.18.0-cp311-cp311-win32.whl", hash = "sha256:9e24d9ecc8781323d9e2eeda019b4b24babc4d624e7d53f61b1fe1a929b0511a"}, - {file = "onnxruntime-1.18.0-cp311-cp311-win_amd64.whl", hash = "sha256:f8608398976ed18aef450d83777ff6f77d0b64eced1ed07a985e1a7db8ea3771"}, - {file = "onnxruntime-1.18.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:f1d79941f15fc40b1ee67738b2ca26b23e0181bf0070b5fb2984f0988734698f"}, - {file = "onnxruntime-1.18.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:99e8caf3a8565c853a22d323a3eebc2a81e3de7591981f085a4f74f7a60aab2d"}, - {file = "onnxruntime-1.18.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:498d2b8380635f5e6ebc50ec1b45f181588927280f32390fb910301d234f97b8"}, - {file = "onnxruntime-1.18.0-cp312-cp312-win32.whl", hash = "sha256:ba7cc0ce2798a386c082aaa6289ff7e9bedc3dee622eef10e74830cff200a72e"}, - {file = "onnxruntime-1.18.0-cp312-cp312-win_amd64.whl", hash = "sha256:1fa175bd43f610465d5787ae06050c81f7ce09da2bf3e914eb282cb8eab363ef"}, - {file = "onnxruntime-1.18.0-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:0284c579c20ec8b1b472dd190290a040cc68b6caec790edb960f065d15cf164a"}, - {file = "onnxruntime-1.18.0-cp38-cp38-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d47353d036d8c380558a5643ea5f7964d9d259d31c86865bad9162c3e916d1f6"}, - {file = "onnxruntime-1.18.0-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:885509d2b9ba4b01f08f7fa28d31ee54b6477953451c7ccf124a84625f07c803"}, - {file = "onnxruntime-1.18.0-cp38-cp38-win32.whl", hash = "sha256:8614733de3695656411d71fc2f39333170df5da6c7efd6072a59962c0bc7055c"}, - {file = "onnxruntime-1.18.0-cp38-cp38-win_amd64.whl", hash = "sha256:47af3f803752fce23ea790fd8d130a47b2b940629f03193f780818622e856e7a"}, - {file = "onnxruntime-1.18.0-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:9153eb2b4d5bbab764d0aea17adadffcfc18d89b957ad191b1c3650b9930c59f"}, - {file = "onnxruntime-1.18.0-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2c7fd86eca727c989bb8d9c5104f3c45f7ee45f445cc75579ebe55d6b99dfd7c"}, - {file = "onnxruntime-1.18.0-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ac67a4de9c1326c4d87bcbfb652c923039b8a2446bb28516219236bec3b494f5"}, - {file = "onnxruntime-1.18.0-cp39-cp39-win32.whl", hash = "sha256:6ffb445816d06497df7a6dd424b20e0b2c39639e01e7fe210e247b82d15a23b9"}, - {file = "onnxruntime-1.18.0-cp39-cp39-win_amd64.whl", hash = "sha256:46de6031cb6745f33f7eca9e51ab73e8c66037fb7a3b6b4560887c5b55ab5d5d"}, + {file = "onnxruntime_gpu-1.18.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:ebee59886d4a63d54cec673c73bdcc58b5496f9946e3c539af51550c6f222e88"}, + {file = "onnxruntime_gpu-1.18.0-cp310-cp310-win_amd64.whl", hash = "sha256:c33a2bfd100bd2b82542c72deaaed40d3de858ac9624dc5730e548869dac2f2f"}, + {file = "onnxruntime_gpu-1.18.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:aa673c044f450b21163265cca2e35eb1ded03d48fb01dec6be1c412811e9b2b0"}, + {file = "onnxruntime_gpu-1.18.0-cp311-cp311-win_amd64.whl", hash = "sha256:1dfb1172de0043f7bdc32724619f5bc151dbdc5bb8e50e806271d2efa8d72715"}, + {file = "onnxruntime_gpu-1.18.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:40ce4aec8499352b96f8c67d1f2fe4ac170dbc9281dd8ebe36339d9e1c0bcdb7"}, + {file = "onnxruntime_gpu-1.18.0-cp312-cp312-win_amd64.whl", hash = "sha256:97df451777322a534dbedff49d43a4f6f74bd5258ccf31b2161e1af606bc724d"}, + {file = "onnxruntime_gpu-1.18.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:05b5f17bc78586741715903b5ec7fcef4c64d136aee94d6583dd17035190d550"}, + {file = "onnxruntime_gpu-1.18.0-cp38-cp38-win_amd64.whl", hash = "sha256:40fab33312d02fdaa93b8cc14ffb1fc3aceaef5db50383a579fd8b7ad5a03b17"}, + {file = "onnxruntime_gpu-1.18.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:e0d9e200afe57c69c6407c90189c7f16a55ce326d1ac45864353e13bfe574b3c"}, + {file = "onnxruntime_gpu-1.18.0-cp39-cp39-win_amd64.whl", hash = "sha256:a7d13f0983eec46857acab8bdace7e3178428d17ad3c6d5b84b5b211c98e5acd"}, ] [package.dependencies] @@ -614,22 +599,33 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "protobuf" -version = "5.27.0" +version = "5.27.1" description = "" optional = false python-versions = ">=3.8" files = [ - {file = "protobuf-5.27.0-cp310-abi3-win32.whl", hash = "sha256:2f83bf341d925650d550b8932b71763321d782529ac0eaf278f5242f513cc04e"}, - {file = "protobuf-5.27.0-cp310-abi3-win_amd64.whl", hash = "sha256:b276e3f477ea1eebff3c2e1515136cfcff5ac14519c45f9b4aa2f6a87ea627c4"}, - {file = "protobuf-5.27.0-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:744489f77c29174328d32f8921566fb0f7080a2f064c5137b9d6f4b790f9e0c1"}, - {file = "protobuf-5.27.0-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:f51f33d305e18646f03acfdb343aac15b8115235af98bc9f844bf9446573827b"}, - {file = "protobuf-5.27.0-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:56937f97ae0dcf4e220ff2abb1456c51a334144c9960b23597f044ce99c29c89"}, - {file = "protobuf-5.27.0-cp38-cp38-win32.whl", hash = "sha256:a17f4d664ea868102feaa30a674542255f9f4bf835d943d588440d1f49a3ed15"}, - {file = "protobuf-5.27.0-cp38-cp38-win_amd64.whl", hash = "sha256:aabbbcf794fbb4c692ff14ce06780a66d04758435717107c387f12fb477bf0d8"}, - {file = "protobuf-5.27.0-cp39-cp39-win32.whl", hash = "sha256:587be23f1212da7a14a6c65fd61995f8ef35779d4aea9e36aad81f5f3b80aec5"}, - {file = "protobuf-5.27.0-cp39-cp39-win_amd64.whl", hash = "sha256:7cb65fc8fba680b27cf7a07678084c6e68ee13cab7cace734954c25a43da6d0f"}, - {file = "protobuf-5.27.0-py3-none-any.whl", hash = "sha256:673ad60f1536b394b4fa0bcd3146a4130fcad85bfe3b60eaa86d6a0ace0fa374"}, - {file = "protobuf-5.27.0.tar.gz", hash = "sha256:07f2b9a15255e3cf3f137d884af7972407b556a7a220912b252f26dc3121e6bf"}, + {file = "protobuf-5.27.1-cp310-abi3-win32.whl", hash = "sha256:3adc15ec0ff35c5b2d0992f9345b04a540c1e73bfee3ff1643db43cc1d734333"}, + {file = "protobuf-5.27.1-cp310-abi3-win_amd64.whl", hash = "sha256:25236b69ab4ce1bec413fd4b68a15ef8141794427e0b4dc173e9d5d9dffc3bcd"}, + {file = "protobuf-5.27.1-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:4e38fc29d7df32e01a41cf118b5a968b1efd46b9c41ff515234e794011c78b17"}, + {file = "protobuf-5.27.1-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:917ed03c3eb8a2d51c3496359f5b53b4e4b7e40edfbdd3d3f34336e0eef6825a"}, + {file = "protobuf-5.27.1-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:ee52874a9e69a30271649be88ecbe69d374232e8fd0b4e4b0aaaa87f429f1631"}, + {file = "protobuf-5.27.1-cp38-cp38-win32.whl", hash = "sha256:7a97b9c5aed86b9ca289eb5148df6c208ab5bb6906930590961e08f097258107"}, + {file = "protobuf-5.27.1-cp38-cp38-win_amd64.whl", hash = "sha256:f6abd0f69968792da7460d3c2cfa7d94fd74e1c21df321eb6345b963f9ec3d8d"}, + {file = "protobuf-5.27.1-cp39-cp39-win32.whl", hash = "sha256:dfddb7537f789002cc4eb00752c92e67885badcc7005566f2c5de9d969d3282d"}, + {file = "protobuf-5.27.1-cp39-cp39-win_amd64.whl", hash = "sha256:39309898b912ca6febb0084ea912e976482834f401be35840a008da12d189340"}, + {file = "protobuf-5.27.1-py3-none-any.whl", hash = "sha256:4ac7249a1530a2ed50e24201d6630125ced04b30619262f06224616e0030b6cf"}, + {file = "protobuf-5.27.1.tar.gz", hash = "sha256:df5e5b8e39b7d1c25b186ffdf9f44f40f810bbcc9d2b71d9d3156fee5a9adf15"}, +] + +[[package]] +name = "py-cpuinfo" +version = "9.0.0" +description = "Get CPU info with pure Python" +optional = false +python-versions = "*" +files = [ + {file = "py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690"}, + {file = "py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5"}, ] [[package]] @@ -679,6 +675,26 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-benchmark" +version = "4.0.0" +description = "A ``pytest`` fixture for benchmarking code. It will group the tests into rounds that are calibrated to the chosen timer." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-benchmark-4.0.0.tar.gz", hash = "sha256:fb0785b83efe599a6a956361c0691ae1dbb5318018561af10f3e915caa0048d1"}, + {file = "pytest_benchmark-4.0.0-py3-none-any.whl", hash = "sha256:fdb7db64e31c8b277dff9850d2a2556d8b60bcb0ea6524e36e28ffd7c87f71d6"}, +] + +[package.dependencies] +py-cpuinfo = "*" +pytest = ">=3.8" + +[package.extras] +aspect = ["aspectlib"] +elasticsearch = ["elasticsearch"] +histogram = ["pygal", "pygaljs"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -901,4 +917,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "79947fb7b5011c525936d8a22c4d026001528b3ff1e120ac7e20796e1a35fcea" +content-hash = "4ff24b133648154ec552eae7f59f370d80b6506eba1bba7fcd70bd70b8e07af8" diff --git a/pyproject.toml b/pyproject.toml index f8cd16c..10dea54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,15 +7,16 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.10" -onnxruntime = "^1.18.0" pillow = "^10.3.0" pandas = "^2.2.2" huggingface-hub = "^0.23.3" rich = "^13.7.1" +onnxruntime-gpu = "^1.18.0" [tool.poetry.group.dev.dependencies] pytest = "^8.2.2" +pytest-benchmark = "^4.0.0" [build-system] requires = ["poetry-core"] diff --git a/wdtagger/__init__.py b/wdtagger/__init__.py index ea7cd0c..922041b 100644 --- a/wdtagger/__init__.py +++ b/wdtagger/__init__.py @@ -22,6 +22,10 @@ HF_TOKEN = os.environ.get( MODEL_FILENAME = "model.onnx" # ONNX model filename LABEL_FILENAME = "selected_tags.csv" # Labels CSV filename +available_providers = rt.get_available_providers() +supported_providers = ["CPUExecutionProvider", "CUDAExecutionProvider"] +providers = list(set(available_providers) & set(supported_providers)) + def load_labels(dataframe) -> list[str]: """Load labels from a dataframe and process tag names. @@ -220,7 +224,11 @@ class Tagger: if num_threads: options.intra_op_num_threads = num_threads options.inter_op_num_threads = num_threads - model = rt.InferenceSession(model_path, options) + model = rt.InferenceSession( + model_path, + options, + providers=providers, + ) _, height, _, _ = model.get_inputs()[0].shape self.model_target_size = height self.model = model