From c7fbdeb7b150cb232c9feeadd2b36fd7621d7d66 Mon Sep 17 00:00:00 2001 From: Jianqi Pan Date: Sat, 24 Aug 2024 19:12:07 +0900 Subject: [PATCH] chore(type): better result type --- wdtagger/__init__.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/wdtagger/__init__.py b/wdtagger/__init__.py index b0c469c..bfb0940 100644 --- a/wdtagger/__init__.py +++ b/wdtagger/__init__.py @@ -3,7 +3,7 @@ import os import time from collections import OrderedDict from pathlib import Path -from typing import Any, List, Sequence, Union +from typing import Any, List, Literal, Sequence, Union, overload import huggingface_hub import numpy as np @@ -115,17 +115,17 @@ class Result: self.rating_data = rating_data @property - def general_tags(self): + def general_tags(self) -> tuple[str]: """Return general tags as a tuple.""" return tuple(self.general_tag_data.keys()) @property - def character_tags(self): + def character_tags(self) -> tuple[str]: """Return character tags as a tuple.""" return tuple(self.character_tag_data.keys()) @property - def rating(self): + def rating(self) -> Literal["general", "sensitive", "questionable", "explicit"]: """Return the highest rated tag.""" return max(self.rating_data, key=self.rating_data.get) @@ -307,6 +307,14 @@ class Tagger: array = array[:, :, [2, 1, 0]] return array + @overload + def tag(self, image: Input, general_threshold: float = 0.35, character_threshold: float = 0.9) -> Result: ... + + @overload + def tag( + self, image: List[Input], general_threshold: float = 0.35, character_threshold: float = 0.9 + ) -> List[Result]: ... + def tag( self, image: Union[Input, List[Input]],