Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| import torch | |
| from torch import nn, Tensor | |
| from bubogpt.common.dist_utils import download_cached_file | |
| from bubogpt.common.utils import is_url | |
| from bubogpt.models.Qformer import BertConfig, BertLMHeadModel | |
| def disabled_train(self, mode=True): | |
| """Overwrite model.train with this function to make sure train/eval mode | |
| does not change anymore.""" | |
| return self | |
| class BaseQFormer(nn.Module): | |
| def __init__(self, freeze_qformer=False): | |
| super().__init__() | |
| self.freeze_qformer = freeze_qformer | |
| self.Qformer = None | |
| def check_and_freeze(self): | |
| assert self.Qformer is not None | |
| if self.freeze_qformer: | |
| for name, param in self.Qformer.named_parameters(): | |
| param.requires_grad = False | |
| self.Qformer = self.Qformer.eval() | |
| self.Qformer.train = disabled_train | |
| self.query_tokens.requires_grad = False | |
| logging.info("Freeze This QFormer") | |
| def load_from_pretrained(self, url_or_filename): | |
| if is_url(url_or_filename): | |
| cached_file = download_cached_file( | |
| url_or_filename, check_hash=False, progress=True | |
| ) | |
| checkpoint = torch.load(cached_file, map_location="cpu") | |
| elif os.path.isfile(url_or_filename): | |
| checkpoint = torch.load(url_or_filename, map_location="cpu") | |
| else: | |
| raise RuntimeError("checkpoint url or path is invalid") | |
| state_dict = checkpoint["model"] | |
| msg = self.load_state_dict(state_dict, strict=False) | |
| logging.info("Missing keys {}".format(msg.missing_keys)) | |
| logging.info("load checkpoint from %s" % url_or_filename) | |
| return msg | |
| class SequenceGenericQFormer(BaseQFormer): | |
| def __init__(self, | |
| num_query_token: int, | |
| encoder_width: int = 768, | |
| freeze_qformer: bool = False, | |
| q_former_model: str = "", | |
| cross_attention_freq: int = 2 | |
| ): | |
| super().__init__(freeze_qformer) | |
| self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, encoder_width, cross_attention_freq) | |
| if q_former_model != "": | |
| self.load_Qformer(q_former_model) | |
| self.check_and_freeze() | |
| def set_Qformer(self): | |
| self.Qformer.cls = None | |
| self.Qformer.bert.embeddings.word_embeddings = None | |
| self.Qformer.bert.embeddings.position_embeddings = None | |
| for layer in self.Qformer.bert.encoder.layer: | |
| layer.output = None | |
| layer.intermediate = None | |
| def load_Qformer(self, q_former_model): | |
| self.Qformer.cls = None | |
| self.Qformer.bert.embeddings.word_embeddings = None | |
| self.Qformer.bert.embeddings.position_embeddings = None | |
| for layer in self.Qformer.bert.encoder.layer: | |
| layer.output = None | |
| layer.intermediate = None | |
| self.load_from_pretrained(url_or_filename=q_former_model) | |
| def init_Qformer(cls, num_query_token, encoder_width, cross_attention_freq=2): | |
| encoder_config = BertConfig.from_pretrained("bert-base-uncased") | |
| encoder_config.encoder_width = encoder_width | |
| # insert cross-attention layer every other block | |
| encoder_config.add_cross_attention = True | |
| encoder_config.cross_attention_freq = cross_attention_freq | |
| encoder_config.query_length = num_query_token | |
| Qformer = BertLMHeadModel(config=encoder_config) | |
| query_tokens = nn.Parameter( | |
| torch.zeros(1, num_query_token, encoder_config.hidden_size) | |
| ) | |
| query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) | |
| return Qformer, query_tokens | |
| def forward(self, input_embeds: Tensor) -> Tensor: | |
| input_atts = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device) | |
| query_tokens = self.query_tokens.expand(input_embeds.shape[0], -1, -1) | |
| query_output = self.Qformer.bert( | |
| query_embeds=query_tokens, | |
| encoder_hidden_states=input_embeds, | |
| encoder_attention_mask=input_atts, | |
| return_dict=True, | |
| ) | |
| return query_output.last_hidden_state | |