jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
import random
from typing import List, Union
import torch
def convert_byte_str_to_str(s: str, encoding: str = "utf-8") -> str:
"""
Extracts the actual string from a stringified bytes array (common in some webdatasets).
Example: "b'hello world'" -> "hello world"
"""
try:
s = s[2:-1]
s = s.encode("utf-8").decode(encoding)
except (UnicodeDecodeError, UnicodeEncodeError, IndexError):
pass
return s
def dropout_caption(caption: Union[str, List[str]], dropout_p: float = 0) -> Union[str, List[str]]:
if random.random() >= dropout_p:
return caption
if isinstance(caption, str):
return ""
return [""] * len(caption)
def dropout_embeddings_to_zero(embed: torch.Tensor, dropout_p: float = 0) -> torch.Tensor:
if random.random() >= dropout_p:
return embed
embed = torch.zeros_like(embed)
return embed
def remove_prefix(text: str, prefixes: List[str]) -> str:
for prefix in prefixes:
if text.startswith(prefix):
return text.removeprefix(prefix).strip()
return text