Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from torch.utils.checkpoint import checkpoint | |
| from transformers import ( | |
| T5Tokenizer, | |
| T5EncoderModel, | |
| CLIPTokenizer, | |
| CLIPTextModel, | |
| AutoProcessor, | |
| CLIPVisionModelWithProjection, | |
| ) | |
| from iopaint.model.anytext.ldm.util import count_params | |
| def _expand_mask(mask, dtype, tgt_len=None): | |
| """ | |
| Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. | |
| """ | |
| bsz, src_len = mask.size() | |
| tgt_len = tgt_len if tgt_len is not None else src_len | |
| expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) | |
| inverted_mask = 1.0 - expanded_mask | |
| return inverted_mask.masked_fill( | |
| inverted_mask.to(torch.bool), torch.finfo(dtype).min | |
| ) | |
| def _build_causal_attention_mask(bsz, seq_len, dtype): | |
| # lazily create causal attention mask, with full attention between the vision tokens | |
| # pytorch uses additive attention mask; fill with -inf | |
| mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) | |
| mask.fill_(torch.tensor(torch.finfo(dtype).min)) | |
| mask.triu_(1) # zero out the lower diagonal | |
| mask = mask.unsqueeze(1) # expand mask | |
| return mask | |
| class AbstractEncoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def encode(self, *args, **kwargs): | |
| raise NotImplementedError | |
| class IdentityEncoder(AbstractEncoder): | |
| def encode(self, x): | |
| return x | |
| class ClassEmbedder(nn.Module): | |
| def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1): | |
| super().__init__() | |
| self.key = key | |
| self.embedding = nn.Embedding(n_classes, embed_dim) | |
| self.n_classes = n_classes | |
| self.ucg_rate = ucg_rate | |
| def forward(self, batch, key=None, disable_dropout=False): | |
| if key is None: | |
| key = self.key | |
| # this is for use in crossattn | |
| c = batch[key][:, None] | |
| if self.ucg_rate > 0.0 and not disable_dropout: | |
| mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) | |
| c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) | |
| c = c.long() | |
| c = self.embedding(c) | |
| return c | |
| def get_unconditional_conditioning(self, bs, device="cuda"): | |
| uc_class = ( | |
| self.n_classes - 1 | |
| ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) | |
| uc = torch.ones((bs,), device=device) * uc_class | |
| uc = {self.key: uc} | |
| return uc | |
| def disabled_train(self, mode=True): | |
| """Overwrite model.train with this function to make sure train/eval mode | |
| does not change anymore.""" | |
| return self | |
| class FrozenT5Embedder(AbstractEncoder): | |
| """Uses the T5 transformer encoder for text""" | |
| def __init__( | |
| self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True | |
| ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl | |
| super().__init__() | |
| self.tokenizer = T5Tokenizer.from_pretrained(version) | |
| self.transformer = T5EncoderModel.from_pretrained(version) | |
| self.device = device | |
| self.max_length = max_length # TODO: typical value? | |
| if freeze: | |
| self.freeze() | |
| def freeze(self): | |
| self.transformer = self.transformer.eval() | |
| # self.train = disabled_train | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, text): | |
| batch_encoding = self.tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_length=True, | |
| return_overflowing_tokens=False, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| tokens = batch_encoding["input_ids"].to(self.device) | |
| outputs = self.transformer(input_ids=tokens) | |
| z = outputs.last_hidden_state | |
| return z | |
| def encode(self, text): | |
| return self(text) | |
| class FrozenCLIPEmbedder(AbstractEncoder): | |
| """Uses the CLIP transformer encoder for text (from huggingface)""" | |
| LAYERS = ["last", "pooled", "hidden"] | |
| def __init__( | |
| self, | |
| version="openai/clip-vit-large-patch14", | |
| device="cuda", | |
| max_length=77, | |
| freeze=True, | |
| layer="last", | |
| layer_idx=None, | |
| ): # clip-vit-base-patch32 | |
| super().__init__() | |
| assert layer in self.LAYERS | |
| self.tokenizer = CLIPTokenizer.from_pretrained(version) | |
| self.transformer = CLIPTextModel.from_pretrained(version) | |
| self.device = device | |
| self.max_length = max_length | |
| if freeze: | |
| self.freeze() | |
| self.layer = layer | |
| self.layer_idx = layer_idx | |
| if layer == "hidden": | |
| assert layer_idx is not None | |
| assert 0 <= abs(layer_idx) <= 12 | |
| def freeze(self): | |
| self.transformer = self.transformer.eval() | |
| # self.train = disabled_train | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, text): | |
| batch_encoding = self.tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_length=True, | |
| return_overflowing_tokens=False, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| tokens = batch_encoding["input_ids"].to(self.device) | |
| outputs = self.transformer( | |
| input_ids=tokens, output_hidden_states=self.layer == "hidden" | |
| ) | |
| if self.layer == "last": | |
| z = outputs.last_hidden_state | |
| elif self.layer == "pooled": | |
| z = outputs.pooler_output[:, None, :] | |
| else: | |
| z = outputs.hidden_states[self.layer_idx] | |
| return z | |
| def encode(self, text): | |
| return self(text) | |
| class FrozenCLIPT5Encoder(AbstractEncoder): | |
| def __init__( | |
| self, | |
| clip_version="openai/clip-vit-large-patch14", | |
| t5_version="google/t5-v1_1-xl", | |
| device="cuda", | |
| clip_max_length=77, | |
| t5_max_length=77, | |
| ): | |
| super().__init__() | |
| self.clip_encoder = FrozenCLIPEmbedder( | |
| clip_version, device, max_length=clip_max_length | |
| ) | |
| self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) | |
| print( | |
| f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " | |
| f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params." | |
| ) | |
| def encode(self, text): | |
| return self(text) | |
| def forward(self, text): | |
| clip_z = self.clip_encoder.encode(text) | |
| t5_z = self.t5_encoder.encode(text) | |
| return [clip_z, t5_z] | |
| class FrozenCLIPEmbedderT3(AbstractEncoder): | |
| """Uses the CLIP transformer encoder for text (from Hugging Face)""" | |
| def __init__( | |
| self, | |
| version="openai/clip-vit-large-patch14", | |
| device="cuda", | |
| max_length=77, | |
| freeze=True, | |
| use_vision=False, | |
| ): | |
| super().__init__() | |
| self.tokenizer = CLIPTokenizer.from_pretrained(version) | |
| self.transformer = CLIPTextModel.from_pretrained(version) | |
| if use_vision: | |
| self.vit = CLIPVisionModelWithProjection.from_pretrained(version) | |
| self.processor = AutoProcessor.from_pretrained(version) | |
| self.device = device | |
| self.max_length = max_length | |
| if freeze: | |
| self.freeze() | |
| def embedding_forward( | |
| self, | |
| input_ids=None, | |
| position_ids=None, | |
| inputs_embeds=None, | |
| embedding_manager=None, | |
| ): | |
| seq_length = ( | |
| input_ids.shape[-1] | |
| if input_ids is not None | |
| else inputs_embeds.shape[-2] | |
| ) | |
| if position_ids is None: | |
| position_ids = self.position_ids[:, :seq_length] | |
| if inputs_embeds is None: | |
| inputs_embeds = self.token_embedding(input_ids) | |
| if embedding_manager is not None: | |
| inputs_embeds = embedding_manager(input_ids, inputs_embeds) | |
| position_embeddings = self.position_embedding(position_ids) | |
| embeddings = inputs_embeds + position_embeddings | |
| return embeddings | |
| self.transformer.text_model.embeddings.forward = embedding_forward.__get__( | |
| self.transformer.text_model.embeddings | |
| ) | |
| def encoder_forward( | |
| self, | |
| inputs_embeds, | |
| attention_mask=None, | |
| causal_attention_mask=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ): | |
| output_attentions = ( | |
| output_attentions | |
| if output_attentions is not None | |
| else self.config.output_attentions | |
| ) | |
| output_hidden_states = ( | |
| output_hidden_states | |
| if output_hidden_states is not None | |
| else self.config.output_hidden_states | |
| ) | |
| return_dict = ( | |
| return_dict if return_dict is not None else self.config.use_return_dict | |
| ) | |
| encoder_states = () if output_hidden_states else None | |
| all_attentions = () if output_attentions else None | |
| hidden_states = inputs_embeds | |
| for idx, encoder_layer in enumerate(self.layers): | |
| if output_hidden_states: | |
| encoder_states = encoder_states + (hidden_states,) | |
| layer_outputs = encoder_layer( | |
| hidden_states, | |
| attention_mask, | |
| causal_attention_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| if output_attentions: | |
| all_attentions = all_attentions + (layer_outputs[1],) | |
| if output_hidden_states: | |
| encoder_states = encoder_states + (hidden_states,) | |
| return hidden_states | |
| self.transformer.text_model.encoder.forward = encoder_forward.__get__( | |
| self.transformer.text_model.encoder | |
| ) | |
| def text_encoder_forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| position_ids=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| embedding_manager=None, | |
| ): | |
| output_attentions = ( | |
| output_attentions | |
| if output_attentions is not None | |
| else self.config.output_attentions | |
| ) | |
| output_hidden_states = ( | |
| output_hidden_states | |
| if output_hidden_states is not None | |
| else self.config.output_hidden_states | |
| ) | |
| return_dict = ( | |
| return_dict if return_dict is not None else self.config.use_return_dict | |
| ) | |
| if input_ids is None: | |
| raise ValueError("You have to specify either input_ids") | |
| input_shape = input_ids.size() | |
| input_ids = input_ids.view(-1, input_shape[-1]) | |
| hidden_states = self.embeddings( | |
| input_ids=input_ids, | |
| position_ids=position_ids, | |
| embedding_manager=embedding_manager, | |
| ) | |
| bsz, seq_len = input_shape | |
| # CLIP's text model uses causal mask, prepare it here. | |
| # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 | |
| causal_attention_mask = _build_causal_attention_mask( | |
| bsz, seq_len, hidden_states.dtype | |
| ).to(hidden_states.device) | |
| # expand attention_mask | |
| if attention_mask is not None: | |
| # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
| attention_mask = _expand_mask(attention_mask, hidden_states.dtype) | |
| last_hidden_state = self.encoder( | |
| inputs_embeds=hidden_states, | |
| attention_mask=attention_mask, | |
| causal_attention_mask=causal_attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| last_hidden_state = self.final_layer_norm(last_hidden_state) | |
| return last_hidden_state | |
| self.transformer.text_model.forward = text_encoder_forward.__get__( | |
| self.transformer.text_model | |
| ) | |
| def transformer_forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| position_ids=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| embedding_manager=None, | |
| ): | |
| return self.text_model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| embedding_manager=embedding_manager, | |
| ) | |
| self.transformer.forward = transformer_forward.__get__(self.transformer) | |
| def freeze(self): | |
| self.transformer = self.transformer.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, text, **kwargs): | |
| batch_encoding = self.tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_length=True, | |
| return_overflowing_tokens=False, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| tokens = batch_encoding["input_ids"].to(self.device) | |
| z = self.transformer(input_ids=tokens, **kwargs) | |
| return z | |
| def encode(self, text, **kwargs): | |
| return self(text, **kwargs) | |