chore(type): better result type
This commit is contained in:
parent
5bcbf991e8
commit
c7fbdeb7b1
@ -3,7 +3,7 @@ import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Sequence, Union
|
||||
from typing import Any, List, Literal, Sequence, Union, overload
|
||||
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
@ -115,17 +115,17 @@ class Result:
|
||||
self.rating_data = rating_data
|
||||
|
||||
@property
|
||||
def general_tags(self):
|
||||
def general_tags(self) -> tuple[str]:
|
||||
"""Return general tags as a tuple."""
|
||||
return tuple(self.general_tag_data.keys())
|
||||
|
||||
@property
|
||||
def character_tags(self):
|
||||
def character_tags(self) -> tuple[str]:
|
||||
"""Return character tags as a tuple."""
|
||||
return tuple(self.character_tag_data.keys())
|
||||
|
||||
@property
|
||||
def rating(self):
|
||||
def rating(self) -> Literal["general", "sensitive", "questionable", "explicit"]:
|
||||
"""Return the highest rated tag."""
|
||||
return max(self.rating_data, key=self.rating_data.get)
|
||||
|
||||
@ -307,6 +307,14 @@ class Tagger:
|
||||
array = array[:, :, [2, 1, 0]]
|
||||
return array
|
||||
|
||||
@overload
|
||||
def tag(self, image: Input, general_threshold: float = 0.35, character_threshold: float = 0.9) -> Result: ...
|
||||
|
||||
@overload
|
||||
def tag(
|
||||
self, image: List[Input], general_threshold: float = 0.35, character_threshold: float = 0.9
|
||||
) -> List[Result]: ...
|
||||
|
||||
def tag(
|
||||
self,
|
||||
image: Union[Input, List[Input]],
|
||||
|
Loading…
x
Reference in New Issue
Block a user