Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import math | |
from typing import List, Optional, Tuple | |
import numpy as np | |
import torch | |
from einops import rearrange, repeat | |
from megatron.core import parallel_state | |
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
""" | |
embed_dim: output dimension for each position | |
pos: a list of positions to be encoded: size (M,) | |
out: (M, D) | |
""" | |
assert embed_dim % 2 == 0 | |
omega = np.arange(embed_dim // 2, dtype=np.float64) | |
omega /= embed_dim / 2.0 | |
omega = 1.0 / 10000**omega # (D/2,) | |
pos = pos.reshape(-1) # (M,) | |
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
emb_sin = np.sin(out) # (M, D/2) | |
emb_cos = np.cos(out) # (M, D/2) | |
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
return emb | |
def _rotate_half_te(x: torch.Tensor) -> torch.Tensor: | |
""" | |
change sign so the last dimension becomes [-odd, +even]. | |
Adopted from TransformerEngine. | |
Source: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py | |
""" | |
x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2))) | |
x1, x2 = x.unbind(dim=-2) | |
return torch.cat((-x2, x1), dim=-1) | |
def _apply_rotary_pos_emb_te( | |
t: torch.Tensor, | |
cos_freqs: torch.Tensor, | |
sin_freqs: torch.Tensor, | |
) -> torch.Tensor: | |
""" | |
Apply rotary positional embedding tensor to the input tensor. | |
Adopted from TransformerEngine. | |
Source: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py | |
Parameters | |
---------- | |
t: torch.Tensor | |
Input tensor of shape `[b, s, h, d]`, on which | |
rotary positional embedding will be applied. | |
cos_freqs: torch.Tensor | |
Cosine component of rotary positional embedding tensor of shape `[s, 1, 1, d]` and dtype 'float', | |
sin_freqs: torch.Tensor | |
Sine component of rotary positional embedding tensor of shape `[s, 1, 1, d]` and dtype 'float', | |
""" | |
rot_dim = cos_freqs.shape[-1] | |
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t | |
t, t_pass = t[..., :rot_dim], t[..., rot_dim:] | |
# first part is cosine component | |
# second part is sine component, need to change signs with _rotate_half method | |
t = (t * cos_freqs) + (_rotate_half_te(t) * sin_freqs) | |
output = torch.cat((t, t_pass), dim=-1) | |
return output | |
def get_pos_emb_on_this_cp_rank(pos_emb: torch.Tensor, seq_dim: int) -> torch.Tensor: | |
""" | |
Get the position embedding for the current context parallel rank. | |
Args: | |
pos_emb (torch.Tensor): The position embedding tensor. | |
seq_dim (int): The sequence dimension to slice. | |
Returns: | |
torch.Tensor: The position embedding tensor for the current rank. | |
""" | |
cp_size = parallel_state.get_context_parallel_world_size() | |
cp_rank = parallel_state.get_context_parallel_rank() | |
cp_idx = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(non_blocking=True) | |
pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], 2 * cp_size, -1, *pos_emb.shape[(seq_dim + 1) :]) | |
pos_emb = pos_emb.index_select(seq_dim, cp_idx) | |
pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :]) | |
return pos_emb | |
def get_pos_emb_on_this_sptp_rank(pos_emb: torch.Tensor, seq_dim: int) -> torch.Tensor: | |
""" | |
Get the position embedding for the current tensor parallel rank (only used when sequence parallel is turned on) | |
Args: | |
pos_emb (torch.Tensor): The position embedding tensor. | |
seq_dim (int): The sequence dimension to slice. | |
Returns: | |
torch.Tensor: The position embedding tensor for the current rank. | |
""" | |
tp_size = parallel_state.get_tensor_model_parallel_world_size() | |
tp_rank = parallel_state.get_tensor_model_parallel_rank() | |
pos_emb_chunks = torch.chunk(pos_emb, tp_size, dim=seq_dim) | |
pos_emb = pos_emb_chunks[tp_rank] | |
return pos_emb | |
class RotaryPositionEmbedding(torch.nn.Module): | |
""" | |
Rotary Position Embedding module as described in the paper: | |
https://arxiv.org/abs/2104.09864 | |
This module implements rotary positional embeddings, which are used to | |
enhance the performance of transformer models. | |
Args: | |
dim (int): Dimensionality of the input tensor. | |
max_position_embeddings (Optional[int]): Maximum position embeddings. | |
original_max_position_embeddings (Optional[int]): Original maximum position embeddings. | |
rope_theta (Optional[float]): Base for the frequency calculation. | |
apply_yarn (Optional[bool]): Whether to apply YaRN (Yet another Rotary). | |
scale (Optional[int]): Scaling factor for the frequency calculation. | |
extrapolation_factor (Optional[int]): Extrapolation factor for the frequency extension. | |
attn_factor (Optional[int]): Attention factor for the frequency calculation. | |
beta_fast (Optional[int]): Fast beta value for the YaRN frequency calculation. | |
beta_slow (Optional[int]): Slow beta value for the YaRN frequency calculation. | |
rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "2D", "3D". | |
latent_shape (Optional[List[int]]): Shape of the latent tensor for video or image inputs. | |
original_latent_shape (Optional[List[int]]): Original shape of the latent tensor for video or image inputs. | |
pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value. | |
""" | |
def __init__( | |
self, | |
dim: int, | |
max_position_embeddings: Optional[int] = None, | |
original_max_position_embeddings: Optional[int] = None, | |
rope_theta: Optional[float] = 10000.0, | |
apply_yarn: Optional[bool] = False, | |
scale: Optional[int] = None, | |
extrapolation_factor: Optional[int] = 1, | |
attn_factor: Optional[int] = 1, | |
beta_fast: Optional[int] = 32, | |
beta_slow: Optional[int] = 1, | |
rope_dim: Optional[str] = "1D", | |
latent_shape: Optional[List[int]] = None, | |
original_latent_shape: Optional[List[int]] = None, | |
pad_to_multiple_of: Optional[int] = None, | |
): | |
super().__init__() | |
self.dim = dim | |
self.max_position_embeddings = max_position_embeddings | |
self.original_max_position_embeddings = original_max_position_embeddings | |
self.rope_theta = rope_theta | |
self.apply_yarn = apply_yarn | |
self.scale = scale | |
self.extrapolation_factor = extrapolation_factor | |
self.attn_factor = attn_factor | |
self.beta_fast = beta_fast | |
self.beta_slow = beta_slow | |
self.mscale = 1.0 | |
self.rope_dim = rope_dim | |
self.latent_shape = latent_shape | |
self.original_latent_shape = original_latent_shape | |
self.pad_to_multiple_of = pad_to_multiple_of | |
self.get_inv_freq(torch.cuda.current_device()) | |
def get_mscale(self, scale: float = 1.0) -> float: | |
"""Get the magnitude scaling factor for YaRN.""" | |
if scale <= 1: | |
return 1.0 | |
return 0.1 * math.log(scale) + 1.0 | |
def forward(self, seq_len: Optional[int] = None) -> torch.Tensor: | |
""" | |
Forward pass for the rotary position embedding. | |
Args: | |
seq_len (Optional[int]): Length of the sequence. | |
Returns: | |
torch.Tensor: The computed frequencies for positional embedding. | |
""" | |
if self.apply_yarn and seq_len > self.max_seq_len_cached: | |
self.max_seq_len_cached = seq_len | |
self.freqs = self.compute_freqs() | |
return self.freqs | |
def compute_freqs( | |
self, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Compute the spatial frequencies for the latent tensor.""" | |
self.seq = torch.arange(self.max_seq_len_cached, dtype=torch.float).cuda() | |
if self.rope_dim == "1D": | |
emb = torch.einsum("i,j->ij", self.seq, self.inv_freq) | |
elif self.rope_dim == "2D": | |
H, W = self.latent_shape | |
half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq) | |
half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq) | |
emb = torch.cat( | |
[ | |
repeat(half_emb_h, "h d -> h w d", w=W), | |
repeat(half_emb_w, "w d -> h w d", h=H), | |
] | |
* 2, | |
dim=-1, | |
) | |
emb = rearrange(emb, "h w d -> (h w) 1 1 d").float() | |
elif self.rope_dim == "3D": | |
T, H, W = self.latent_shape | |
half_emb_t = torch.outer(self.seq[:T], self.temporal_inv_freq) | |
half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq) | |
half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq) | |
emb = torch.cat( | |
[ | |
repeat(half_emb_t, "t d -> t h w d", h=H, w=W), | |
repeat(half_emb_h, "h d -> t h w d", t=T, w=W), | |
repeat(half_emb_w, "w d -> t h w d", t=T, h=H), | |
] | |
* 2, | |
dim=-1, | |
) | |
emb = rearrange(emb, "t h w d -> (t h w) 1 1 d").float() | |
else: | |
raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") | |
return emb | |
def get_scale_factors(self, inv_freq: torch.Tensor, original_seq_len: int) -> torch.Tensor: | |
"""Get the scale factors for YaRN.""" | |
# Calculate the high and low frequency cutoffs for YaRN. Note: `beta_fast` and `beta_slow` are called | |
# `high_freq_factor` and `low_freq_factor` in the Llama 3.1 RoPE scaling code. | |
high_freq_cutoff = 2 * math.pi * self.beta_fast / original_seq_len | |
low_freq_cutoff = 2 * math.pi * self.beta_slow / original_seq_len | |
# Obtain a smooth mask that has a value of 0 for low frequencies and 1 for high frequencies, with linear | |
# interpolation in between. | |
smooth_mask = torch.clamp((inv_freq - low_freq_cutoff) / (high_freq_cutoff - low_freq_cutoff), min=0, max=1) | |
# For low frequencies, we scale the frequency by 1/self.scale. For high frequencies, we keep the frequency. | |
scale_factors = (1 - smooth_mask) / self.scale + smooth_mask | |
return scale_factors | |
def get_inv_freq(self, device: torch.device) -> None: | |
"""Get the inverse frequency.""" | |
if self.rope_dim == "1D": | |
assert self.max_position_embeddings is not None, "Max position embeddings required." | |
inv_freq = 1.0 / ( | |
self.rope_theta ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim) | |
) | |
if self.apply_yarn: | |
assert self.original_max_position_embeddings is not None, "Original max position embeddings required." | |
assert self.beta_slow is not None, "Beta slow value required." | |
assert self.beta_fast is not None, "Beta fast value required." | |
scale_factors = self.get_scale_factors(inv_freq, self.original_max_position_embeddings) | |
# Apply the scaling factors to inv_freq. | |
inv_freq = inv_freq * scale_factors | |
# Set the magnitude scaling factor. | |
self.mscale = float(self.get_mscale(self.scale) * self.attn_factor) | |
self.max_seq_len_cached = self.max_position_embeddings | |
self.inv_freq = inv_freq | |
elif self.rope_dim == "2D": | |
assert self.latent_shape is not None, "Latent shape required." | |
dim_h = self.dim // 2 | |
spatial_inv_freq = 1.0 / ( | |
self.rope_theta ** torch.arange(0, dim_h, 2, dtype=torch.float32, device=device) / dim_h | |
) | |
if self.apply_yarn: | |
assert self.original_latent_shape is not None, "Original latent shape required." | |
assert self.beta_slow is not None, "Beta slow value required." | |
assert self.beta_fast is not None, "Beta fast value required." | |
scale_factors = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[0]) | |
spatial_inv_freq = spatial_inv_freq * scale_factors | |
self.mscale = float(self.get_mscale(self.scale) * self.attn_factor) | |
self.spatial_inv_freq = spatial_inv_freq | |
self.max_seq_len_cached = max(self.latent_shape) | |
elif self.rope_dim == "3D": | |
assert self.latent_shape is not None, "Latent shape required." | |
dim_h = self.dim // 6 * 2 | |
dim_t = self.dim - 2 * dim_h | |
self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(device) / dim_h | |
spatial_inv_freq = 1.0 / (self.rope_theta**self.dim_spatial_range) | |
self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(device) / dim_t | |
temporal_inv_freq = 1.0 / (self.rope_theta**self.dim_temporal_range) | |
if self.apply_yarn: | |
assert self.original_latent_shape is not None, "Original latent shape required." | |
assert self.beta_slow is not None, "Beta slow value required." | |
assert self.beta_fast is not None, "Beta fast value required." | |
scale_factors_spatial = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[1]) | |
spatial_inv_freq = spatial_inv_freq * scale_factors_spatial | |
scale_factors_temporal = self.get_scale_factors(temporal_inv_freq, self.original_latent_shape[0]) | |
temporal_inv_freq = temporal_inv_freq * scale_factors_temporal | |
self.mscale = float(self.get_mscale(self.scale) * self.attn_factor) | |
self.spatial_inv_freq = spatial_inv_freq | |
self.temporal_inv_freq = temporal_inv_freq | |
self.max_seq_len_cached = max(self.latent_shape) | |
else: | |
raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") | |
self.freqs = self.compute_freqs() | |
class RotaryPositionEmbeddingTE(RotaryPositionEmbedding): | |
""" | |
Rotary Position Embedding with context parallelism support. | |
""" | |
def __init__( | |
self, | |
**kwargs, | |
): | |
super().__init__( | |
**kwargs, | |
) | |
def forward(self, seq_len: int, training_type: str = None) -> torch.Tensor: | |
""" | |
Create rotary position embedding frequencies. | |
Args: | |
seq_len (int): Sequence length of a sample. | |
Returns: | |
torch.Tensor: The computed positional embeddings. | |
""" | |
if self.rope_dim == "1D": | |
freqs = super().forward(seq_len=seq_len) | |
emb = torch.cat((freqs, freqs), dim=-1) | |
emb = emb.reshape(emb.size(0), 1, 1, emb.size(1)) | |
elif self.rope_dim in ["2D", "3D"]: | |
emb = super().forward(seq_len=seq_len) | |
if training_type == "text_to_video": | |
# since we added <bov> token at the beginning of the video for text2video, we also extend the position embedding by one token in the beginning | |
bov_pe = torch.zeros((1, *emb.shape[1:]), device=emb.device) | |
emb = torch.cat((bov_pe, emb), dim=0) | |
else: | |
raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") | |
if self.pad_to_multiple_of is not None and emb.shape[0] % self.pad_to_multiple_of != 0: | |
# Round up to the nearest multiple of pad_to_multiple_of | |
pad_len = self.pad_to_multiple_of - emb.shape[0] % self.pad_to_multiple_of | |
emb = torch.cat((emb, torch.zeros((pad_len, *emb.shape[1:]), device=emb.device)), dim=0) | |
return emb | |
class RotaryPositionEmbeddingPytorch(RotaryPositionEmbedding): | |
""" | |
Rotary Position Embedding with PyTorch specific adjustments. | |
""" | |
def __init__( | |
self, | |
**kwargs, | |
): | |
super().__init__( | |
**kwargs, | |
) | |
if self.rope_dim == "1D": | |
emb = torch.stack((self.freqs, self.freqs), dim=-1).reshape(*self.freqs.shape[:-1], -1) | |
elif self.rope_dim in ["2D", "3D"]: | |
emb = rearrange(self.freqs, "s 1 1 d -> s d").float() | |
self.register_buffer("cos_cached", (emb.cos() * self.mscale)[None, :, None, :], persistent=False) | |
self.register_buffer("sin_cached", (emb.sin() * self.mscale)[None, :, None, :], persistent=False) | |
def rotate_half(self, x: torch.Tensor) -> torch.Tensor: | |
"""Rotate half the hidden dimensions of the input tensor.""" | |
x_reshaped = x.reshape(*x.shape[:-1], -1, 2) | |
x1 = x_reshaped[..., 0] | |
x2 = x_reshaped[..., 1] | |
output = torch.stack((-x2, x1), dim=-1).reshape(*x.shape) | |
return output | |
def forward( | |
self, q: torch.Tensor, k: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seq_len: Optional[int] = None | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Forward pass for the rotary position embedding. | |
Args: | |
q (torch.Tensor): Query tensor. | |
k (torch.Tensor): Key tensor. | |
input_pos (Optional[torch.Tensor]): Starting position for the sequence. | |
seq_len (Optional[int]): Length of the sequence. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors. | |
""" | |
if self.apply_yarn and seq_len > self.max_seq_len_cached: | |
freqs = super().forward(seq_len) | |
if self.rope_dim == "1D": | |
emb = torch.stack((freqs, freqs), dim=-1).reshape(*freqs.shape[:-1], -1) | |
elif self.rope_dim in ["2D", "3D"]: | |
emb = rearrange(freqs, "s 1 1 d -> s d").float() | |
else: | |
raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") | |
self.register_buffer( | |
"cos_cached", (emb.cos() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False | |
) | |
self.register_buffer( | |
"sin_cached", (emb.sin() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False | |
) | |
if input_pos is not None: | |
cos_cached = self.cos_cached[:, input_pos] | |
sin_cached = self.sin_cached[:, input_pos] | |
else: | |
assert ( | |
self.cos_cached.shape[1] >= seq_len | |
), f"Invalid sequence length; cos_cached.shape {self.cos_cached.shape}, seq_len {seq_len}." | |
cos_cached = self.cos_cached[:, :seq_len, ...] | |
sin_cached = self.sin_cached[:, :seq_len, ...] | |
xq = q * cos_cached + self.rotate_half(q) * sin_cached | |
xk = k * cos_cached + self.rotate_half(k) * sin_cached | |
return xq.type_as(q), xk.type_as(k) | |
class RotaryPositionEmbeddingPytorchV2(RotaryPositionEmbedding): | |
""" | |
Rotary Position Embedding that works in the same way as the TransformerEngine RoPE | |
(https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) | |
""" | |
def __init__( | |
self, | |
seq_len: int, | |
training_type: str = None, | |
**kwargs, | |
): | |
super().__init__( | |
**kwargs, | |
) | |
emb = self.create_rope_freqs(seq_len=seq_len, training_type=training_type) | |
emb = emb.transpose(0, 1).contiguous() # [seq, 1, 1, dim] -> [1, seq, 1, dim] | |
assert emb.shape[0] == 1 and emb.shape[2] == 1, f"emb shape: {emb.shape}" | |
# cos/sin first then dtype conversion for better precision | |
self.register_buffer("cos_cached", torch.cos(emb), persistent=False) | |
self.register_buffer("sin_cached", torch.sin(emb), persistent=False) | |
def create_rope_freqs(self, seq_len: int, training_type: str = None) -> torch.Tensor: | |
""" | |
Create rotary position embedding frequencies. | |
Args: | |
seq_len (int): Sequence length of a sample. | |
Returns: | |
torch.Tensor: The computed positional embeddings. | |
""" | |
if self.rope_dim == "1D": | |
freqs = super().forward(seq_len=seq_len) | |
emb = torch.cat((freqs, freqs), dim=-1) | |
emb = emb.reshape(emb.size(0), 1, 1, emb.size(1)) | |
elif self.rope_dim in ["2D", "3D"]: | |
emb = super().forward(seq_len=seq_len) | |
if training_type == "text_to_video": | |
# since we added <bov> token at the beginning of the video for text2world, we also extend the position embedding by one token in the beginning | |
bov_pe = torch.zeros((1, *emb.shape[1:]), device=emb.device) | |
emb = torch.cat((bov_pe, emb), dim=0) | |
else: | |
raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") | |
if self.pad_to_multiple_of is not None and emb.shape[0] % self.pad_to_multiple_of != 0: | |
# Round up to the nearest multiple of pad_to_multiple_of | |
pad_len = self.pad_to_multiple_of - emb.shape[0] % self.pad_to_multiple_of | |
emb = torch.cat((emb, torch.zeros((pad_len, *emb.shape[1:]), device=emb.device)), dim=0) | |
return emb | |
def forward( | |
self, q: torch.Tensor, k: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seq_len: Optional[int] = None | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
if q.dtype != self.cos_cached.dtype: | |
self.cos_cached = self.cos_cached.to(q.dtype) | |
self.sin_cached = self.sin_cached.to(q.dtype) | |
cos_emb = self.cos_cached | |
sin_emb = self.sin_cached | |
if input_pos is not None: | |
cos_emb = cos_emb[:, input_pos, :, :] | |
sin_emb = sin_emb[:, input_pos, :, :] | |
elif seq_len is not None: | |
cos_emb = cos_emb[:, :seq_len, :, :] | |
sin_emb = sin_emb[:, :seq_len, :, :] | |
q = _apply_rotary_pos_emb_te(q, cos_emb, sin_emb) | |
k = _apply_rotary_pos_emb_te(k, cos_emb, sin_emb) | |
return q, k | |
class RotaryPositionEmbeddingPytorchV1(RotaryPositionEmbedding): | |
""" | |
Rotary Position Embedding that works in the same way as | |
mistral_inference (https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/rope.py) | |
or llama3 (https://github.com/meta-llama/llama3/blob/main/llama/model.py) | |
""" | |
def __init__( | |
self, | |
**kwargs, | |
): | |
super().__init__( | |
**kwargs, | |
) | |
if self.rope_dim == "1D": | |
emb = torch.stack((self.freqs, self.freqs), dim=-1).reshape(*self.freqs.shape[:-1], -1) | |
elif self.rope_dim in ["2D", "3D"]: | |
emb = rearrange(self.freqs, "s 1 1 d -> s d").float() | |
self.register_buffer("cos_cached", (emb.cos() * self.mscale)[None, :, None, :], persistent=False) | |
self.register_buffer("sin_cached", (emb.sin() * self.mscale)[None, :, None, :], persistent=False) | |
def rotate_half(self, x: torch.Tensor) -> torch.Tensor: | |
"""Rotate half the hidden dimensions of the input tensor.""" | |
x_reshaped = x.reshape(*x.shape[:-1], -1, 2) | |
x1 = x_reshaped[..., 0] | |
x2 = x_reshaped[..., 1] | |
output = torch.stack((-x2, x1), dim=-1).reshape(*x.shape) | |
return output | |
def forward( | |
self, q: torch.Tensor, k: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seq_len: Optional[int] = None | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Forward pass for the rotary position embedding. | |
Args: | |
q (torch.Tensor): Query tensor. | |
k (torch.Tensor): Key tensor. | |
input_pos (Optional[torch.Tensor]): Starting position for the sequence. | |
seq_len (Optional[int]): Length of the sequence. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors. | |
""" | |
if self.apply_yarn and seq_len > self.max_seq_len_cached: | |
freqs = super().forward(seq_len) | |
if self.rope_dim == "1D": | |
emb = torch.stack((freqs, freqs), dim=-1).reshape(*freqs.shape[:-1], -1) | |
elif self.rope_dim in ["2D", "3D"]: | |
emb = rearrange(freqs, "s 1 1 d -> s d").float() | |
else: | |
raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") | |
self.register_buffer( | |
"cos_cached", (emb.cos() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False | |
) | |
self.register_buffer( | |
"sin_cached", (emb.sin() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False | |
) | |
if input_pos is not None: | |
cos_cached = self.cos_cached[:, input_pos] | |
sin_cached = self.sin_cached[:, input_pos] | |
else: | |
assert ( | |
self.cos_cached.shape[1] >= seq_len | |
), f"Invalid sequence length; cos_cached.shape {self.cos_cached.shape}, seq_len {seq_len}." | |
cos_cached = self.cos_cached[:, :seq_len, ...] | |
sin_cached = self.sin_cached[:, :seq_len, ...] | |
xq = q * cos_cached + self.rotate_half(q) * sin_cached | |
xk = k * cos_cached + self.rotate_half(k) * sin_cached | |
return xq.type_as(q), xk.type_as(k) | |
class SinCosPosEmbAxisTE(torch.nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
latent_shape: Optional[List[int]] = None, | |
pad_to_multiple_of: Optional[int] = None, | |
dtype: torch.dtype = torch.bfloat16, | |
device="cuda", | |
**kwargs, | |
): | |
""" | |
Args: | |
dim (int): Dimensionality of the input tensor. | |
latent_shape (Optional[List[int]]): Shape of the latent tensor for video or image inputs. | |
pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value. | |
dtype (torch.dtype): Data type of the position embedding tensor. | |
""" | |
super().__init__() | |
dim_h = dim // 6 * 2 | |
dim_w = dim_h | |
dim_t = dim - 2 * dim_h | |
assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" | |
self.latent_shape = latent_shape | |
T, H, W = latent_shape | |
emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, pos=np.arange(H)) | |
emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, pos=np.arange(W)) | |
emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, pos=np.arange(T)) | |
self.register_buffer("pos_emb_h", torch.from_numpy(emb_h).to(dtype=dtype, device=device), persistent=False) | |
self.register_buffer("pos_emb_w", torch.from_numpy(emb_w).to(dtype=dtype, device=device), persistent=False) | |
self.register_buffer("pos_emb_t", torch.from_numpy(emb_t).to(dtype=dtype, device=device), persistent=False) | |
self.pad_to_multiple_of = pad_to_multiple_of | |
def forward( | |
self, | |
training_type: str | None = None, | |
) -> torch.Tensor: | |
T, H, W = self.latent_shape | |
emb = torch.cat( | |
[ | |
repeat(self.pos_emb_t, "t d-> t h w d", h=H, w=W), | |
repeat(self.pos_emb_h, "h d-> t h w d", t=T, w=W), | |
repeat(self.pos_emb_w, "w d-> t h w d", t=T, h=H), | |
], | |
dim=-1, | |
) | |
# Flatten the T,H,W dimensions | |
emb = rearrange(emb, "t h w d -> (t h w) d") | |
if training_type == "text_to_video": | |
bov_pe = torch.zeros((1, *emb.shape[1:]), device=emb.device, dtype=emb.dtype) | |
emb = torch.cat((bov_pe, emb), dim=0) | |
if self.pad_to_multiple_of is not None and emb.shape[0] % self.pad_to_multiple_of != 0: | |
pad_len = self.pad_to_multiple_of - emb.shape[0] % self.pad_to_multiple_of | |
emb = torch.cat((emb, torch.zeros((pad_len, *emb.shape[1:]), device=emb.device, dtype=emb.dtype)), dim=0) | |
seq_len, dim = emb.shape | |
emb = emb.reshape(1, seq_len, dim) | |
return emb | |