import numpy as np import torch import torch.amp as amp from torch.backends.cuda import sdp_kernel from xfuser.core.distributed import get_sequence_parallel_rank from xfuser.core.distributed import get_sequence_parallel_world_size from xfuser.core.distributed import get_sp_group from xfuser.core.long_ctx_attention import xFuserLongContextAttention from ..modules.transformer import sinusoidal_embedding_1d def pad_freqs(original_tensor, target_len): seq_len, s1, s2 = original_tensor.shape pad_size = target_len - seq_len padding_tensor = torch.ones(pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device) padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) return padded_tensor @amp.autocast("cuda", enabled=False) def rope_apply(x, grid_sizes, freqs): """ x: [B, L, N, C]. grid_sizes: [B, 3]. freqs: [M, C // 2]. """ s, n, c = x.size(1), x.size(2), x.size(3) // 2 # split freqs freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) # loop over samples output = [] grid = [grid_sizes.tolist()] * x.size(0) for i, (f, h, w) in enumerate(grid): seq_len = f * h * w # precompute multipliers x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2)) freqs_i = torch.cat( [ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), ], dim=-1, ).reshape(seq_len, 1, -1) # apply rotary embedding sp_size = get_sequence_parallel_world_size() sp_rank = get_sequence_parallel_rank() freqs_i = pad_freqs(freqs_i, s * sp_size) s_per_rank = s freqs_i_rank = freqs_i[(sp_rank * s_per_rank) : ((sp_rank + 1) * s_per_rank), :, :] x_i = torch.view_as_real(x_i * freqs_i_rank.cuda()).flatten(2) x_i = torch.cat([x_i, x[i, s:]]) # append to collection output.append(x_i) return torch.stack(output).float() def broadcast_should_calc(should_calc: bool) -> bool: import torch.distributed as dist device = torch.cuda.current_device() int_should_calc = 1 if should_calc else 0 tensor = torch.tensor([int_should_calc], device=device, dtype=torch.int8) dist.broadcast(tensor, src=0) should_calc = tensor.item() == 1 return should_calc def usp_dit_forward(self, x, t, context, clip_fea=None, y=None, fps=None): """ x: A list of videos each with shape [C, T, H, W]. t: [B]. context: A list of text embeddings each with shape [L, C]. """ if self.model_type == "i2v": assert clip_fea is not None and y is not None # params device = self.patch_embedding.weight.device if self.freqs.device != device: self.freqs = self.freqs.to(device) if y is not None: x = torch.cat([x, y], dim=1) # embeddings x = self.patch_embedding(x) grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long) x = x.flatten(2).transpose(1, 2) if self.flag_causal_attention: frame_num = grid_sizes[0] height = grid_sizes[1] width = grid_sizes[2] block_num = frame_num // self.num_frame_per_block range_tensor = torch.arange(block_num).view(-1, 1) range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten() casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device) casual_mask = casual_mask.repeat(1, height, width, 1, height, width) casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width) self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0) # time embeddings with amp.autocast("cuda", dtype=torch.float32): if t.dim() == 2: b, f = t.shape _flag_df = True else: _flag_df = False e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim if self.inject_sample_info: fps = torch.tensor(fps, dtype=torch.long, device=device) fps_emb = self.fps_embedding(fps).float() if _flag_df: e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1) else: e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)) if _flag_df: e = e.view(b, f, 1, 1, self.dim) e0 = e0.view(b, f, 1, 1, 6, self.dim) e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3) e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3) e0 = e0.transpose(1, 2).contiguous() assert e.dtype == torch.float32 and e0.dtype == torch.float32 # context context = self.text_embedding(context) if clip_fea is not None: context_clip = self.img_emb(clip_fea) # bs x 257 x dim context = torch.concat([context_clip, context], dim=1) # arguments if e0.ndim == 4: e0 = torch.chunk(e0, get_sequence_parallel_world_size(), dim=2)[get_sequence_parallel_rank()] kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask) if self.enable_teacache: modulated_inp = e0 if self.use_ref_steps else e # teacache if self.cnt % 2 == 0: # even -> conditon self.is_even = True if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: should_calc_even = True self.accumulated_rel_l1_distance_even = 0 else: rescale_func = np.poly1d(self.coefficients) self.accumulated_rel_l1_distance_even += rescale_func( ((modulated_inp - self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()) .cpu() .item() ) if self.accumulated_rel_l1_distance_even < self.teacache_thresh: should_calc_even = False else: should_calc_even = True self.accumulated_rel_l1_distance_even = 0 self.previous_e0_even = modulated_inp.clone() else: # odd -> unconditon self.is_even = False if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: should_calc_odd = True self.accumulated_rel_l1_distance_odd = 0 else: rescale_func = np.poly1d(self.coefficients) self.accumulated_rel_l1_distance_odd += rescale_func( ((modulated_inp - self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()) .cpu() .item() ) if self.accumulated_rel_l1_distance_odd < self.teacache_thresh: should_calc_odd = False else: should_calc_odd = True self.accumulated_rel_l1_distance_odd = 0 self.previous_e0_odd = modulated_inp.clone() x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] if self.enable_teacache: if self.is_even: should_calc_even = broadcast_should_calc(should_calc_even) if not should_calc_even: x += self.previous_residual_even else: ori_x = x.clone() for block in self.blocks: x = block(x, **kwargs) ori_x.mul_(-1) ori_x.add_(x) self.previous_residual_even = ori_x else: should_calc_odd = broadcast_should_calc(should_calc_odd) if not should_calc_odd: x += self.previous_residual_odd else: ori_x = x.clone() for block in self.blocks: x = block(x, **kwargs) ori_x.mul_(-1) ori_x.add_(x) self.previous_residual_odd = ori_x self.cnt += 1 if self.cnt >= self.num_steps: self.cnt = 0 else: # Context Parallel for block in self.blocks: x = block(x, **kwargs) # head if e.ndim == 3: e = torch.chunk(e, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] x = self.head(x, e) # Context Parallel x = get_sp_group().all_gather(x, dim=1) # unpatchify x = self.unpatchify(x, grid_sizes) return x.float() def usp_attn_forward(self, x, grid_sizes, freqs, block_mask): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] seq_lens(Tensor): Shape [B] grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim half_dtypes = (torch.float16, torch.bfloat16) def half(x): return x if x.dtype in half_dtypes else x.to(torch.bfloat16) # query, key, value function def qkv_fn(x): q = self.norm_q(self.q(x)).view(b, s, n, d) k = self.norm_k(self.k(x)).view(b, s, n, d) v = self.v(x).view(b, s, n, d) return q, k, v x = x.to(self.q.weight.dtype) q, k, v = qkv_fn(x) if not self._flag_ar_attention: q = rope_apply(q, grid_sizes, freqs) k = rope_apply(k, grid_sizes, freqs) else: q = rope_apply(q, grid_sizes, freqs) k = rope_apply(k, grid_sizes, freqs) q = q.to(torch.bfloat16) k = k.to(torch.bfloat16) v = v.to(torch.bfloat16) # x = torch.nn.functional.scaled_dot_product_attention( # q.transpose(1, 2), # k.transpose(1, 2), # v.transpose(1, 2), # ).transpose(1, 2).contiguous() with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): x = ( torch.nn.functional.scaled_dot_product_attention( q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask ) .transpose(1, 2) .contiguous() ) x = xFuserLongContextAttention()(None, query=half(q), key=half(k), value=half(v), window_size=self.window_size) # output x = x.flatten(2) x = self.o(x) return x