SeedVR2-3B / models /video_vae_v3 /modules /context_parallel_lib.py
IceClear
upload files
42f2c22
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
from typing import List
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import Tensor
from common.distributed import get_device
from common.distributed.advanced import (
get_next_sequence_parallel_rank,
get_prev_sequence_parallel_rank,
get_sequence_parallel_group,
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
)
from common.distributed.ops import Gather
from common.logger import get_logger
from models.video_vae_v3.modules.types import MemoryState
logger = get_logger(__name__)
def causal_conv_slice_inputs(x, split_size, memory_state):
sp_size = get_sequence_parallel_world_size()
sp_group = get_sequence_parallel_group()
sp_rank = get_sequence_parallel_rank()
if sp_group is None:
return x
assert memory_state != MemoryState.UNSET
leave_out = 1 if memory_state != MemoryState.ACTIVE else 0
# Should have at least sp_size slices.
num_slices = (x.size(2) - leave_out) // split_size
assert num_slices >= sp_size, f"{num_slices} < {sp_size}"
split_sizes = [split_size + leave_out] + [split_size] * (num_slices - 1)
split_sizes += [x.size(2) - sum(split_sizes)]
assert sum(split_sizes) == x.size(2)
split_sizes = torch.tensor(split_sizes)
slices_per_rank = len(split_sizes) // sp_size
split_sizes = split_sizes.split(
[slices_per_rank] * (sp_size - 1) + [len(split_sizes) - slices_per_rank * (sp_size - 1)]
)
split_sizes = list(map(lambda s: s.sum().item(), split_sizes))
logger.debug(f"split_sizes: {split_sizes}")
return x.split(split_sizes, dim=2)[sp_rank]
def causal_conv_gather_outputs(x):
sp_group = get_sequence_parallel_group()
sp_size = get_sequence_parallel_world_size()
if sp_group is None:
return x
# Communicate shapes.
unpad_lens = torch.empty((sp_size,), device=get_device(), dtype=torch.long)
local_unpad_len = torch.tensor([x.size(2)], device=get_device(), dtype=torch.long)
torch.distributed.all_gather_into_tensor(unpad_lens, local_unpad_len, group=sp_group)
# Padding to max_len for gather.
max_len = unpad_lens.max()
x_pad = F.pad(x, (0, 0, 0, 0, 0, max_len - x.size(2))).contiguous()
# Gather outputs.
x_pad = Gather.apply(sp_group, x_pad, 2, True)
# Remove padding.
x_pad_lists = list(x_pad.chunk(sp_size, dim=2))
for i, (x_pad, unpad_len) in enumerate(zip(x_pad_lists, unpad_lens)):
x_pad_lists[i] = x_pad[:, :, :unpad_len]
return torch.cat(x_pad_lists, dim=2)
def get_output_len(conv_module, input_len, pad_len, dim=0):
dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1
output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1
return output_len
def get_cache_size(conv_module, input_len, pad_len, dim=0):
dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1
output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1
remain_len = (
input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size)
)
overlap_len = dilated_kernerl_size - conv_module.stride[dim]
cache_len = overlap_len + remain_len # >= 0
logger.debug(
f"I:{input_len}, "
f"P:{pad_len}, "
f"K:{conv_module.kernel_size[dim]}, "
f"S:{conv_module.stride[dim]}, "
f"O:{output_len}, "
f"Cache:{cache_len}"
)
assert output_len > 0
return cache_len
def cache_send_recv(tensor: List[Tensor], cache_size, times, memory=None):
sp_group = get_sequence_parallel_group()
sp_rank = get_sequence_parallel_rank()
sp_size = get_sequence_parallel_world_size()
send_dst = get_next_sequence_parallel_rank()
recv_src = get_prev_sequence_parallel_rank()
recv_buffer = None
recv_req = None
logger.debug(
f"[sp{sp_rank}] cur_tensors:{[(t.size(), t.dtype) for t in tensor]}, times: {times}"
)
if sp_rank == 0 or sp_group is None:
if memory is not None:
recv_buffer = memory.to(tensor[0])
elif times > 0:
tile_repeat = [1] * tensor[0].ndim
tile_repeat[2] = times
recv_buffer = torch.tile(tensor[0][:, :, :1], tile_repeat)
if cache_size != 0 and sp_group is not None:
if sp_rank > 0:
shape = list(tensor[0].size())
shape[2] = cache_size
recv_buffer = torch.empty(
*shape, device=tensor[0].device, dtype=tensor[0].dtype
).contiguous()
recv_req = dist.irecv(recv_buffer, recv_src, group=sp_group)
if sp_rank < sp_size - 1:
if cache_size > tensor[-1].size(2) and len(tensor) == 1:
logger.debug(f"[sp{sp_rank}] force concat before send {tensor[-1].size()}")
if recv_req is not None:
recv_req.wait()
tensor[0] = torch.cat([recv_buffer, tensor[0]], dim=2)
recv_buffer = None
assert cache_size <= tensor[-1].size(
2
), f"Not enough value to cache, got {tensor[-1].size()}, cache_size={cache_size}"
dist.isend(
tensor[-1][:, :, -cache_size:].detach().contiguous(), send_dst, group=sp_group
)
if recv_req is not None:
recv_req.wait()
logger.debug(
f"[sp{sp_rank}] recv_src:{recv_src}, "
f"recv_buffer:{recv_buffer.size() if recv_buffer is not None else None}"
)
return recv_buffer