|
import torch |
|
from typing import Union, Tuple |
|
|
|
|
|
def _to_tuple(x, dim=2): |
|
if isinstance(x, int): |
|
return (x,) * dim |
|
elif len(x) == dim: |
|
return x |
|
else: |
|
raise ValueError(f"Expected length {dim} or int, but got {x}") |
|
|
|
|
|
def get_meshgrid_nd(start, *args, dim=2): |
|
""" |
|
Get n-D meshgrid with start, stop and num. |
|
|
|
Args: |
|
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, |
|
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num |
|
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in |
|
n-tuples. |
|
*args: See above. |
|
dim (int): Dimension of the meshgrid. Defaults to 2. |
|
|
|
Returns: |
|
grid (np.ndarray): [dim, ...] |
|
""" |
|
if len(args) == 0: |
|
|
|
num = _to_tuple(start, dim=dim) |
|
start = (0,) * dim |
|
stop = num |
|
elif len(args) == 1: |
|
|
|
start = _to_tuple(start, dim=dim) |
|
stop = _to_tuple(args[0], dim=dim) |
|
num = [stop[i] - start[i] for i in range(dim)] |
|
elif len(args) == 2: |
|
|
|
start = _to_tuple(start, dim=dim) |
|
stop = _to_tuple(args[0], dim=dim) |
|
num = _to_tuple(args[1], dim=dim) |
|
else: |
|
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") |
|
|
|
|
|
axis_grid = [] |
|
for i in range(dim): |
|
a, b, n = start[i], stop[i], num[i] |
|
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] |
|
axis_grid.append(g) |
|
grid = torch.meshgrid(*axis_grid, indexing="ij") |
|
grid = torch.stack(grid, dim=0) |
|
|
|
return grid |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_nd_rotary_pos_embed( |
|
rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor=1.0, freq_scaling=1.0 |
|
): |
|
""" |
|
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. |
|
|
|
Args: |
|
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. |
|
sum(rope_dim_list) should equal to head_dim of attention layer. |
|
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, |
|
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. |
|
*args: See above. |
|
theta (float): Scaling factor for frequency computation. Defaults to 10000.0. |
|
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. |
|
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real |
|
part and an imaginary part separately. |
|
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. |
|
freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0. |
|
|
|
Returns: |
|
pos_embed (torch.Tensor): [HW, D/2] |
|
""" |
|
|
|
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) |
|
|
|
|
|
embs = [] |
|
for i in range(len(rope_dim_list)): |
|
emb = get_1d_rotary_pos_embed( |
|
rope_dim_list[i], |
|
grid[i].reshape(-1), |
|
theta, |
|
use_real=use_real, |
|
theta_rescale_factor=theta_rescale_factor, |
|
freq_scaling=freq_scaling, |
|
) |
|
embs.append(emb) |
|
|
|
if use_real: |
|
cos = torch.cat([emb[0] for emb in embs], dim=1) |
|
sin = torch.cat([emb[1] for emb in embs], dim=1) |
|
return cos, sin |
|
else: |
|
emb = torch.cat(embs, dim=1) |
|
return emb |
|
|
|
|
|
def get_1d_rotary_pos_embed( |
|
dim: int, |
|
pos: Union[torch.FloatTensor, int], |
|
theta: float = 10000.0, |
|
use_real: bool = False, |
|
theta_rescale_factor: float = 1.0, |
|
freq_scaling: float = 1.0, |
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
|
""" |
|
Precompute the frequency tensor for complex exponential (cis) with given dimensions. |
|
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) |
|
|
|
This function calculates a frequency tensor with complex exponential using the given dimension 'dim' |
|
and the end index 'end'. The 'theta' parameter scales the frequencies. |
|
The returned tensor contains complex values in complex64 data type. |
|
|
|
Args: |
|
dim (int): Dimension of the frequency tensor. |
|
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar |
|
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. |
|
use_real (bool, optional): If True, return real part and imaginary part separately. |
|
Otherwise, return complex numbers. |
|
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. |
|
freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0. |
|
|
|
Returns: |
|
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] |
|
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] |
|
""" |
|
if isinstance(pos, int): |
|
pos = torch.arange(pos).float() |
|
|
|
|
|
|
|
|
|
if theta_rescale_factor != 1.0: |
|
theta *= theta_rescale_factor ** (dim / (dim - 1)) |
|
|
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) |
|
freqs *= freq_scaling |
|
freqs = torch.outer(pos, freqs) |
|
if use_real: |
|
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) |
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) |
|
return freqs_cos, freqs_sin |
|
else: |
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
|
return freqs_cis |
|
|