import random from typing import List, Union import torch def convert_byte_str_to_str(s: str, encoding: str = "utf-8") -> str: """ Extracts the actual string from a stringified bytes array (common in some webdatasets). Example: "b'hello world'" -> "hello world" """ try: s = s[2:-1] s = s.encode("utf-8").decode(encoding) except (UnicodeDecodeError, UnicodeEncodeError, IndexError): pass return s def dropout_caption(caption: Union[str, List[str]], dropout_p: float = 0) -> Union[str, List[str]]: if random.random() >= dropout_p: return caption if isinstance(caption, str): return "" return [""] * len(caption) def dropout_embeddings_to_zero(embed: torch.Tensor, dropout_p: float = 0) -> torch.Tensor: if random.random() >= dropout_p: return embed embed = torch.zeros_like(embed) return embed def remove_prefix(text: str, prefixes: List[str]) -> str: for prefix in prefixes: if text.startswith(prefix): return text.removeprefix(prefix).strip() return text