Spaces:
Running
on
Zero
Running
on
Zero
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
# // | |
# // Licensed under the Apache License, Version 2.0 (the "License"); | |
# // you may not use this file except in compliance with the License. | |
# // You may obtain a copy of the License at | |
# // | |
# // http://www.apache.org/licenses/LICENSE-2.0 | |
# // | |
# // Unless required by applicable law or agreed to in writing, software | |
# // distributed under the License is distributed on an "AS IS" BASIS, | |
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# // See the License for the specific language governing permissions and | |
# // limitations under the License. | |
from typing import List | |
import torch | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
from torch import Tensor | |
from common.distributed import get_device | |
from common.distributed.advanced import ( | |
get_next_sequence_parallel_rank, | |
get_prev_sequence_parallel_rank, | |
get_sequence_parallel_group, | |
get_sequence_parallel_rank, | |
get_sequence_parallel_world_size, | |
) | |
from common.distributed.ops import Gather | |
from common.logger import get_logger | |
from models.video_vae_v3.modules.types import MemoryState | |
logger = get_logger(__name__) | |
def causal_conv_slice_inputs(x, split_size, memory_state): | |
sp_size = get_sequence_parallel_world_size() | |
sp_group = get_sequence_parallel_group() | |
sp_rank = get_sequence_parallel_rank() | |
if sp_group is None: | |
return x | |
assert memory_state != MemoryState.UNSET | |
leave_out = 1 if memory_state != MemoryState.ACTIVE else 0 | |
# Should have at least sp_size slices. | |
num_slices = (x.size(2) - leave_out) // split_size | |
assert num_slices >= sp_size, f"{num_slices} < {sp_size}" | |
split_sizes = [split_size + leave_out] + [split_size] * (num_slices - 1) | |
split_sizes += [x.size(2) - sum(split_sizes)] | |
assert sum(split_sizes) == x.size(2) | |
split_sizes = torch.tensor(split_sizes) | |
slices_per_rank = len(split_sizes) // sp_size | |
split_sizes = split_sizes.split( | |
[slices_per_rank] * (sp_size - 1) + [len(split_sizes) - slices_per_rank * (sp_size - 1)] | |
) | |
split_sizes = list(map(lambda s: s.sum().item(), split_sizes)) | |
logger.debug(f"split_sizes: {split_sizes}") | |
return x.split(split_sizes, dim=2)[sp_rank] | |
def causal_conv_gather_outputs(x): | |
sp_group = get_sequence_parallel_group() | |
sp_size = get_sequence_parallel_world_size() | |
if sp_group is None: | |
return x | |
# Communicate shapes. | |
unpad_lens = torch.empty((sp_size,), device=get_device(), dtype=torch.long) | |
local_unpad_len = torch.tensor([x.size(2)], device=get_device(), dtype=torch.long) | |
torch.distributed.all_gather_into_tensor(unpad_lens, local_unpad_len, group=sp_group) | |
# Padding to max_len for gather. | |
max_len = unpad_lens.max() | |
x_pad = F.pad(x, (0, 0, 0, 0, 0, max_len - x.size(2))).contiguous() | |
# Gather outputs. | |
x_pad = Gather.apply(sp_group, x_pad, 2, True) | |
# Remove padding. | |
x_pad_lists = list(x_pad.chunk(sp_size, dim=2)) | |
for i, (x_pad, unpad_len) in enumerate(zip(x_pad_lists, unpad_lens)): | |
x_pad_lists[i] = x_pad[:, :, :unpad_len] | |
return torch.cat(x_pad_lists, dim=2) | |
def get_output_len(conv_module, input_len, pad_len, dim=0): | |
dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 | |
output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 | |
return output_len | |
def get_cache_size(conv_module, input_len, pad_len, dim=0): | |
dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 | |
output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 | |
remain_len = ( | |
input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size) | |
) | |
overlap_len = dilated_kernerl_size - conv_module.stride[dim] | |
cache_len = overlap_len + remain_len # >= 0 | |
logger.debug( | |
f"I:{input_len}, " | |
f"P:{pad_len}, " | |
f"K:{conv_module.kernel_size[dim]}, " | |
f"S:{conv_module.stride[dim]}, " | |
f"O:{output_len}, " | |
f"Cache:{cache_len}" | |
) | |
assert output_len > 0 | |
return cache_len | |
def cache_send_recv(tensor: List[Tensor], cache_size, times, memory=None): | |
sp_group = get_sequence_parallel_group() | |
sp_rank = get_sequence_parallel_rank() | |
sp_size = get_sequence_parallel_world_size() | |
send_dst = get_next_sequence_parallel_rank() | |
recv_src = get_prev_sequence_parallel_rank() | |
recv_buffer = None | |
recv_req = None | |
logger.debug( | |
f"[sp{sp_rank}] cur_tensors:{[(t.size(), t.dtype) for t in tensor]}, times: {times}" | |
) | |
if sp_rank == 0 or sp_group is None: | |
if memory is not None: | |
recv_buffer = memory.to(tensor[0]) | |
elif times > 0: | |
tile_repeat = [1] * tensor[0].ndim | |
tile_repeat[2] = times | |
recv_buffer = torch.tile(tensor[0][:, :, :1], tile_repeat) | |
if cache_size != 0 and sp_group is not None: | |
if sp_rank > 0: | |
shape = list(tensor[0].size()) | |
shape[2] = cache_size | |
recv_buffer = torch.empty( | |
*shape, device=tensor[0].device, dtype=tensor[0].dtype | |
).contiguous() | |
recv_req = dist.irecv(recv_buffer, recv_src, group=sp_group) | |
if sp_rank < sp_size - 1: | |
if cache_size > tensor[-1].size(2) and len(tensor) == 1: | |
logger.debug(f"[sp{sp_rank}] force concat before send {tensor[-1].size()}") | |
if recv_req is not None: | |
recv_req.wait() | |
tensor[0] = torch.cat([recv_buffer, tensor[0]], dim=2) | |
recv_buffer = None | |
assert cache_size <= tensor[-1].size( | |
2 | |
), f"Not enough value to cache, got {tensor[-1].size()}, cache_size={cache_size}" | |
dist.isend( | |
tensor[-1][:, :, -cache_size:].detach().contiguous(), send_dst, group=sp_group | |
) | |
if recv_req is not None: | |
recv_req.wait() | |
logger.debug( | |
f"[sp{sp_rank}] recv_src:{recv_src}, " | |
f"recv_buffer:{recv_buffer.size() if recv_buffer is not None else None}" | |
) | |
return recv_buffer | |