Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
This module implements a Vision Transformer (ViT) with 2D Rotary Position Embeddings, | |
designed for processing image inputs in vision-language models. | |
This module follows Mistral's vision encoder implementation (for their Pistral-12B VLM): | |
https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py | |
""" | |
from functools import partial | |
from typing import Any, Callable, Mapping, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
from cosmos_predict1.autoregressive.modules.normalization import create_norm | |
from cosmos_predict1.autoregressive.networks.transformer import TransformerBlock | |
from cosmos_predict1.utils import log | |
def get_vit_config(model_name: str) -> Mapping[str, Any]: | |
""" | |
Get the ViT configuration for a given model name. | |
""" | |
if model_name == "pixtral-12b-vit": | |
# The 400M ViT of Pixtral 12B VLM | |
return dict( | |
dim=1024, | |
num_channels=3, | |
image_size=1024, | |
patch_size=16, | |
rope_theta=10000, | |
ffn_hidden_size=4096, | |
n_layers=24, | |
n_heads=16, | |
n_kv_heads=16, | |
norm_type="rmsnorm", | |
norm_eps=1e-5, | |
image_token_id=10, | |
) | |
else: | |
raise ValueError(f"Unknown model name: {model_name}") | |
def precompute_freqs_cis_2d( | |
dim: int, | |
height: int, | |
width: int, | |
theta: float, | |
) -> torch.Tensor: | |
""" | |
Precompute 2D complex tensor for rotary position embedding. | |
This function generates a 2D complex tensor used for rotary position embeddings, | |
which helps the model understand spatial relationships in the input image. | |
Args: | |
dim (int): Dimension of the model (typically the hidden size divided by number of heads). | |
height (int): Height of the image in patches. | |
width (int): Width of the image in patches. | |
theta (float): Base value for the angle calculation, controls the frequency range. | |
Returns: | |
torch.Tensor: 2D complex tensor of shape (height, width, dim // 2). | |
""" | |
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) | |
h = torch.arange(height, device=freqs.device) | |
w = torch.arange(width, device=freqs.device) | |
freqs_h = torch.outer(h, freqs[::2]).float() | |
freqs_w = torch.outer(w, freqs[1::2]).float() | |
freqs_2d = torch.cat( | |
[ | |
freqs_h[:, None, :].repeat(1, width, 1), | |
freqs_w[None, :, :].repeat(height, 1, 1), | |
], | |
dim=-1, | |
) | |
return torch.polar(torch.ones_like(freqs_2d), freqs_2d) | |
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): | |
""" | |
Reshape frequency tensor for broadcasting with input tensor. | |
This function ensures that the frequency tensor can be properly broadcast | |
with the input tensor during the rotary embedding process. | |
Args: | |
freqs_cis (torch.Tensor): Frequency tensor from precompute_freqs_cis_2d. | |
x (torch.Tensor): Input tensor to be embedded. | |
Returns: | |
torch.Tensor: Reshaped frequency tensor ready for broadcasting. | |
""" | |
ndim = x.ndim | |
assert 0 <= 1 < ndim, f"ndim is {ndim} but index is {1}" | |
assert freqs_cis.shape == ( | |
x.shape[1], | |
x.shape[-1], | |
), f"freqs_cis shape is {freqs_cis.shape} but x shape is {x.shape}" | |
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] | |
return freqs_cis.view(*shape) | |
def apply_rotary_emb( | |
xq: torch.Tensor, | |
xk: torch.Tensor, | |
*args, | |
freqs_cis: torch.Tensor, | |
**kwargs, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Apply rotary positional embeddings to input tensors. | |
This function applies the rotary positional embeddings to the query and key tensors, | |
which helps the model understand spatial relationships in the input. | |
Args: | |
xq (torch.Tensor): Query tensor. | |
xk (torch.Tensor): Key tensor. | |
freqs_cis (torch.Tensor): Precomputed frequencies from precompute_freqs_cis_2d. | |
*args: Variable length argument list (unused). | |
**kwargs: Arbitrary keyword arguments (unused). | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors. | |
""" | |
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
freqs_cis = reshape_for_broadcast(freqs_cis, xq_) | |
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
return xq_out.type_as(xq), xk_out.type_as(xk) | |
class VisionTransformer(nn.Module): | |
""" | |
Vision Transformer model for image processing. | |
This class implements a Vision Transformer that processes images using a patch-based approach | |
and applies transformer layers with rotary position embeddings. | |
Args: | |
dim (int): Dimension of the model (hidden size). | |
num_channels (int): Number of input image channels (e.g., 3 for RGB). | |
patch_size (int): Size of each image patch (e.g., 16x16 pixels). | |
n_layers (int): Number of transformer layers. | |
n_heads (int): Number of attention heads. | |
ffn_hidden_size (int): Hidden size of the feed-forward network in transformer blocks. | |
norm_type (str): Type of normalization to use (e.g., "rmsnorm"). | |
norm_eps (float): Epsilon value for normalization layers. | |
image_size (int): Size of the input image (assumed square). | |
rope_theta (float): Base value for rotary position embedding calculation. | |
attention_dropout (float): Dropout rate for attention layers. | |
hidden_dropout (float): Dropout rate for hidden layers. | |
image_token_id (int): Token ID for the image token (if present). | |
""" | |
def __init__( | |
self, | |
dim: int = 1024, | |
num_channels: int = 3, | |
patch_size: int = 16, | |
n_layers: int = 24, | |
n_heads: int = 16, | |
n_kv_heads: int = None, | |
ffn_hidden_size: int = 4096, | |
norm_type: str = "rmsnorm", | |
norm_eps: float = 1e-5, | |
image_size: int = 1024, | |
rope_theta: float = 1000000.0, | |
image_token_id: int = None, | |
tensor_model_parallel_size: int = 1, | |
): | |
super().__init__() | |
self.patch_conv = nn.Conv2d( | |
in_channels=num_channels, | |
out_channels=dim, | |
kernel_size=patch_size, | |
stride=patch_size, | |
bias=False, | |
) | |
self.ln_pre = create_norm(norm_type=norm_type, dim=dim, eps=norm_eps) | |
if n_kv_heads is None: | |
n_kv_heads = n_heads | |
layer_args = dict( | |
n_layers=n_layers, | |
n_heads=n_heads, | |
n_kv_heads=n_kv_heads, | |
dim=dim, | |
use_qk_normalization=False, | |
max_seq_len=None, | |
max_batch_size=None, | |
ffn_hidden_size=ffn_hidden_size, | |
norm_type=norm_type, | |
norm_eps=norm_eps, | |
causal_mask=False, # Full attention in ViT | |
head_dim=None, | |
insert_cross_attn=False, | |
tensor_model_parallel_size=tensor_model_parallel_size, | |
attn_type="full", | |
) | |
self.transformer = VisionTransformerBlocks(n_layers=n_layers, args=layer_args) | |
head_dim = dim // n_heads | |
assert head_dim % 2 == 0, "ROPE requires even head_dim" | |
self.dim = dim | |
self.n_heads = n_heads | |
self.max_patches_per_side = image_size // patch_size | |
self.image_size = image_size | |
self.patch_size = patch_size | |
self.rope_theta = rope_theta | |
self._freqs_cis: Optional[torch.Tensor] = None | |
self.image_token_id = image_token_id | |
num_params = self.get_num_params() | |
log.debug(f"Number of model parameters: {round(num_params / 1e6, 3)}M") | |
def build( | |
cls, | |
config: Mapping[str, Any], | |
) -> "VisionTransformer": | |
""" | |
Create a Vision Transformer from a configuration dictionary. | |
This class method creates a Vision Transformer from a configuration dictionary, | |
which is typically loaded from a JSON file or other configuration source. | |
Args: | |
config (Mapping[str, Any]): Configuration dictionary for the Vision Transformer. | |
Returns: | |
VisionTransformer: Vision Transformer model instance. | |
""" | |
necessary_keys = ["dim", "num_channels", "patch_size", "n_layers", "n_heads", "ffn_hidden_size", "rope_theta"] | |
missing_keys = [k for k in necessary_keys if k not in config] | |
assert len(missing_keys) == 0, f"Missing keys in config: {missing_keys}" | |
return cls( | |
**config, | |
) | |
def expand_in_channels(self, new_in_channels: int): | |
""" | |
Expand the input channels of the patch convolution layer. | |
This is useful when the input is non-standard, e.g. a 4-channel image with the last channel as the alpha channel. | |
Note that you should only call this method after the weight is loaded. | |
""" | |
assert ( | |
new_in_channels > self.patch_conv.in_channels | |
), "Cannot expand the input channels of the patch convolution layer to be less than the original number of channels." | |
log.debug( | |
f"Vision encoder in_channels is {self.patch_conv.in_channels}. But you have specified to be {new_in_channels}. We will change it to {new_in_channels} channels with {new_in_channels - self.patch_conv.in_channels} channels of 0s." | |
) | |
new_conv = nn.Conv2d( | |
in_channels=new_in_channels, | |
out_channels=self.patch_conv.out_channels, | |
kernel_size=self.patch_conv.kernel_size, | |
stride=self.patch_conv.stride, | |
bias=False, | |
) | |
new_conv.weight.data[:, : self.patch_conv.in_channels].copy_(self.patch_conv.weight.data) | |
new_conv.weight.data[ | |
:, self.patch_conv.in_channels : | |
].zero_() # zeroize, such that initially it has no effect to output | |
self.patch_conv = new_conv | |
def device(self) -> torch.device: | |
"""Get the device of the model.""" | |
return next(self.parameters()).device | |
def freqs_cis(self) -> torch.Tensor: | |
""" | |
Get or compute the frequency tensor for rotary position embedding. | |
This property lazily initializes and caches the frequency tensor used for | |
rotary position embeddings, ensuring it's on the correct device. | |
Returns: | |
torch.Tensor: The frequency tensor for rotary position embeddings. | |
""" | |
if self._freqs_cis is None: | |
self._freqs_cis = precompute_freqs_cis_2d( | |
dim=self.dim // self.n_heads, | |
height=self.max_patches_per_side, | |
width=self.max_patches_per_side, | |
theta=self.rope_theta, | |
) | |
if self._freqs_cis.device != self.device: | |
self._freqs_cis = self._freqs_cis.to(device=self.device) | |
return self._freqs_cis | |
def forward( | |
self, | |
x: torch.Tensor, | |
) -> torch.Tensor: | |
""" | |
Forward pass of the Vision Transformer. | |
This method processes the input image through the Vision Transformer, | |
including patch embedding, position embedding, and transformer layers. | |
Args: | |
x (torch.Tensor): Input tensor of shape (B, C, H, W), where B is batch size, | |
C is number of channels, and H, W are height and width. | |
Returns: | |
torch.Tensor: Output features of shape (B, N, D), where N is the number of patches | |
and D is the embedding dimension. | |
""" | |
patch_embeds = self.patch_conv(x) # (B, D, Hp, Wp) | |
_, _, Hp, Wp = patch_embeds.shape # Patch embeds dim | |
patch_embeds = patch_embeds.flatten(2) # (B, D, Hp*Wp) | |
patch_embeds = patch_embeds.transpose(1, 2) # (B, Hp*Wp, D) | |
patch_embeds = self.ln_pre(patch_embeds) # (B, Hp*Wp, D) | |
positions = torch.stack( | |
torch.meshgrid( | |
torch.arange(Hp), | |
torch.arange(Wp), | |
indexing="ij", | |
), | |
dim=-1, | |
).reshape(-1, 2) | |
freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] | |
rope = partial(apply_rotary_emb, freqs_cis=freqs_cis) | |
out = self.transformer(patch_embeds, rope=rope) | |
return out | |
def get_num_params( | |
self, | |
) -> int: | |
""" | |
Return the number of parameters in the model. | |
""" | |
n_params = sum(p.numel() for p in self.parameters()) | |
return n_params | |
class VisionTransformerBlocks(nn.Module): | |
""" | |
Vision Transformer Blocks. | |
This class implements a stack of Transformer blocks used in the Vision Transformer. | |
Args: | |
n_layers (int): Number of transformer layers. | |
args (Mapping[str, Any]): Arguments for each transformer block, including dimensions, | |
""" | |
def __init__( | |
self, | |
n_layers: int, | |
args: Mapping[str, Any], | |
): | |
super().__init__() | |
self.layers = torch.nn.ModuleList() | |
for layer_id in range(n_layers): | |
self.layers.append( | |
TransformerBlock( | |
layer_id=layer_id, | |
args=args, | |
) | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
rope: Callable, | |
) -> torch.Tensor: | |
""" | |
Forward pass through the Vision Transformer Blocks. | |
This method applies a series of Transformer blocks to the input tensor, | |
using the provided rotary position embedding function. | |
Args: | |
x (torch.Tensor): Input tensor of shape (B, N, D), where B is batch size, | |
N is the number of patches, and D is the embedding dimension. | |
rope (Callable): Rotary position embedding function to be applied in each layer. | |
Returns: | |
torch.Tensor: Output tensor after passing through all transformer layers, | |
with the same shape as the input. | |
""" | |
for layer in self.layers: | |
x = layer(x, input_pos=None, mask=None, rope=rope) | |
return x | |