Cosmos
Safetensors
NeMo
cosmos-embed1
nvidia
custom_code
Cosmos-Embed1-336p / modeling_embed1.py
fferroni's picture
First commit
ecf8cbe
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cosmos-Embed1 text+video embedder."""
import math
from copy import deepcopy
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from transformers import AutoModel, PreTrainedModel
from .configuration_embed1 import CosmosEmbed1Config
from .modeling_outputs import TextEmbedderOutput, TextVideoEmbedderOutput, VideoEmbedderOutput
from .modeling_qformer import BertLMHeadModel, load_qformer
from .modeling_utils import EncodingFactory, rank0_first
from .modeling_vit import EvaViTG
class CosmosEmbed1(PreTrainedModel):
config_class = CosmosEmbed1Config
def __init__(self, config: CosmosEmbed1Config) -> None:
"""Cosmos-Embed1 video embedder constructor.
Args:
config (CosmosEmbed1Config): Model configuration.
"""
super().__init__(config)
self.embed_dim = config.embed_dim
self.num_query_tokens = config.num_query_tokens
self.num_video_frames = config.num_video_frames
self.temporal_encoding_type = config.temporal_encoding_type
self.resolution = config.resolution
self.vocab_size = config.vocab_size
self.transformer_engine = config.transformer_engine
self.use_fp8 = config.use_fp8
# visual encoder initialization
self.register_buffer(
"normalization_mean",
torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1),
persistent=False,
)
self.register_buffer(
"normalization_std",
torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1),
persistent=False,
)
self.visual_encoder = EvaViTG(
img_size=self.resolution,
transformer_engine=self.transformer_engine,
use_fp8=self.use_fp8,
)
self.ln_vision = nn.LayerNorm(self.visual_encoder.embed_dim)
# qformer initialization
self.qformer, self.query_tokens = self._init_qformer(
num_query_tokens=self.num_query_tokens,
encoder_width=self.visual_encoder.embed_dim,
vocab_size=self.vocab_size,
)
# self.qformer.
state_dict = self.qformer.state_dict()
for name, param in self.qformer.named_parameters():
if "_query" in name:
key_orig = name.replace("_query", "")
param.data.copy_(state_dict[key_orig])
# temporal encoding
self.temporal_encoding = EncodingFactory(
self.temporal_encoding_type,
embed_dim=self.visual_encoder.embed_dim,
max_len=self.num_video_frames,
)
# output projections
self.vision_proj = nn.Linear(self.qformer.config.hidden_size, self.embed_dim)
self.text_proj = nn.Linear(self.qformer.config.hidden_size, self.embed_dim)
self.itm_proj = nn.Linear(self.qformer.config.hidden_size, 2)
# initialize logit scale/bias like SigLIP (as per Table 4 in https://arxiv.org/pdf/2303.15343)
self.logit_scale = nn.Parameter(torch.tensor(math.log(10.0)))
self.logit_bias = nn.Parameter(torch.tensor(-10.0))
@property
def hidden_dim(self) -> int:
return self.visual_encoder.embed_dim
@torch.jit.ignore
def no_weight_decay(self) -> set:
ret = {"logit_scale", "logit_bias"}
return ret
def forward(
self,
videos: torch.FloatTensor,
input_ids: torch.LongTensor,
attention_mask: torch.FloatTensor,
) -> TextVideoEmbedderOutput:
"""Forward function for `ComosEmbed1`.
Args:
videos (`torch.Tensor` of shape `(batch_size, num_frames, RGB, height, width)`):
batched videos with fixed number of RGB frames.
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained by using [`AutoTokenizer`, `CosmosEmbed1Tokenizer`].
attention_mask: (`torch.Tensor` of shape `(batch_size, sequence_length)`):
Mask to avoid performing attention on padding token indices.
Mask values select in `[0, 1]`.
- 1 for tokens that are **not masked**.
- 0 for tokens that are **masked**.
"""
video_output = self.get_video_embeddings(videos)
text_output = self.get_text_embeddings(input_ids, attention_mask)
return TextVideoEmbedderOutput(**video_output, **text_output)
def get_video_embeddings(self, videos: torch.Tensor) -> VideoEmbedderOutput:
videos = (videos - self.normalization_mean) / self.normalization_std
batch_size, num_frames, _, H, W = videos.shape
frame_batch = rearrange(videos, "b t c h w -> (b t) c h w")
# process video frames through ViT
visual_embs = self.visual_encoder(frame_batch)
visual_embs = self.ln_vision(visual_embs)
visual_embs = rearrange(
visual_embs,
"(b t) k d -> b t k d",
b=batch_size,
t=num_frames,
k=visual_embs.size(1),
d=visual_embs.size(2),
)
# add temporal encoding
visual_embs = self.temporal_encoding(visual_embs)
# Q-Former cross-attention
encoder_hidden_states = rearrange(visual_embs, "b t k d -> b (t k) d")
encoder_attention_mask = torch.ones(encoder_hidden_states.size()[:-1], dtype=torch.long).to(videos.device)
query_tokens = self.query_tokens.expand(encoder_hidden_states.size(0), -1, -1)
visual_query_output = self.qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
return_dict=True,
)
visual_cls_tokens = visual_query_output.last_hidden_state.mean(dim=1, keepdim=False)
visual_proj = self.vision_proj(visual_cls_tokens)
visual_proj = F.normalize(visual_proj, dim=-1)
# reshape visual embs to (B,T,H,W,D), to confirm with expected output.
# separate out the frame-level cls tokens if necessary.
frame_cls_tokens, visual_embs = visual_embs[:, :, 0:1], visual_embs[:, :, 1:]
h = H // self.visual_encoder.patch_size
w = W // self.visual_encoder.patch_size
visual_embs = rearrange(visual_embs, "b t (h w) d -> b t h w d", h=h, w=w)
return VideoEmbedderOutput(
visual_proj=visual_proj,
visual_embs=visual_embs,
visual_query_output=visual_query_output,
visual_cls_tokens=visual_cls_tokens,
frame_cls_tokens=frame_cls_tokens,
)
def get_text_embeddings(
self,
input_ids: torch.LongTensor,
attention_mask: torch.FloatTensor,
) -> TextEmbedderOutput:
text_query_output = self.qformer.bert(
input_ids=input_ids,
attention_mask=attention_mask.to(dtype=self.query_tokens.dtype),
return_dict=True,
)
text_proj = text_query_output.last_hidden_state[:, 0, :]
text_proj = self.text_proj(text_proj)
text_proj = F.normalize(text_proj, dim=-1)
return TextEmbedderOutput(
text_proj=text_proj,
text_embs=text_query_output.last_hidden_state,
text_query_output=text_query_output,
)
@classmethod
@rank0_first
def _init_qformer(
cls: "CosmosEmbed1",
num_query_tokens: int,
encoder_width: int,
vocab_size: int,
hidden_size: int = 768,
) -> tuple[BertLMHeadModel, nn.Parameter]:
"""Convenience function for initializing QFormer module."""
qformer = load_qformer(
num_query_tokens=num_query_tokens,
encoder_width=encoder_width,
hidden_size=hidden_size,
vocab_size=vocab_size,
)
query_tokens = nn.Parameter(torch.zeros(1, num_query_tokens, hidden_size))
query_tokens.data.normal_(mean=0.0, std=0.02)
return qformer, query_tokens
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
# Get config from kwargs or load from pretrained path
config = kwargs.get("config", None)
if config is None:
config = CosmosEmbed1Config.from_pretrained(pretrained_model_name_or_path)
if config.transformer_engine:
config_no_te = deepcopy(config)
config_no_te.transformer_engine = False
config_no_te.use_fp8 = False # Also disable FP8 for the base model
# Remove 'config' from kwargs to avoid conflict, we'll pass config_no_te
kwargs_no_te = deepcopy(kwargs)
kwargs_no_te["config"] = config_no_te
# Load standard (non-TE) model & weights
base_model = super().from_pretrained(pretrained_model_name_or_path, **kwargs_no_te)
base_state_dict = base_model.state_dict()
# Now build the TE version of the model
model_with_te = cls(config=config)
# Load weights from non-TE model
missing, unexpected = model_with_te.load_state_dict(base_state_dict, strict=False)
# Optional debug log
if missing:
print(f"[TransformerEngine] Missing keys: {missing}")
if unexpected:
print(f"[TransformerEngine] Unexpected keys: {unexpected}")
return model_with_te
else:
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
AutoModel.register(CosmosEmbed1Config, CosmosEmbed1)