|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Output definitions for Cosmos-Embed1.""" |
|
|
|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
import torch |
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, ModelOutput |
|
|
|
|
|
@dataclass |
|
class TextEmbedderOutput(ModelOutput): |
|
"""Output of a video embedder branch `get_text_embeddings` function. |
|
|
|
Attrs: |
|
text_proj (`torch.FloatTensor` of shape `(batch_size, num_visual_embs, embed_dim)` or `(batch_size, embed_dim)`: |
|
text (video-aligned) projected embeddings from text branch. |
|
text_embs (`torch.FloatTensor` of shape `(batch_size, ...)`: |
|
text tokens from text branch. |
|
text_query_output (`transformer.modeling_outputs.CausalLMOutputWithCrossAttentions`): |
|
Useful text branch intermediate outputs like hidden states, past key values, attentions etc. |
|
""" |
|
|
|
text_proj: Optional[torch.FloatTensor] = None |
|
text_embs: Optional[torch.FloatTensor] = None |
|
text_query_output: Optional[CausalLMOutputWithCrossAttentions] = None |
|
|
|
|
|
@dataclass |
|
class VideoEmbedderOutput(ModelOutput): |
|
"""Output of a video embedder branch `get_video_embeddings` function. |
|
|
|
Attrs: |
|
visual_proj (`torch.FloatTensor` of shape `(batch_size, embed_dim)`): |
|
visual (text-aligned) projected embeddings from visual branch. |
|
visual_embs (`torch.FloatTensor` of shape `(batch_size, num_frames, height, width, encoder_dim)`): |
|
per-frame dense visual embeddings from visual encoder. |
|
visual_cls_tokens (`torch.FloatTensor` of shape `(batch_size, qformer_dim)`): |
|
visual pooled tokens from visual branch prior to projection and normalization. |
|
frame_cls_tokens (`torch.FloatTensor` of shape `(batch_size, num_frames, encoder_dim)`): |
|
per-frame cls tokens from visual encoder. |
|
visual_query_output (`transformer.modeling_outputs.CausalLMOutputWithCrossAttentions`): |
|
Useful visual branch intermediate outputs like hidden states, past key values, attentions etc. |
|
""" |
|
|
|
visual_proj: Optional[torch.FloatTensor] = None |
|
visual_embs: Optional[torch.FloatTensor] = None |
|
visual_cls_tokens: Optional[torch.FloatTensor] = None |
|
frame_cls_tokens: Optional[torch.FloatTensor] = None |
|
visual_query_output: Optional[CausalLMOutputWithCrossAttentions] = None |
|
|
|
|
|
@dataclass |
|
class TextVideoEmbedderOutput(VideoEmbedderOutput, TextEmbedderOutput): |
|
"""Merged class of `VideoEmbedderOutput` and `TextEmbedderOutput`.""" |
|
|