diff --git a/src/wdtagger/__init__.py b/src/wdtagger/__init__.py index d855dc2..e6f9482 100644 --- a/src/wdtagger/__init__.py +++ b/src/wdtagger/__init__.py @@ -153,12 +153,27 @@ class Result: @property def general_tags(self) -> tuple[str, ...]: """Return general tags as a tuple.""" - return tuple(self.general_tag_data.keys()) + + return tuple( + d.replace("_", " ").replace("(", R"\(").replace(")", R"\)") + for d in sorted( + self.general_tag_data, + key=lambda k: self.general_tag_data[k], + reverse=True, + ) + ) @property def character_tags(self) -> tuple[str, ...]: """Return character tags as a tuple.""" - return tuple(self.character_tag_data.keys()) + return tuple( + d.replace("_", " ").replace("(", R"\(").replace(")", R"\)") + for d in sorted( + self.character_tag_data, + key=lambda k: self.character_tag_data[k], + reverse=True, + ) + ) @property def rating(self) -> Literal["general", "sensitive", "questionable", "explicit"]: @@ -174,7 +189,7 @@ class Result: reverse=True, ) string = [x[0] for x in string] - return ", ".join(string).replace("_", " ").replace("(", R"\(").replace(")", R"\)") + return ", ".join(string) @property def character_tags_string(self) -> str: @@ -185,7 +200,7 @@ class Result: reverse=True, ) string = [x[0] for x in string] - return ", ".join(string).replace("_", " ").replace("(", R"\(").replace(")", R"\)") + return ", ".join(string) @property def all_tags(self) -> list[str]: @@ -194,7 +209,7 @@ class Result: @property def all_tags_string(self) -> str: - return ", ".join(self.all_tags).replace("_", " ").replace("(", R"\(").replace(")", R"\)") + return ", ".join(self.all_tags) def __str__(self) -> str: """Return a formatted string representation of the tags and their ratings.""" @@ -261,7 +276,7 @@ class Tagger: self._do_load_model(model_repo) def _do_load_model(self, model_repo: str) -> None: - model: nn.Module = timm.create_model("hf-hub:" + model_repo).eval() + model: nn.Module = timm.create_model(f"hf-hub:{model_repo}").eval() state_dict = timm.models.load_state_dict_from_hf(model_repo) model.load_state_dict(state_dict)