|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Cosmos-Embed1 text+video embedder.""" |
|
|
|
import math |
|
from copy import deepcopy |
|
|
|
import torch |
|
from einops import rearrange |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from transformers import AutoModel, PreTrainedModel |
|
|
|
from .configuration_embed1 import CosmosEmbed1Config |
|
from .modeling_outputs import TextEmbedderOutput, TextVideoEmbedderOutput, VideoEmbedderOutput |
|
from .modeling_qformer import BertLMHeadModel, load_qformer |
|
from .modeling_utils import EncodingFactory, rank0_first |
|
from .modeling_vit import EvaViTG |
|
|
|
|
|
class CosmosEmbed1(PreTrainedModel): |
|
config_class = CosmosEmbed1Config |
|
|
|
def __init__(self, config: CosmosEmbed1Config) -> None: |
|
"""Cosmos-Embed1 video embedder constructor. |
|
|
|
Args: |
|
config (CosmosEmbed1Config): Model configuration. |
|
""" |
|
super().__init__(config) |
|
|
|
self.embed_dim = config.embed_dim |
|
self.num_query_tokens = config.num_query_tokens |
|
self.num_video_frames = config.num_video_frames |
|
self.temporal_encoding_type = config.temporal_encoding_type |
|
self.resolution = config.resolution |
|
self.vocab_size = config.vocab_size |
|
self.transformer_engine = config.transformer_engine |
|
self.use_fp8 = config.use_fp8 |
|
|
|
|
|
self.register_buffer( |
|
"normalization_mean", |
|
torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1), |
|
persistent=False, |
|
) |
|
self.register_buffer( |
|
"normalization_std", |
|
torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1), |
|
persistent=False, |
|
) |
|
self.visual_encoder = EvaViTG( |
|
img_size=self.resolution, |
|
transformer_engine=self.transformer_engine, |
|
use_fp8=self.use_fp8, |
|
) |
|
self.ln_vision = nn.LayerNorm(self.visual_encoder.embed_dim) |
|
|
|
|
|
self.qformer, self.query_tokens = self._init_qformer( |
|
num_query_tokens=self.num_query_tokens, |
|
encoder_width=self.visual_encoder.embed_dim, |
|
vocab_size=self.vocab_size, |
|
) |
|
|
|
state_dict = self.qformer.state_dict() |
|
for name, param in self.qformer.named_parameters(): |
|
if "_query" in name: |
|
key_orig = name.replace("_query", "") |
|
param.data.copy_(state_dict[key_orig]) |
|
|
|
|
|
self.temporal_encoding = EncodingFactory( |
|
self.temporal_encoding_type, |
|
embed_dim=self.visual_encoder.embed_dim, |
|
max_len=self.num_video_frames, |
|
) |
|
|
|
|
|
self.vision_proj = nn.Linear(self.qformer.config.hidden_size, self.embed_dim) |
|
self.text_proj = nn.Linear(self.qformer.config.hidden_size, self.embed_dim) |
|
self.itm_proj = nn.Linear(self.qformer.config.hidden_size, 2) |
|
|
|
self.logit_scale = nn.Parameter(torch.tensor(math.log(10.0))) |
|
self.logit_bias = nn.Parameter(torch.tensor(-10.0)) |
|
|
|
@property |
|
def hidden_dim(self) -> int: |
|
return self.visual_encoder.embed_dim |
|
|
|
@torch.jit.ignore |
|
def no_weight_decay(self) -> set: |
|
ret = {"logit_scale", "logit_bias"} |
|
return ret |
|
|
|
def forward( |
|
self, |
|
videos: torch.FloatTensor, |
|
input_ids: torch.LongTensor, |
|
attention_mask: torch.FloatTensor, |
|
) -> TextVideoEmbedderOutput: |
|
"""Forward function for `ComosEmbed1`. |
|
|
|
Args: |
|
videos (`torch.Tensor` of shape `(batch_size, num_frames, RGB, height, width)`): |
|
batched videos with fixed number of RGB frames. |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. |
|
Indices can be obtained by using [`AutoTokenizer`, `CosmosEmbed1Tokenizer`]. |
|
attention_mask: (`torch.Tensor` of shape `(batch_size, sequence_length)`): |
|
Mask to avoid performing attention on padding token indices. |
|
Mask values select in `[0, 1]`. |
|
- 1 for tokens that are **not masked**. |
|
- 0 for tokens that are **masked**. |
|
""" |
|
video_output = self.get_video_embeddings(videos) |
|
text_output = self.get_text_embeddings(input_ids, attention_mask) |
|
return TextVideoEmbedderOutput(**video_output, **text_output) |
|
|
|
def get_video_embeddings(self, videos: torch.Tensor) -> VideoEmbedderOutput: |
|
videos = (videos - self.normalization_mean) / self.normalization_std |
|
batch_size, num_frames, _, H, W = videos.shape |
|
frame_batch = rearrange(videos, "b t c h w -> (b t) c h w") |
|
|
|
|
|
visual_embs = self.visual_encoder(frame_batch) |
|
visual_embs = self.ln_vision(visual_embs) |
|
visual_embs = rearrange( |
|
visual_embs, |
|
"(b t) k d -> b t k d", |
|
b=batch_size, |
|
t=num_frames, |
|
k=visual_embs.size(1), |
|
d=visual_embs.size(2), |
|
) |
|
|
|
|
|
visual_embs = self.temporal_encoding(visual_embs) |
|
|
|
|
|
encoder_hidden_states = rearrange(visual_embs, "b t k d -> b (t k) d") |
|
encoder_attention_mask = torch.ones(encoder_hidden_states.size()[:-1], dtype=torch.long).to(videos.device) |
|
query_tokens = self.query_tokens.expand(encoder_hidden_states.size(0), -1, -1) |
|
visual_query_output = self.qformer.bert( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
use_cache=True, |
|
return_dict=True, |
|
) |
|
|
|
visual_cls_tokens = visual_query_output.last_hidden_state.mean(dim=1, keepdim=False) |
|
visual_proj = self.vision_proj(visual_cls_tokens) |
|
visual_proj = F.normalize(visual_proj, dim=-1) |
|
|
|
|
|
|
|
frame_cls_tokens, visual_embs = visual_embs[:, :, 0:1], visual_embs[:, :, 1:] |
|
h = H // self.visual_encoder.patch_size |
|
w = W // self.visual_encoder.patch_size |
|
visual_embs = rearrange(visual_embs, "b t (h w) d -> b t h w d", h=h, w=w) |
|
|
|
return VideoEmbedderOutput( |
|
visual_proj=visual_proj, |
|
visual_embs=visual_embs, |
|
visual_query_output=visual_query_output, |
|
visual_cls_tokens=visual_cls_tokens, |
|
frame_cls_tokens=frame_cls_tokens, |
|
) |
|
|
|
def get_text_embeddings( |
|
self, |
|
input_ids: torch.LongTensor, |
|
attention_mask: torch.FloatTensor, |
|
) -> TextEmbedderOutput: |
|
text_query_output = self.qformer.bert( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask.to(dtype=self.query_tokens.dtype), |
|
return_dict=True, |
|
) |
|
text_proj = text_query_output.last_hidden_state[:, 0, :] |
|
text_proj = self.text_proj(text_proj) |
|
text_proj = F.normalize(text_proj, dim=-1) |
|
|
|
return TextEmbedderOutput( |
|
text_proj=text_proj, |
|
text_embs=text_query_output.last_hidden_state, |
|
text_query_output=text_query_output, |
|
) |
|
|
|
@classmethod |
|
@rank0_first |
|
def _init_qformer( |
|
cls: "CosmosEmbed1", |
|
num_query_tokens: int, |
|
encoder_width: int, |
|
vocab_size: int, |
|
hidden_size: int = 768, |
|
) -> tuple[BertLMHeadModel, nn.Parameter]: |
|
"""Convenience function for initializing QFormer module.""" |
|
qformer = load_qformer( |
|
num_query_tokens=num_query_tokens, |
|
encoder_width=encoder_width, |
|
hidden_size=hidden_size, |
|
vocab_size=vocab_size, |
|
) |
|
query_tokens = nn.Parameter(torch.zeros(1, num_query_tokens, hidden_size)) |
|
query_tokens.data.normal_(mean=0.0, std=0.02) |
|
return qformer, query_tokens |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): |
|
|
|
config = kwargs.get("config", None) |
|
if config is None: |
|
config = CosmosEmbed1Config.from_pretrained(pretrained_model_name_or_path) |
|
|
|
if config.transformer_engine: |
|
config_no_te = deepcopy(config) |
|
config_no_te.transformer_engine = False |
|
config_no_te.use_fp8 = False |
|
|
|
|
|
kwargs_no_te = deepcopy(kwargs) |
|
kwargs_no_te["config"] = config_no_te |
|
|
|
|
|
base_model = super().from_pretrained(pretrained_model_name_or_path, **kwargs_no_te) |
|
base_state_dict = base_model.state_dict() |
|
|
|
|
|
model_with_te = cls(config=config) |
|
|
|
|
|
missing, unexpected = model_with_te.load_state_dict(base_state_dict, strict=False) |
|
|
|
|
|
if missing: |
|
print(f"[TransformerEngine] Missing keys: {missing}") |
|
if unexpected: |
|
print(f"[TransformerEngine] Unexpected keys: {unexpected}") |
|
|
|
return model_with_te |
|
else: |
|
return super().from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
|
|
|
|
AutoModel.register(CosmosEmbed1Config, CosmosEmbed1) |
|
|