Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Multimodal projector to connect vision encoder / tokenizer with the LLM.""" | |
from typing import Any, Optional | |
import torch | |
import torch.nn as nn | |
class DownSampleBlock(nn.Module): | |
"""Downsample block.""" | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
""" | |
Performs the forward pass of the downsample block. | |
Args: | |
x (torch.Tensor): The input tensor from ViT's output of a sequence of embeddings. | |
Shape: (b, seq_len, c). | |
Returns: | |
torch.Tensor: The output tensor. Shape: (b, seq_len/4, c*4). | |
""" | |
vit_embeds = x | |
# Get h and w as the sqrt of seq length. This assumes that the input is square-shaped. | |
h = w = int(vit_embeds.shape[1] ** 0.5) | |
b = vit_embeds.shape[0] | |
vit_embeds = vit_embeds.reshape(b, h, w, -1) | |
vit_embeds = self.flat_square(vit_embeds) | |
vit_embeds = vit_embeds.reshape(b, -1, vit_embeds.shape[-1]) | |
return vit_embeds | |
def flat_square(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Performs spatial downsampling while increasing the number of channels. | |
Args: | |
x (torch.Tensor): The input tensor reshaped to a 2D grid. | |
Shape: (b, h, w, c) | |
Returns: | |
torch.Tensor: The output tensor after the spatial downsampling. | |
Shape: (b, h/2, w/2, c*4) | |
""" | |
b, h, w, c = x.size() | |
# If w or h is odd, pad a column or a row of zeros. | |
if h % 2 == 1: | |
x = torch.concat([x, torch.zeros((b, 1, w, c), dtype=x.dtype).to(x.device)], dim=1).contiguous() | |
b, h, w, c = x.size() | |
if w % 2 == 1: | |
x = torch.concat([x, torch.zeros((b, h, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous() | |
b, h, w, c = x.size() | |
# 2x spatial downsampling, 4x channel increasing. | |
x = x.view(b, h, int(w / 2), int(c * 2)) | |
x = x.permute(0, 2, 1, 3).contiguous() | |
x = x.view(b, int(h / 2), int(w / 2), int(c * 4)) | |
x = x.permute(0, 2, 1, 3).contiguous() | |
return x | |
class MultimodalProjector(nn.Module): | |
"""Multimodal projector.""" | |
def __init__( | |
self, | |
mm_projector_type: str, | |
in_dim: int, | |
out_dim: Optional[int] = None, | |
**kwargs: Any, | |
): | |
super().__init__() | |
if out_dim is None: | |
out_dim = in_dim | |
if mm_projector_type == "identity": | |
self.projector = nn.Identity() | |
elif mm_projector_type == "linear": | |
self.projector = nn.Linear(in_dim, out_dim) | |
elif mm_projector_type == "mlp": | |
self.projector = nn.Sequential(nn.Linear(in_dim, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim)) | |
elif mm_projector_type == "mlp_downsample": | |
self.projector = nn.Sequential( | |
DownSampleBlock(), | |
nn.LayerNorm(in_dim * 4), | |
nn.Linear(in_dim * 4, out_dim), | |
nn.GELU(), | |
nn.Linear(out_dim, out_dim), | |
) | |
else: | |
raise ValueError(f"Unknown projector type: {mm_projector_type}") | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.projector(x) | |