Spaces:
Runtime error
Runtime error
| import copy | |
| import random | |
| import torch | |
| import torch.nn as nn | |
| from transformers import CLIPTokenizer | |
| from typing import Any, List, Optional, Union | |
| class TokenizerWrapper: | |
| """Tokenizer wrapper for CLIPTokenizer. Only support CLIPTokenizer | |
| currently. This wrapper is modified from https://github.com/huggingface/dif | |
| fusers/blob/e51f19aee82c8dd874b715a09dbc521d88835d68/src/diffusers/loaders. | |
| py#L358 # noqa. | |
| Args: | |
| from_pretrained (Union[str, os.PathLike], optional): The *model id* | |
| of a pretrained model or a path to a *directory* containing | |
| model weights and config. Defaults to None. | |
| from_config (Union[str, os.PathLike], optional): The *model id* | |
| of a pretrained model or a path to a *directory* containing | |
| model weights and config. Defaults to None. | |
| *args, **kwargs: If `from_pretrained` is passed, *args and **kwargs | |
| will be passed to `from_pretrained` function. Otherwise, *args | |
| and **kwargs will be used to initialize the model by | |
| `self._module_cls(*args, **kwargs)`. | |
| """ | |
| def __init__(self, tokenizer: CLIPTokenizer): | |
| self.wrapped = tokenizer | |
| self.token_map = {} | |
| def __getattr__(self, name: str) -> Any: | |
| if name in self.__dict__: | |
| return getattr(self, name) | |
| #if name == "wrapped": | |
| # return getattr(self, 'wrapped')#super().__getattr__("wrapped") | |
| try: | |
| return getattr(self.wrapped, name) | |
| except AttributeError: | |
| raise AttributeError( | |
| "'name' cannot be found in both " | |
| f"'{self.__class__.__name__}' and " | |
| f"'{self.__class__.__name__}.tokenizer'." | |
| ) | |
| def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs): | |
| """Attempt to add tokens to the tokenizer. | |
| Args: | |
| tokens (Union[str, List[str]]): The tokens to be added. | |
| """ | |
| num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs) | |
| assert num_added_tokens != 0, ( | |
| f"The tokenizer already contains the token {tokens}. Please pass " | |
| "a different `placeholder_token` that is not already in the " | |
| "tokenizer." | |
| ) | |
| def get_token_info(self, token: str) -> dict: | |
| """Get the information of a token, including its start and end index in | |
| the current tokenizer. | |
| Args: | |
| token (str): The token to be queried. | |
| Returns: | |
| dict: The information of the token, including its start and end | |
| index in current tokenizer. | |
| """ | |
| token_ids = self.__call__(token).input_ids | |
| start, end = token_ids[1], token_ids[-2] + 1 | |
| return {"name": token, "start": start, "end": end} | |
| def add_placeholder_token(self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs): | |
| """Add placeholder tokens to the tokenizer. | |
| Args: | |
| placeholder_token (str): The placeholder token to be added. | |
| num_vec_per_token (int, optional): The number of vectors of | |
| the added placeholder token. | |
| *args, **kwargs: The arguments for `self.wrapped.add_tokens`. | |
| """ | |
| output = [] | |
| if num_vec_per_token == 1: | |
| self.try_adding_tokens(placeholder_token, *args, **kwargs) | |
| output.append(placeholder_token) | |
| else: | |
| output = [] | |
| for i in range(num_vec_per_token): | |
| ith_token = placeholder_token + f"_{i}" | |
| self.try_adding_tokens(ith_token, *args, **kwargs) | |
| output.append(ith_token) | |
| for token in self.token_map: | |
| if token in placeholder_token: | |
| raise ValueError( | |
| f"The tokenizer already has placeholder token {token} " | |
| f"that can get confused with {placeholder_token} " | |
| "keep placeholder tokens independent" | |
| ) | |
| self.token_map[placeholder_token] = output | |
| def replace_placeholder_tokens_in_text( | |
| self, text: Union[str, List[str]], vector_shuffle: bool = False, prop_tokens_to_load: float = 1.0 | |
| ) -> Union[str, List[str]]: | |
| """Replace the keywords in text with placeholder tokens. This function | |
| will be called in `self.__call__` and `self.encode`. | |
| Args: | |
| text (Union[str, List[str]]): The text to be processed. | |
| vector_shuffle (bool, optional): Whether to shuffle the vectors. | |
| Defaults to False. | |
| prop_tokens_to_load (float, optional): The proportion of tokens to | |
| be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0. | |
| Returns: | |
| Union[str, List[str]]: The processed text. | |
| """ | |
| if isinstance(text, list): | |
| output = [] | |
| for i in range(len(text)): | |
| output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle)) | |
| return output | |
| for placeholder_token in self.token_map: | |
| if placeholder_token in text: | |
| tokens = self.token_map[placeholder_token] | |
| tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)] | |
| if vector_shuffle: | |
| tokens = copy.copy(tokens) | |
| random.shuffle(tokens) | |
| text = text.replace(placeholder_token, " ".join(tokens)) | |
| return text | |
| def replace_text_with_placeholder_tokens(self, text: Union[str, List[str]]) -> Union[str, List[str]]: | |
| """Replace the placeholder tokens in text with the original keywords. | |
| This function will be called in `self.decode`. | |
| Args: | |
| text (Union[str, List[str]]): The text to be processed. | |
| Returns: | |
| Union[str, List[str]]: The processed text. | |
| """ | |
| if isinstance(text, list): | |
| output = [] | |
| for i in range(len(text)): | |
| output.append(self.replace_text_with_placeholder_tokens(text[i])) | |
| return output | |
| for placeholder_token, tokens in self.token_map.items(): | |
| merged_tokens = " ".join(tokens) | |
| if merged_tokens in text: | |
| text = text.replace(merged_tokens, placeholder_token) | |
| return text | |
| def __call__( | |
| self, | |
| text: Union[str, List[str]], | |
| *args, | |
| vector_shuffle: bool = False, | |
| prop_tokens_to_load: float = 1.0, | |
| **kwargs, | |
| ): | |
| """The call function of the wrapper. | |
| Args: | |
| text (Union[str, List[str]]): The text to be tokenized. | |
| vector_shuffle (bool, optional): Whether to shuffle the vectors. | |
| Defaults to False. | |
| prop_tokens_to_load (float, optional): The proportion of tokens to | |
| be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0 | |
| *args, **kwargs: The arguments for `self.wrapped.__call__`. | |
| """ | |
| replaced_text = self.replace_placeholder_tokens_in_text( | |
| text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load | |
| ) | |
| return self.wrapped.__call__(replaced_text, *args, **kwargs) | |
| def encode(self, text: Union[str, List[str]], *args, **kwargs): | |
| """Encode the passed text to token index. | |
| Args: | |
| text (Union[str, List[str]]): The text to be encode. | |
| *args, **kwargs: The arguments for `self.wrapped.__call__`. | |
| """ | |
| replaced_text = self.replace_placeholder_tokens_in_text(text) | |
| return self.wrapped(replaced_text, *args, **kwargs) | |
| def decode(self, token_ids, return_raw: bool = False, *args, **kwargs) -> Union[str, List[str]]: | |
| """Decode the token index to text. | |
| Args: | |
| token_ids: The token index to be decoded. | |
| return_raw: Whether keep the placeholder token in the text. | |
| Defaults to False. | |
| *args, **kwargs: The arguments for `self.wrapped.decode`. | |
| Returns: | |
| Union[str, List[str]]: The decoded text. | |
| """ | |
| text = self.wrapped.decode(token_ids, *args, **kwargs) | |
| if return_raw: | |
| return text | |
| replaced_text = self.replace_text_with_placeholder_tokens(text) | |
| return replaced_text | |
| def __repr__(self): | |
| """The representation of the wrapper.""" | |
| s = super().__repr__() | |
| prefix = f"Wrapped Module Class: {self._module_cls}\n" | |
| prefix += f"Wrapped Module Name: {self._module_name}\n" | |
| if self._from_pretrained: | |
| prefix += f"From Pretrained: {self._from_pretrained}\n" | |
| s = prefix + s | |
| return s | |
| class EmbeddingLayerWithFixes(nn.Module): | |
| """The revised embedding layer to support external embeddings. This design | |
| of this class is inspired by https://github.com/AUTOMATIC1111/stable- | |
| diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi | |
| jack.py#L224 # noqa. | |
| Args: | |
| wrapped (nn.Emebdding): The embedding layer to be wrapped. | |
| external_embeddings (Union[dict, List[dict]], optional): The external | |
| embeddings added to this layer. Defaults to None. | |
| """ | |
| def __init__(self, wrapped: nn.Embedding, external_embeddings: Optional[Union[dict, List[dict]]] = None): | |
| super().__init__() | |
| self.wrapped = wrapped | |
| self.num_embeddings = wrapped.weight.shape[0] | |
| self.external_embeddings = [] | |
| if external_embeddings: | |
| self.add_embeddings(external_embeddings) | |
| self.trainable_embeddings = nn.ParameterDict() | |
| def weight(self): | |
| """Get the weight of wrapped embedding layer.""" | |
| return self.wrapped.weight | |
| def check_duplicate_names(self, embeddings: List[dict]): | |
| """Check whether duplicate names exist in list of 'external | |
| embeddings'. | |
| Args: | |
| embeddings (List[dict]): A list of embedding to be check. | |
| """ | |
| names = [emb["name"] for emb in embeddings] | |
| assert len(names) == len(set(names)), ( | |
| "Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'" | |
| ) | |
| def check_ids_overlap(self, embeddings): | |
| """Check whether overlap exist in token ids of 'external_embeddings'. | |
| Args: | |
| embeddings (List[dict]): A list of embedding to be check. | |
| """ | |
| ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings] | |
| ids_range.sort() # sort by 'start' | |
| # check if 'end' has overlapping | |
| for idx in range(len(ids_range) - 1): | |
| name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1] | |
| assert ids_range[idx][1] <= ids_range[idx + 1][0], ( | |
| f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'." | |
| ) | |
| def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]): | |
| """Add external embeddings to this layer. | |
| Use case: | |
| >>> 1. Add token to tokenizer and get the token id. | |
| >>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32') | |
| >>> # 'how much' in kiswahili | |
| >>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4) | |
| >>> | |
| >>> 2. Add external embeddings to the model. | |
| >>> new_embedding = { | |
| >>> 'name': 'ngapi', # 'how much' in kiswahili | |
| >>> 'embedding': torch.ones(1, 15) * 4, | |
| >>> 'start': tokenizer.get_token_info('kwaheri')['start'], | |
| >>> 'end': tokenizer.get_token_info('kwaheri')['end'], | |
| >>> 'trainable': False # if True, will registry as a parameter | |
| >>> } | |
| >>> embedding_layer = nn.Embedding(10, 15) | |
| >>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer) | |
| >>> embedding_layer_wrapper.add_embeddings(new_embedding) | |
| >>> | |
| >>> 3. Forward tokenizer and embedding layer! | |
| >>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?'] | |
| >>> input_ids = tokenizer( | |
| >>> input_text, padding='max_length', truncation=True, | |
| >>> return_tensors='pt')['input_ids'] | |
| >>> out_feat = embedding_layer_wrapper(input_ids) | |
| >>> | |
| >>> 4. Let's validate the result! | |
| >>> assert (out_feat[0, 3: 7] == 2.3).all() | |
| >>> assert (out_feat[2, 5: 9] == 2.3).all() | |
| Args: | |
| embeddings (Union[dict, list[dict]]): The external embeddings to | |
| be added. Each dict must contain the following 4 fields: 'name' | |
| (the name of this embedding), 'embedding' (the embedding | |
| tensor), 'start' (the start token id of this embedding), 'end' | |
| (the end token id of this embedding). For example: | |
| `{name: NAME, start: START, end: END, embedding: torch.Tensor}` | |
| """ | |
| if isinstance(embeddings, dict): | |
| embeddings = [embeddings] | |
| self.external_embeddings += embeddings | |
| self.check_duplicate_names(self.external_embeddings) | |
| self.check_ids_overlap(self.external_embeddings) | |
| # set for trainable | |
| added_trainable_emb_info = [] | |
| for embedding in embeddings: | |
| trainable = embedding.get("trainable", False) | |
| if trainable: | |
| name = embedding["name"] | |
| embedding["embedding"] = torch.nn.Parameter(embedding["embedding"]) | |
| self.trainable_embeddings[name] = embedding["embedding"] | |
| added_trainable_emb_info.append(name) | |
| added_emb_info = [emb["name"] for emb in embeddings] | |
| added_emb_info = ", ".join(added_emb_info) | |
| print(f"Successfully add external embeddings: {added_emb_info}.", "current") | |
| if added_trainable_emb_info: | |
| added_trainable_emb_info = ", ".join(added_trainable_emb_info) | |
| print("Successfully add trainable external embeddings: " f"{added_trainable_emb_info}", "current") | |
| def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: | |
| """Replace external input ids to 0. | |
| Args: | |
| input_ids (torch.Tensor): The input ids to be replaced. | |
| Returns: | |
| torch.Tensor: The replaced input ids. | |
| """ | |
| input_ids_fwd = input_ids.clone() | |
| input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0 | |
| return input_ids_fwd | |
| def replace_embeddings( | |
| self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict | |
| ) -> torch.Tensor: | |
| """Replace external embedding to the embedding layer. Noted that, in | |
| this function we use `torch.cat` to avoid inplace modification. | |
| Args: | |
| input_ids (torch.Tensor): The original token ids. Shape like | |
| [LENGTH, ]. | |
| embedding (torch.Tensor): The embedding of token ids after | |
| `replace_input_ids` function. | |
| external_embedding (dict): The external embedding to be replaced. | |
| Returns: | |
| torch.Tensor: The replaced embedding. | |
| """ | |
| new_embedding = [] | |
| name = external_embedding["name"] | |
| start = external_embedding["start"] | |
| end = external_embedding["end"] | |
| target_ids_to_replace = [i for i in range(start, end)] | |
| ext_emb = external_embedding["embedding"].to(embedding.device) | |
| # do not need to replace | |
| if not (input_ids == start).any(): | |
| return embedding | |
| # start replace | |
| s_idx, e_idx = 0, 0 | |
| while e_idx < len(input_ids): | |
| if input_ids[e_idx] == start: | |
| if e_idx != 0: | |
| # add embedding do not need to replace | |
| new_embedding.append(embedding[s_idx:e_idx]) | |
| # check if the next embedding need to replace is valid | |
| actually_ids_to_replace = [int(i) for i in input_ids[e_idx : e_idx + end - start]] | |
| assert actually_ids_to_replace == target_ids_to_replace, ( | |
| f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. " | |
| f"Expect '{target_ids_to_replace}' for embedding " | |
| f"'{name}' but found '{actually_ids_to_replace}'." | |
| ) | |
| new_embedding.append(ext_emb) | |
| s_idx = e_idx + end - start | |
| e_idx = s_idx + 1 | |
| else: | |
| e_idx += 1 | |
| if e_idx == len(input_ids): | |
| new_embedding.append(embedding[s_idx:e_idx]) | |
| return torch.cat(new_embedding, dim=0) | |
| def forward(self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None, out_dtype = None): | |
| """The forward function. | |
| Args: | |
| input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or | |
| [LENGTH, ]. | |
| external_embeddings (Optional[List[dict]]): The external | |
| embeddings. If not passed, only `self.external_embeddings` | |
| will be used. Defaults to None. | |
| input_ids: shape like [bz, LENGTH] or [LENGTH]. | |
| """ | |
| assert input_ids.ndim in [1, 2] | |
| if input_ids.ndim == 1: | |
| input_ids = input_ids.unsqueeze(0) | |
| if external_embeddings is None and not self.external_embeddings: | |
| return self.wrapped(input_ids, out_dtype=out_dtype) | |
| input_ids_fwd = self.replace_input_ids(input_ids) | |
| inputs_embeds = self.wrapped(input_ids_fwd) | |
| vecs = [] | |
| if external_embeddings is None: | |
| external_embeddings = [] | |
| elif isinstance(external_embeddings, dict): | |
| external_embeddings = [external_embeddings] | |
| embeddings = self.external_embeddings + external_embeddings | |
| for input_id, embedding in zip(input_ids, inputs_embeds): | |
| new_embedding = embedding | |
| for external_embedding in embeddings: | |
| new_embedding = self.replace_embeddings(input_id, new_embedding, external_embedding) | |
| vecs.append(new_embedding) | |
| return torch.stack(vecs).to(out_dtype) | |
| def add_tokens( | |
| tokenizer, text_encoder, placeholder_tokens: list, initialize_tokens: list = None, num_vectors_per_token: int = 1 | |
| ): | |
| """Add token for training. | |
| # TODO: support add tokens as dict, then we can load pretrained tokens. | |
| """ | |
| if initialize_tokens is not None: | |
| assert len(initialize_tokens) == len( | |
| placeholder_tokens | |
| ), "placeholder_token should be the same length as initialize_token" | |
| for ii in range(len(placeholder_tokens)): | |
| tokenizer.add_placeholder_token(placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token) | |
| # text_encoder.set_embedding_layer() | |
| embedding_layer = text_encoder.text_model.embeddings.token_embedding | |
| text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(embedding_layer) | |
| embedding_layer = text_encoder.text_model.embeddings.token_embedding | |
| assert embedding_layer is not None, ( | |
| "Do not support get embedding layer for current text encoder. " "Please check your configuration." | |
| ) | |
| initialize_embedding = [] | |
| if initialize_tokens is not None: | |
| for ii in range(len(placeholder_tokens)): | |
| init_id = tokenizer(initialize_tokens[ii]).input_ids[1] | |
| temp_embedding = embedding_layer.weight[init_id] | |
| initialize_embedding.append(temp_embedding[None, ...].repeat(num_vectors_per_token, 1)) | |
| else: | |
| for ii in range(len(placeholder_tokens)): | |
| init_id = tokenizer("a").input_ids[1] | |
| temp_embedding = embedding_layer.weight[init_id] | |
| len_emb = temp_embedding.shape[0] | |
| init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0 | |
| initialize_embedding.append(init_weight) | |
| # initialize_embedding = torch.cat(initialize_embedding,dim=0) | |
| token_info_all = [] | |
| for ii in range(len(placeholder_tokens)): | |
| token_info = tokenizer.get_token_info(placeholder_tokens[ii]) | |
| token_info["embedding"] = initialize_embedding[ii] | |
| token_info["trainable"] = True | |
| token_info_all.append(token_info) | |
| embedding_layer.add_embeddings(token_info_all) | |