refactor(wdtagger): improve tag formatting and model loading

This commit is contained in:
Jianqi Pan 2024-11-26 22:48:41 +09:00
parent bd85a5643a
commit 9fb556c41c

View File

@ -153,12 +153,27 @@ class Result:
@property
def general_tags(self) -> tuple[str, ...]:
"""Return general tags as a tuple."""
return tuple(self.general_tag_data.keys())
return tuple(
d.replace("_", " ").replace("(", R"\(").replace(")", R"\)")
for d in sorted(
self.general_tag_data,
key=lambda k: self.general_tag_data[k],
reverse=True,
)
)
@property
def character_tags(self) -> tuple[str, ...]:
"""Return character tags as a tuple."""
return tuple(self.character_tag_data.keys())
return tuple(
d.replace("_", " ").replace("(", R"\(").replace(")", R"\)")
for d in sorted(
self.character_tag_data,
key=lambda k: self.character_tag_data[k],
reverse=True,
)
)
@property
def rating(self) -> Literal["general", "sensitive", "questionable", "explicit"]:
@ -174,7 +189,7 @@ class Result:
reverse=True,
)
string = [x[0] for x in string]
return ", ".join(string).replace("_", " ").replace("(", R"\(").replace(")", R"\)")
return ", ".join(string)
@property
def character_tags_string(self) -> str:
@ -185,7 +200,7 @@ class Result:
reverse=True,
)
string = [x[0] for x in string]
return ", ".join(string).replace("_", " ").replace("(", R"\(").replace(")", R"\)")
return ", ".join(string)
@property
def all_tags(self) -> list[str]:
@ -194,7 +209,7 @@ class Result:
@property
def all_tags_string(self) -> str:
return ", ".join(self.all_tags).replace("_", " ").replace("(", R"\(").replace(")", R"\)")
return ", ".join(self.all_tags)
def __str__(self) -> str:
"""Return a formatted string representation of the tags and their ratings."""
@ -261,7 +276,7 @@ class Tagger:
self._do_load_model(model_repo)
def _do_load_model(self, model_repo: str) -> None:
model: nn.Module = timm.create_model("hf-hub:" + model_repo).eval()
model: nn.Module = timm.create_model(f"hf-hub:{model_repo}").eval()
state_dict = timm.models.load_state_dict_from_hf(model_repo)
model.load_state_dict(state_dict)