Spaces:
Runtime error
Runtime error
| from typing import List | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM | |
| from .segment_anything import build_sam_vit_h | |
| from .unilm.beit3.modeling_utils import BEiT3Wrapper, _get_base_config, _get_large_config | |
| from .configuration_evf import EvfConfig | |
| def dice_loss( | |
| inputs: torch.Tensor, | |
| targets: torch.Tensor, | |
| num_masks: float, | |
| scale=1000, # 100000.0, | |
| eps=1e-6, | |
| ): | |
| """ | |
| Compute the DICE loss, similar to generalized IOU for masks | |
| Args: | |
| inputs: A float tensor of arbitrary shape. | |
| The predictions for each example. | |
| targets: A float tensor with the same shape as inputs. Stores the binary | |
| classification label for each element in inputs | |
| (0 for the negative class and 1 for the positive class). | |
| """ | |
| inputs = inputs.sigmoid() | |
| inputs = inputs.flatten(1, 2) | |
| targets = targets.flatten(1, 2) | |
| numerator = 2 * (inputs / scale * targets).sum(-1) | |
| denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1) | |
| loss = 1 - (numerator + eps) / (denominator + eps) | |
| loss = loss.sum() / (num_masks + 1e-8) | |
| return loss | |
| def sigmoid_ce_loss( | |
| inputs: torch.Tensor, | |
| targets: torch.Tensor, | |
| num_masks: float, | |
| ): | |
| """ | |
| Args: | |
| inputs: A float tensor of arbitrary shape. | |
| The predictions for each example. | |
| targets: A float tensor with the same shape as inputs. Stores the binary | |
| classification label for each element in inputs | |
| (0 for the negative class and 1 for the positive class). | |
| Returns: | |
| Loss tensor | |
| """ | |
| loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") | |
| loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8) | |
| return loss | |
| class EvfSamModel(PreTrainedModel): | |
| config_class = EvfConfig | |
| def __init__( | |
| self, | |
| config, | |
| **kwargs | |
| ): | |
| super(EvfSamModel, self).__init__(config) | |
| self.config = config | |
| self.vision_pretrained = kwargs.get("vision_pretrained", None) | |
| self.encoder_pretrained = kwargs.get("encoder_pretrained", None) | |
| self.dice_loss_weight = kwargs.get("dice_loss_weight", None) | |
| self.bce_loss_weight = kwargs.get("bce_loss_weight", None) | |
| self.train_mask_decoder = kwargs.get("train_mask_decoder", False) | |
| self.train_prompt_encoder = kwargs.get("train_prompt_encoder", False) | |
| self.initialize_evf_modules(config) | |
| def initialize_evf_modules(self, config): | |
| # SAM | |
| if config.sam_scale=="huge": | |
| self.visual_model = build_sam_vit_h(self.vision_pretrained) | |
| else: | |
| raise NotImplementedError | |
| for param in self.visual_model.parameters(): | |
| param.requires_grad = False | |
| if self.train_mask_decoder: | |
| self.visual_model.mask_decoder.train() | |
| for param in self.visual_model.mask_decoder.parameters(): | |
| param.requires_grad = True | |
| if self.train_prompt_encoder: | |
| self.visual_model.prompt_encoder.no_mask_embed.requires_grad_(True) | |
| # beit-3 | |
| if self.config.mm_extractor_scale == "base": | |
| beit_config = _get_base_config() | |
| elif self.config.mm_extractor_scale == "large": | |
| beit_config = _get_large_config() | |
| else: | |
| raise AttributeError(f"model config should contain key 'mm_extractor_scale', with value 'base' or 'large'.") | |
| self.mm_extractor = BEiT3Wrapper(beit_config) | |
| if self.encoder_pretrained is not None: | |
| beit_state_dict = torch.load(self.encoder_pretrained)["model"] | |
| self.mm_extractor.load_state_dict( | |
| beit_state_dict, | |
| strict=False | |
| ) | |
| for param in self.mm_extractor.parameters(): | |
| param.requires_grad = True | |
| # Projection layer | |
| in_dim = config.hidden_size | |
| assert in_dim==beit_config.encoder_embed_dim, \ | |
| f"projection layer dim {in_dim} mismatch with mm_extractor dim {beit_config.encoder_embed_dim}" | |
| out_dim = config.out_dim | |
| text_fc = [ | |
| nn.Linear(in_dim, in_dim), | |
| nn.ReLU(), | |
| nn.Linear(in_dim, out_dim) | |
| ] | |
| self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)]) | |
| self.text_hidden_fcs.train() | |
| for param in self.text_hidden_fcs.parameters(): | |
| param.requires_grad = True | |
| def get_visual_embs(self, pixel_values: torch.FloatTensor): | |
| with torch.no_grad(): | |
| image_embeddings_list = [] | |
| for i in range(pixel_values.shape[0]): | |
| torch.cuda.empty_cache() | |
| image_embeddings = self.visual_model.image_encoder( | |
| pixel_values[i].unsqueeze(0) | |
| ) | |
| image_embeddings_list.append(image_embeddings) | |
| torch.cuda.empty_cache() | |
| image_embeddings = torch.cat(image_embeddings_list, 0) | |
| return image_embeddings | |
| def forward( | |
| self, | |
| images: torch.FloatTensor, | |
| images_evf: torch.FloatTensor, | |
| input_ids: torch.LongTensor, | |
| attention_masks: torch.LongTensor, | |
| offset: torch.LongTensor, | |
| masks_list: List[torch.FloatTensor], | |
| label_list: List[torch.Tensor], | |
| resize_list: List[tuple], | |
| inference: bool = False, | |
| **kwargs, | |
| ): | |
| image_embeddings = self.get_visual_embs(images) | |
| batch_size = image_embeddings.shape[0] | |
| assert batch_size == len(offset) - 1 | |
| images_evf_list = [] | |
| for i in range(len(offset) - 1): | |
| start_i, end_i = offset[i], offset[i + 1] | |
| images_evf_i = ( | |
| images_evf[i] | |
| .unsqueeze(0) | |
| .expand(end_i - start_i, -1, -1, -1) | |
| .contiguous() | |
| ) | |
| images_evf_list.append(images_evf_i) | |
| images_evf = torch.cat(images_evf_list, dim=0) | |
| multimask_output = False | |
| output = self.mm_extractor.beit3( | |
| visual_tokens=images_evf, | |
| textual_tokens=input_ids, | |
| text_padding_position=~attention_masks | |
| ) | |
| feat = output["encoder_out"][:, :1, ...] | |
| feat = self.text_hidden_fcs[0](feat) | |
| feat = torch.split(feat, [offset[i+1] - offset[i] for i in range(len(offset)-1)]) | |
| pred_masks = [] | |
| for i in range(len(feat)): | |
| ( | |
| sparse_embeddings, | |
| dense_embeddings, | |
| ) = self.visual_model.prompt_encoder( | |
| points=None, | |
| boxes=None, | |
| masks=None, | |
| text_embeds=feat[i], | |
| ) | |
| sparse_embeddings = sparse_embeddings.to(feat[i].dtype) | |
| low_res_masks, iou_predictions = self.visual_model.mask_decoder( | |
| image_embeddings=image_embeddings[i].unsqueeze(0), | |
| image_pe=self.visual_model.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=multimask_output, | |
| ) | |
| if multimask_output: | |
| sorted_ids = torch.argsort(iou_predictions, dim=-1, descending=True) | |
| low_res_masks = torch.take_along_dim(low_res_masks, sorted_ids[..., None, None], dim=1)[:, :1] | |
| pred_mask = self.visual_model.postprocess_masks( | |
| low_res_masks, | |
| input_size=resize_list[i], | |
| original_size=label_list[i].shape, | |
| ) | |
| pred_masks.append(pred_mask[:, 0]) | |
| gt_masks = masks_list | |
| if inference: | |
| return { | |
| "pred_masks": pred_masks, | |
| "gt_masks": gt_masks, | |
| } | |
| mask_bce_loss = 0 | |
| mask_dice_loss = 0 | |
| num_masks = 0 | |
| for batch_idx in range(len(pred_masks)): | |
| gt_mask = gt_masks[batch_idx] | |
| pred_mask = pred_masks[batch_idx] | |
| assert ( | |
| gt_mask.shape[0] == pred_mask.shape[0] | |
| ), "gt_mask.shape: {}, pred_mask.shape: {}".format( | |
| gt_mask.shape, pred_mask.shape | |
| ) | |
| mask_bce_loss += ( | |
| sigmoid_ce_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0]) | |
| * gt_mask.shape[0] | |
| ) | |
| mask_dice_loss += ( | |
| dice_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0]) | |
| * gt_mask.shape[0] | |
| ) | |
| num_masks += gt_mask.shape[0] | |
| mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8) | |
| mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8) | |
| mask_loss = mask_bce_loss + mask_dice_loss | |
| loss = mask_loss | |
| return { | |
| "loss": loss, | |
| "mask_bce_loss": mask_bce_loss, | |
| "mask_dice_loss": mask_dice_loss, | |
| "mask_loss": mask_loss, | |
| } | |
| def inference( | |
| self, | |
| images, | |
| images_evf, | |
| input_ids, | |
| resize_list, | |
| original_size_list, | |
| multimask_output=False, | |
| ): | |
| with torch.no_grad(): | |
| image_embeddings = self.visual_model.image_encoder(images) | |
| multimask_output = multimask_output | |
| output = self.mm_extractor.beit3(visual_tokens=images_evf, textual_tokens=input_ids, text_padding_position=torch.zeros_like(input_ids)) | |
| feat = output["encoder_out"][:, :1, ...] | |
| feat = self.text_hidden_fcs[0](feat) | |
| ( | |
| sparse_embeddings, | |
| dense_embeddings, | |
| ) = self.visual_model.prompt_encoder( | |
| points=None, | |
| boxes=None, | |
| masks=None, | |
| text_embeds=feat, | |
| ) | |
| sparse_embeddings = sparse_embeddings.to(feat.dtype) | |
| low_res_masks, iou_predictions = self.visual_model.mask_decoder( | |
| image_embeddings=image_embeddings, | |
| image_pe=self.visual_model.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=multimask_output, | |
| ) | |
| if multimask_output: | |
| sorted_ids = torch.argsort(iou_predictions, dim=-1, descending=True) | |
| low_res_masks = torch.take_along_dim(low_res_masks, sorted_ids[..., None, None], dim=1)[:, :1] | |
| pred_mask = self.visual_model.postprocess_masks( | |
| low_res_masks, | |
| input_size=resize_list[0], | |
| original_size=original_size_list[0], | |
| ) | |
| return pred_mask[:, 0] | |
| AutoConfig.register("evf", EvfConfig) | |
| AutoModelForCausalLM.register(EvfConfig, EvfSamModel) |