Cosmos
Safetensors
NeMo
cosmos-embed1
nvidia
custom_code
Cosmos-Embed1-336p / modeling_utils.py
fferroni's picture
First commit
ecf8cbe
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Misc functions and modules for Cosmos-Embed1."""
import functools
from logging import getLogger
from typing import Callable, Optional, Protocol
import torch
import torch.distributed as dist
import torch.nn as nn
logger = getLogger(__file__)
def get_rank(group: Optional[dist.ProcessGroup] = None) -> int:
"""Get the rank (GPU device) of the worker.
Returns:
rank (int): The rank of the worker.
"""
rank = 0
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank(group)
return rank
def barrier() -> None:
"""Barrier for all GPUs."""
if dist.is_available() and dist.is_initialized():
dist.barrier()
def rank0_first(func: Callable) -> Callable:
"""Run the function on rank 0 first, then on other ranks."""
@functools.wraps(func)
def wrapper(*args, **kwargs): # noqa: ANN202
if get_rank() == 0:
result = func(*args, **kwargs)
barrier()
if get_rank() != 0:
result = func(*args, **kwargs)
return result
return wrapper
def add_docstring(docstring: str):
def decorator(func):
func.__doc__ = docstring
return func
return decorator
INIT_DOCSTRING = """
Constructor for encoding module.
Args:
embed_dim: size of embedding vectors, e.g. x.shape[3].
max_len: maximum length of temporal sequence, e.g. x.shape[1].
"""
FORWARD_DOCSTRING = """
Forward function.
Args:
x (`torch.Tensor`): rank 4 tensor to add spatio-temporal encodings to.
Returns:
`torch.Tensor` of rank 4.
"""
class EncodingProtocol(Protocol):
def __init__(self, embed_dim: int, max_len: int) -> None:
pass
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass
def interpolate_temp_pos_embed(temp_embed: torch.Tensor, num_frames: int) -> torch.Tensor:
"""Linearly interpolates temporal encoding from `temp_embed.shape[0] to num_frames."""
temp_embed_resized = temp_embed.permute(1, 0).unsqueeze(0)
temp_embed_resized = nn.functional.interpolate(
temp_embed_resized,
size=(num_frames),
mode="linear",
align_corners=False,
)
return temp_embed_resized.squeeze(0).permute(1, 0)
class TemporalParameterEncoding(nn.Module, EncodingProtocol):
@add_docstring(INIT_DOCSTRING)
def __init__(self, embed_dim: int, max_len: int) -> None:
super().__init__()
self.embed_dim = embed_dim
self.max_len = max_len
self.temp_embed = nn.Parameter(torch.zeros(self.max_len, self.embed_dim))
nn.init.trunc_normal_(self.temp_embed, std=0.02)
@add_docstring(FORWARD_DOCSTRING)
def forward(self, x: torch.Tensor) -> torch.Tensor:
_, t, _, _ = x.shape
if t != self.temp_embed.shape[0]:
logger.debug(f"Interpolating temporal encodings from {self.temp_embed.shape[0]} to {t}.")
temp_embed = interpolate_temp_pos_embed(self.temp_embed, t)
else:
temp_embed = self.temp_embed
temp_embed = temp_embed.unsqueeze(0).unsqueeze(2)
return x + temp_embed
def create_neighbor_weight_matrix(num_tokens: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
indices = torch.arange(num_tokens, dtype=dtype, device=device)
abs_diff = torch.abs(indices.unsqueeze(0) - indices.unsqueeze(1))
weights = 1.0 / (2.0**abs_diff)
return weights
def compute_t_adj(x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
return torch.einsum("bfnd,nk->bfkd", x, weights)
def token_propagation(x: torch.Tensor, num_tokens: int) -> torch.Tensor:
"""Apply neighboring token propagation update."""
weights = create_neighbor_weight_matrix(num_tokens, x.device, x.dtype)
t_adj = compute_t_adj(x, weights)
return x + t_adj - t_adj.detach()
class NeighboringTokenPropagationEncoding(TemporalParameterEncoding):
"""
Neighboring Token Propagation method inspired by Momentor (https://arxiv.org/abs/2402.11435)
"""
@add_docstring(FORWARD_DOCSTRING)
def forward(self, x: torch.Tensor) -> torch.Tensor:
_, t, q, _ = x.shape
if t != self.temp_embed.shape[0]:
logger.debug(f"Interpolating temporal encodings from {self.temp_embed.shape[0]} to {t}.")
temp_embed = interpolate_temp_pos_embed(self.temp_embed, t)
else:
temp_embed = self.temp_embed
temp_embed = temp_embed.unsqueeze(0).unsqueeze(2)
if self.training:
temp_embed = token_propagation(temp_embed, q)
return x + temp_embed
class EncodingFactory(nn.Module):
def __init__(self, encoding_type: str, embed_dim: int, max_len: int) -> None:
super().__init__()
fn = {
"temporal_parameter": TemporalParameterEncoding,
"neighboring_token_propagation": NeighboringTokenPropagationEncoding,
}[encoding_type]
self.encoding = fn(embed_dim=embed_dim, max_len=max_len)
@add_docstring(FORWARD_DOCSTRING)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.encoding(x)