Hunyuan-Avatar / hymm_sp /helpers.py
rahul7star's picture
Upload 99 files
357c94c verified
import torch
from typing import Union, List
from hymm_sp.modules.posemb_layers import get_1d_rotary_pos_embed, get_meshgrid_nd
from itertools import repeat
import collections.abc
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
x = tuple(x)
if len(x) == 1:
x = tuple(repeat(x[0], n))
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
def get_rope_freq_from_size(latents_size, ndim, target_ndim, args,
rope_theta_rescale_factor: Union[float, List[float]]=1.0,
rope_interpolation_factor: Union[float, List[float]]=1.0,
concat_dict={}):
if isinstance(args.patch_size, int):
assert all(s % args.patch_size == 0 for s in latents_size), \
f"Latent size(last {ndim} dimensions) should be divisible by patch size({args.patch_size}), " \
f"but got {latents_size}."
rope_sizes = [s // args.patch_size for s in latents_size]
elif isinstance(args.patch_size, list):
assert all(s % args.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \
f"Latent size(last {ndim} dimensions) should be divisible by patch size({args.patch_size}), " \
f"but got {latents_size}."
rope_sizes = [s // args.patch_size[idx] for idx, s in enumerate(latents_size)]
if len(rope_sizes) != target_ndim:
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
head_dim = args.hidden_size // args.num_heads
rope_dim_list = args.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list,
rope_sizes,
theta=args.rope_theta,
use_real=True,
theta_rescale_factor=rope_theta_rescale_factor,
interpolation_factor=rope_interpolation_factor,
concat_dict=concat_dict)
return freqs_cos, freqs_sin
def get_nd_rotary_pos_embed_new(rope_dim_list, start, *args, theta=10000., use_real=False,
theta_rescale_factor: Union[float, List[float]]=1.0,
interpolation_factor: Union[float, List[float]]=1.0,
concat_dict={}
):
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
if len(concat_dict)<1:
pass
else:
if concat_dict['mode']=='timecat':
bias = grid[:,:1].clone()
bias[0] = concat_dict['bias']*torch.ones_like(bias[0])
grid = torch.cat([bias, grid], dim=1)
elif concat_dict['mode']=='timecat-w':
bias = grid[:,:1].clone()
bias[0] = concat_dict['bias']*torch.ones_like(bias[0])
bias[2] += start[-1] ## ref https://github.com/Yuanshi9815/OminiControl/blob/main/src/generate.py#L178
grid = torch.cat([bias, grid], dim=1)
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real,
theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i]) # 2 x [WHD, rope_dim_list[i]]
embs.append(emb)
if use_real:
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
return cos, sin
else:
emb = torch.cat(embs, dim=1) # (WHD, D/2)
return emb