Cosmos
Safetensors
NeMo
cosmos-embed1
nvidia
custom_code
Cosmos-Embed1-336p / preprocessing_embed1.py
fferroni's picture
First commit
ecf8cbe
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Cosmos-Embed1
"""
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torchvision
from transformers import AutoProcessor, BatchFeature
from transformers.processing_utils import ProcessorMixin
from transformers.utils import TensorType
from .configuration_embed1 import CosmosEmbed1Config
class CosmosEmbed1Processor(ProcessorMixin):
r"""
Constructs a processor which wraps a BertTokenizer tokenizer and a fast video resize function.
Args:
tokenizer ([`BertTokenizerFast`], *optional*):
The tokenizer is a required input for text processing.
config ([`CosmosEmbed1Config`], *optional*):
Needed for processing options.
"""
attributes = ["tokenizer"]
tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
config_class = CosmosEmbed1Config
chat_template = None
def __init__(
self,
tokenizer=None,
resolution: Union[int, Tuple[int, int]] = 336,
num_video_frames: int = 8,
max_txt_len: int = 128,
**kwargs,
) -> None:
super().__init__(tokenizer, **kwargs)
self.resolution = resolution
self.num_video_frames = num_video_frames
self.max_txt_len = max_txt_len
def __call__(
self,
text: Optional[Union[str, List[str]]] = None,
videos: Optional[Union[np.ndarray, torch.Tensor]] = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
resolution: Union[int, Tuple[int, int]] = None,
num_video_frames: int = None,
max_txt_len: int = None,
**kwargs,
) -> BatchFeature:
inputs = {}
if text is not None:
max_txt_len = max_txt_len if max_txt_len is not None else self.max_txt_len
tokenized = self.tokenizer(
text, return_tensors="pt", padding="max_length", truncation=True, max_length=max_txt_len, **kwargs
)
inputs["input_ids"] = tokenized.input_ids
inputs["attention_mask"] = tokenized.attention_mask.float()
if videos is not None:
if isinstance(videos, np.ndarray):
videos = torch.from_numpy(videos)
if not isinstance(videos, torch.Tensor) or videos.ndim != 5:
raise ValueError("Processor expects a numpy or torch tensor of shape BTCHW from [0-255].")
resolution = resolution if resolution is not None else self.resolution
if isinstance(resolution, int):
resolution = (resolution, resolution)
_, t, c, h, w = videos.shape
if c != 3:
raise ValueError(f"Expected tensor of shape BTCHW with RGB channels, got channel size {c}.")
num_video_frames = num_video_frames if num_video_frames is not None else self.num_video_frames
if t != num_video_frames:
raise ValueError(f"Expected tensor of shape BTCHW with {num_video_frames} frames, got {t}.")
if h != resolution[0] or w != resolution[1]:
videos = resize_video(videos, resolution)
if videos.dtype == torch.uint8:
videos = videos.float()
inputs["videos"] = videos / 255.0
if not inputs:
raise ValueError("Must pass either `text` or `videos` argument to __call__ function.")
return BatchFeature(inputs, tensor_type=return_tensors)
def resize_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
"""Resize a video tensor (B, T, C, H, W) to a new height/width.
Args:
video (torch.Tensor): (B, T, C, H, W) uint8 or float32.
size (tuple): target (H', W') size.
Returns:
torch.Tensor: resized video of shape (B, T, C, H', W')
"""
h, w = size
B, T, C, H, W = video.shape
video = video.view(B * T, C, H, W)
resize = torchvision.transforms.Resize(
(h, w),
antialias=True,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
)
video = resize(video)
new_H, new_W = video.shape[-2:]
video = video.view(B, T, C, new_H, new_W)
return video
AutoProcessor.register(CosmosEmbed1Config, CosmosEmbed1Processor)
__all__ = ["CosmosEmbed1Processor"]