# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # // # // Licensed under the Apache License, Version 2.0 (the "License"); # // you may not use this file except in compliance with the License. # // You may obtain a copy of the License at # // # // http://www.apache.org/licenses/LICENSE-2.0 # // # // Unless required by applicable law or agreed to in writing, software # // distributed under the License is distributed on an "AS IS" BASIS, # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # // See the License for the specific language governing permissions and # // limitations under the License. from typing import List import torch import torch.distributed as dist import torch.nn.functional as F from torch import Tensor from common.distributed import get_device from common.distributed.advanced import ( get_next_sequence_parallel_rank, get_prev_sequence_parallel_rank, get_sequence_parallel_group, get_sequence_parallel_rank, get_sequence_parallel_world_size, ) from common.distributed.ops import Gather from common.logger import get_logger from models.video_vae_v3.modules.types import MemoryState logger = get_logger(__name__) def causal_conv_slice_inputs(x, split_size, memory_state): sp_size = get_sequence_parallel_world_size() sp_group = get_sequence_parallel_group() sp_rank = get_sequence_parallel_rank() if sp_group is None: return x assert memory_state != MemoryState.UNSET leave_out = 1 if memory_state != MemoryState.ACTIVE else 0 # Should have at least sp_size slices. num_slices = (x.size(2) - leave_out) // split_size assert num_slices >= sp_size, f"{num_slices} < {sp_size}" split_sizes = [split_size + leave_out] + [split_size] * (num_slices - 1) split_sizes += [x.size(2) - sum(split_sizes)] assert sum(split_sizes) == x.size(2) split_sizes = torch.tensor(split_sizes) slices_per_rank = len(split_sizes) // sp_size split_sizes = split_sizes.split( [slices_per_rank] * (sp_size - 1) + [len(split_sizes) - slices_per_rank * (sp_size - 1)] ) split_sizes = list(map(lambda s: s.sum().item(), split_sizes)) logger.debug(f"split_sizes: {split_sizes}") return x.split(split_sizes, dim=2)[sp_rank] def causal_conv_gather_outputs(x): sp_group = get_sequence_parallel_group() sp_size = get_sequence_parallel_world_size() if sp_group is None: return x # Communicate shapes. unpad_lens = torch.empty((sp_size,), device=get_device(), dtype=torch.long) local_unpad_len = torch.tensor([x.size(2)], device=get_device(), dtype=torch.long) torch.distributed.all_gather_into_tensor(unpad_lens, local_unpad_len, group=sp_group) # Padding to max_len for gather. max_len = unpad_lens.max() x_pad = F.pad(x, (0, 0, 0, 0, 0, max_len - x.size(2))).contiguous() # Gather outputs. x_pad = Gather.apply(sp_group, x_pad, 2, True) # Remove padding. x_pad_lists = list(x_pad.chunk(sp_size, dim=2)) for i, (x_pad, unpad_len) in enumerate(zip(x_pad_lists, unpad_lens)): x_pad_lists[i] = x_pad[:, :, :unpad_len] return torch.cat(x_pad_lists, dim=2) def get_output_len(conv_module, input_len, pad_len, dim=0): dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 return output_len def get_cache_size(conv_module, input_len, pad_len, dim=0): dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 remain_len = ( input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size) ) overlap_len = dilated_kernerl_size - conv_module.stride[dim] cache_len = overlap_len + remain_len # >= 0 logger.debug( f"I:{input_len}, " f"P:{pad_len}, " f"K:{conv_module.kernel_size[dim]}, " f"S:{conv_module.stride[dim]}, " f"O:{output_len}, " f"Cache:{cache_len}" ) assert output_len > 0 return cache_len def cache_send_recv(tensor: List[Tensor], cache_size, times, memory=None): sp_group = get_sequence_parallel_group() sp_rank = get_sequence_parallel_rank() sp_size = get_sequence_parallel_world_size() send_dst = get_next_sequence_parallel_rank() recv_src = get_prev_sequence_parallel_rank() recv_buffer = None recv_req = None logger.debug( f"[sp{sp_rank}] cur_tensors:{[(t.size(), t.dtype) for t in tensor]}, times: {times}" ) if sp_rank == 0 or sp_group is None: if memory is not None: recv_buffer = memory.to(tensor[0]) elif times > 0: tile_repeat = [1] * tensor[0].ndim tile_repeat[2] = times recv_buffer = torch.tile(tensor[0][:, :, :1], tile_repeat) if cache_size != 0 and sp_group is not None: if sp_rank > 0: shape = list(tensor[0].size()) shape[2] = cache_size recv_buffer = torch.empty( *shape, device=tensor[0].device, dtype=tensor[0].dtype ).contiguous() recv_req = dist.irecv(recv_buffer, recv_src, group=sp_group) if sp_rank < sp_size - 1: if cache_size > tensor[-1].size(2) and len(tensor) == 1: logger.debug(f"[sp{sp_rank}] force concat before send {tensor[-1].size()}") if recv_req is not None: recv_req.wait() tensor[0] = torch.cat([recv_buffer, tensor[0]], dim=2) recv_buffer = None assert cache_size <= tensor[-1].size( 2 ), f"Not enough value to cache, got {tensor[-1].size()}, cache_size={cache_size}" dist.isend( tensor[-1][:, :, -cache_size:].detach().contiguous(), send_dst, group=sp_group ) if recv_req is not None: recv_req.wait() logger.debug( f"[sp{sp_rank}] recv_src:{recv_src}, " f"recv_buffer:{recv_buffer.size() if recv_buffer is not None else None}" ) return recv_buffer