# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Output definitions for Cosmos-Embed1.""" from dataclasses import dataclass from typing import Optional import torch from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, ModelOutput @dataclass class TextEmbedderOutput(ModelOutput): """Output of a video embedder branch `get_text_embeddings` function. Attrs: text_proj (`torch.FloatTensor` of shape `(batch_size, num_visual_embs, embed_dim)` or `(batch_size, embed_dim)`: text (video-aligned) projected embeddings from text branch. text_embs (`torch.FloatTensor` of shape `(batch_size, ...)`: text tokens from text branch. text_query_output (`transformer.modeling_outputs.CausalLMOutputWithCrossAttentions`): Useful text branch intermediate outputs like hidden states, past key values, attentions etc. """ text_proj: Optional[torch.FloatTensor] = None text_embs: Optional[torch.FloatTensor] = None text_query_output: Optional[CausalLMOutputWithCrossAttentions] = None @dataclass class VideoEmbedderOutput(ModelOutput): """Output of a video embedder branch `get_video_embeddings` function. Attrs: visual_proj (`torch.FloatTensor` of shape `(batch_size, embed_dim)`): visual (text-aligned) projected embeddings from visual branch. visual_embs (`torch.FloatTensor` of shape `(batch_size, num_frames, height, width, encoder_dim)`): per-frame dense visual embeddings from visual encoder. visual_cls_tokens (`torch.FloatTensor` of shape `(batch_size, qformer_dim)`): visual pooled tokens from visual branch prior to projection and normalization. frame_cls_tokens (`torch.FloatTensor` of shape `(batch_size, num_frames, encoder_dim)`): per-frame cls tokens from visual encoder. visual_query_output (`transformer.modeling_outputs.CausalLMOutputWithCrossAttentions`): Useful visual branch intermediate outputs like hidden states, past key values, attentions etc. """ visual_proj: Optional[torch.FloatTensor] = None visual_embs: Optional[torch.FloatTensor] = None visual_cls_tokens: Optional[torch.FloatTensor] = None frame_cls_tokens: Optional[torch.FloatTensor] = None visual_query_output: Optional[CausalLMOutputWithCrossAttentions] = None @dataclass class TextVideoEmbedderOutput(VideoEmbedderOutput, TextEmbedderOutput): """Merged class of `VideoEmbedderOutput` and `TextEmbedderOutput`."""