| | |
| | |
| | |
| | |
| | |
| | from typing import Optional, Tuple, Union |
| | from functools import partial |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import torch.utils.checkpoint |
| |
|
| | from einops import rearrange |
| | from timm.models.layers import DropPath |
| | from torch import nn |
| | from transformers.activations import ACT2FN |
| | from transformers.modeling_outputs import (BaseModelOutput, |
| | BaseModelOutputWithPooling) |
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.utils import logging |
| |
|
| | from .configuration_navil_vit import NaViLVisionConfig |
| | from .modular_intern_vit import ( |
| | InternVisionFlashAttention2, |
| | InternVisionSdpaAttention, |
| | InternMLP, |
| | NORM2FN, |
| | InternVisionRotaryEmbedding, |
| | ) |
| |
|
| | try: |
| | |
| | from flash_attn import flash_attn_varlen_func |
| | from flash_attn.layers.rotary import apply_rotary_emb |
| | has_flash_attn = True |
| | except: |
| | print('FlashAttention is not installed.') |
| | has_flash_attn = False |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class NaViLVisionEmbeddingsAnyRes(nn.Module): |
| | def __init__(self, config: NaViLVisionConfig): |
| | super().__init__() |
| | self.config = config |
| | self.embed_dim = config.hidden_size |
| | self.image_size = config.image_size |
| | self.patch_size = config.patch_size |
| | self.merge_size = int(1.0 / config.downsample_ratio) |
| |
|
| | self.patch_embedding = nn.Conv2d( |
| | in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size |
| | ) |
| |
|
| | self.num_patches = (self.image_size // self.patch_size) ** 2 |
| | self.num_positions = self.num_patches + 1 |
| |
|
| | def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: |
| | target_dtype = self.patch_embedding.weight.dtype |
| | patch_embeds = self.patch_embedding(pixel_values) |
| | batch_size, _, height, width = patch_embeds.shape |
| |
|
| | return patch_embeds.flatten(1) |
| |
|
| |
|
| | class NaViLVisionEncoderLayerAnyRes(nn.Module): |
| | def __init__(self, config: NaViLVisionConfig, drop_path_rate: float): |
| | super().__init__() |
| | self.embed_dim = config.hidden_size |
| | self.intermediate_size = config.intermediate_size |
| | self.norm_type = config.norm_type |
| |
|
| | if has_flash_attn: |
| | self.attn = InternVisionFlashAttention2(config) |
| | else: |
| | self.attn = InternVisionSdpaAttention(config) |
| | self.mlp = InternMLP(config) |
| | self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) |
| | self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) |
| |
|
| | self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) |
| | self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) |
| | self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
| | self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | cu_seqlens, |
| | rotary_pos_emb |
| | ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]: |
| | """ |
| | Args: |
| | hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| | """ |
| | hidden_states = hidden_states + self.drop_path1( |
| | self.attn( |
| | self.norm1(hidden_states), |
| | cu_seqlens=cu_seqlens, |
| | rotary_pos_emb=rotary_pos_emb, |
| | ) * self.ls1) |
| |
|
| | hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2) |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class NaViLVisionEncoderAnyRes(nn.Module): |
| | """ |
| | Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a |
| | [`InternEncoderLayer`]. |
| | |
| | Args: |
| | config (`InternConfig`): |
| | The corresponding vision configuration for the `InternEncoder`. |
| | """ |
| |
|
| | def __init__(self, config: NaViLVisionConfig): |
| | super().__init__() |
| | self.config = config |
| | |
| | dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)] |
| | self.layers = nn.ModuleList([ |
| | NaViLVisionEncoderLayerAnyRes(config, dpr[idx]) for idx in range(config.num_hidden_layers)]) |
| | self.gradient_checkpointing = True |
| |
|
| | head_dim = config.hidden_size // config.num_attention_heads |
| | self.rotary_pos_emb = InternVisionRotaryEmbedding(head_dim // 2) |
| |
|
| | self.merge_size = int(1.0 / config.downsample_ratio) |
| | self.merge_unit = self.merge_size * self.merge_size |
| | self.patch_size = config.patch_size |
| | self.fullatt_block_indexes = config.fullatt_block_indexes |
| | self.window_size = config.window_size |
| | |
| | def rot_pos_emb(self, grid_thw): |
| | pos_ids = [] |
| | for t, h, w in grid_thw: |
| | hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) |
| | hpos_ids = hpos_ids.reshape( |
| | h // self.merge_size, |
| | self.merge_size, |
| | w // self.merge_size, |
| | self.merge_size, |
| | ) |
| | hpos_ids = hpos_ids.permute(0, 2, 1, 3) |
| | hpos_ids = hpos_ids.flatten() |
| |
|
| | wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) |
| | wpos_ids = wpos_ids.reshape( |
| | h // self.merge_size, |
| | self.merge_size, |
| | w // self.merge_size, |
| | self.merge_size, |
| | ) |
| | wpos_ids = wpos_ids.permute(0, 2, 1, 3) |
| | wpos_ids = wpos_ids.flatten() |
| | pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) |
| | pos_ids = torch.cat(pos_ids, dim=0) |
| | max_grid_size = grid_thw[:, 1:].max() |
| | rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) |
| | rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) |
| | return rotary_pos_emb |
| | |
| | def get_window_index(self, grid_thw): |
| | window_index: list = [] |
| | cu_window_seqlens: list = [0] |
| | window_index_id = 0 |
| | vit_merger_window_size = self.window_size // self.merge_size |
| | assert vit_merger_window_size > 0 |
| |
|
| | for grid_t, grid_h, grid_w in grid_thw: |
| | llm_grid_h, llm_grid_w = ( |
| | grid_h // self.merge_size, |
| | grid_w // self.merge_size, |
| | ) |
| | index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) |
| | pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size |
| | pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size |
| | num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size |
| | num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size |
| | index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) |
| | index_padded = index_padded.reshape( |
| | grid_t, |
| | num_windows_h, |
| | vit_merger_window_size, |
| | num_windows_w, |
| | vit_merger_window_size, |
| | ) |
| | index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( |
| | grid_t, |
| | num_windows_h * num_windows_w, |
| | vit_merger_window_size, |
| | vit_merger_window_size, |
| | ) |
| | seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) |
| | index_padded = index_padded.reshape(-1) |
| | index_new = index_padded[index_padded != -100] |
| | window_index.append(index_new + window_index_id) |
| | cu_seqlens_tmp = seqlens.cumsum(0) * self.merge_unit + cu_window_seqlens[-1] |
| | cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) |
| | window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() |
| | window_index = torch.cat(window_index, dim=0) |
| |
|
| | return window_index, cu_window_seqlens |
| |
|
| | def forward( |
| | self, |
| | inputs_embeds, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | grid_thw: Optional[torch.Tensor] = None, |
| | ) -> Union[Tuple, BaseModelOutput]: |
| | r""" |
| | Args: |
| | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| | Embedded representation of the inputs. Should be float, not int tokens. |
| | output_hidden_states (`bool`, *optional*): |
| | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| | for more detail. |
| | return_dict (`bool`, *optional*): |
| | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| | """ |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | encoder_states = () if output_hidden_states else None |
| | hidden_states = inputs_embeds |
| |
|
| | rotary_pos_emb = self.rot_pos_emb(grid_thw) |
| | window_index, cu_window_seqlens = self.get_window_index(grid_thw) |
| | cu_window_seqlens = torch.tensor( |
| | cu_window_seqlens, |
| | device=hidden_states.device, |
| | dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
| | ) |
| | cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) |
| |
|
| | seq_len, _ = hidden_states.size() |
| | hidden_states = hidden_states.reshape(seq_len // self.merge_unit, self.merge_unit, -1) |
| | hidden_states = hidden_states[window_index, :, :] |
| | hidden_states = hidden_states.reshape(seq_len, -1) |
| | rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.merge_unit, self.merge_unit, -1) |
| | rotary_pos_emb = rotary_pos_emb[window_index, :, :] |
| | rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) |
| |
|
| | cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( |
| | dim=0, |
| | |
| | |
| | |
| | |
| | dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
| | ) |
| | cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) |
| |
|
| |
|
| | for idx, encoder_layer in enumerate(self.layers): |
| | if (self.fullatt_block_indexes is None) or (idx in self.fullatt_block_indexes): |
| | cu_seqlens_now = cu_seqlens |
| | else: |
| | cu_seqlens_now = cu_window_seqlens |
| | if output_hidden_states: |
| | encoder_states = encoder_states + (hidden_states,) |
| | if self.gradient_checkpointing and self.training: |
| | layer_outputs = torch.utils.checkpoint.checkpoint( |
| | partial(encoder_layer, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb), |
| | hidden_states) |
| | else: |
| | layer_outputs = encoder_layer( |
| | hidden_states, |
| | cu_seqlens=cu_seqlens_now, |
| | rotary_pos_emb=rotary_pos_emb, |
| | ) |
| | hidden_states = layer_outputs |
| |
|
| | if output_hidden_states: |
| | encoder_states = encoder_states + (hidden_states,) |
| |
|
| | if not return_dict: |
| | return tuple(v for v in [hidden_states, encoder_states] if v is not None) |
| | return BaseModelOutput( |
| | last_hidden_state=hidden_states, hidden_states=encoder_states |
| | ) |
| |
|
| |
|
| | class NaViLVisionModelAnyRes(PreTrainedModel): |
| | main_input_name = 'pixel_values' |
| | config_class = NaViLVisionConfig |
| | _no_split_modules = ['NaViLVisionEncoderLayerAnyRes'] |
| |
|
| | def __init__(self, config: NaViLVisionConfig): |
| | super().__init__(config) |
| | self.config = config |
| | |
| | self.merge_size = int(1.0 / config.downsample_ratio) |
| | self.embeddings = NaViLVisionEmbeddingsAnyRes(config) |
| | self.encoder = NaViLVisionEncoderAnyRes(config) |
| | |
| | def get_input_embeddings(self): |
| | return self.embeddings |
| |
|
| | def forward( |
| | self, |
| | pixel_values: Optional[torch.FloatTensor] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | pixel_embeds: Optional[torch.FloatTensor] = None, |
| | grid_thw: Optional[torch.Tensor] = None, |
| | ) -> Union[Tuple, BaseModelOutputWithPooling]: |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | if pixel_values is None and pixel_embeds is None: |
| | raise ValueError('You have to specify pixel_values or pixel_embeds') |
| |
|
| | if pixel_embeds is not None: |
| | hidden_states = pixel_embeds |
| | else: |
| | if len(pixel_values.shape) == 4: |
| | hidden_states = self.embeddings(pixel_values) |
| | else: |
| | raise ValueError(f'wrong pixel_values size: {pixel_values.shape}') |
| | |
| | encoder_outputs = self.encoder( |
| | inputs_embeds=hidden_states, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | grid_thw=grid_thw |
| | ) |
| | last_hidden_state = encoder_outputs.last_hidden_state |
| | |
| |
|
| | last_hidden_state = last_hidden_state.unsqueeze(1).reshape(-1, self.merge_size, self.merge_size, last_hidden_state.shape[-1]) |
| |
|
| | if not return_dict: |
| | return (last_hidden_state, ) + encoder_outputs[1:] |
| | |
| | return BaseModelOutputWithPooling( |
| | last_hidden_state=last_hidden_state, |
| | pooler_output=None, |
| | hidden_states=encoder_outputs.hidden_states, |
| | attentions=encoder_outputs.attentions, |
| | ) |
| |
|