SeedVR2-3B / models /dit /nablocks /mmsr_block.py
IceClear
upload files
42f2c22
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
from typing import Tuple, Union
import torch
from einops import rearrange
from torch.nn import functional as F
# from ..cache import Cache
from common.cache import Cache
from common.distributed.ops import gather_heads_scatter_seq, gather_seq_scatter_heads_qkv
from .. import na
from ..attention import FlashAttentionVarlen
from ..blocks.mmdit_window_block import MMWindowAttention, MMWindowTransformerBlock
from ..mm import MMArg
from ..modulation import ada_layer_type
from ..normalization import norm_layer_type
from ..rope import NaRotaryEmbedding3d
from ..window import get_window_op
class NaSwinAttention(MMWindowAttention):
def __init__(
self,
vid_dim: int,
txt_dim: int,
heads: int,
head_dim: int,
qk_bias: bool,
qk_rope: bool,
qk_norm: norm_layer_type,
qk_norm_eps: float,
window: Union[int, Tuple[int, int, int]],
window_method: str,
shared_qkv: bool,
**kwargs,
):
super().__init__(
vid_dim=vid_dim,
txt_dim=txt_dim,
heads=heads,
head_dim=head_dim,
qk_bias=qk_bias,
qk_rope=qk_rope,
qk_norm=qk_norm,
qk_norm_eps=qk_norm_eps,
window=window,
window_method=window_method,
shared_qkv=shared_qkv,
)
self.rope = NaRotaryEmbedding3d(dim=head_dim // 2) if qk_rope else None
self.attn = FlashAttentionVarlen()
self.window_op = get_window_op(window_method)
def forward(
self,
vid: torch.FloatTensor, # l c
txt: torch.FloatTensor, # l c
vid_shape: torch.LongTensor, # b 3
txt_shape: torch.LongTensor, # b 1
cache: Cache,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
]:
vid_qkv, txt_qkv = self.proj_qkv(vid, txt)
vid_qkv = gather_seq_scatter_heads_qkv(
vid_qkv,
seq_dim=0,
qkv_shape=vid_shape,
cache=cache.namespace("vid"),
)
txt_qkv = gather_seq_scatter_heads_qkv(
txt_qkv,
seq_dim=0,
qkv_shape=txt_shape,
cache=cache.namespace("txt"),
)
# re-org the input seq for window attn
cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3")
def make_window(x: torch.Tensor):
t, h, w, _ = x.shape
window_slices = self.window_op((t, h, w), self.window)
return [x[st, sh, sw] for (st, sh, sw) in window_slices]
window_partition, window_reverse, window_shape, window_count = cache_win(
"win_transform",
lambda: na.window_idx(vid_shape, make_window),
)
vid_qkv_win = window_partition(vid_qkv)
vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim)
txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim)
vid_q, vid_k, vid_v = vid_qkv_win.unbind(1)
txt_q, txt_k, txt_v = txt_qkv.unbind(1)
vid_q, txt_q = self.norm_q(vid_q, txt_q)
vid_k, txt_k = self.norm_k(vid_k, txt_k)
txt_len = cache("txt_len", lambda: txt_shape.prod(-1))
vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1))
txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count))
all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win)
concat_win, unconcat_win = cache_win(
"mm_pnp", lambda: na.repeat_concat_idx(vid_len_win, txt_len, window_count)
)
# window rope
if self.rope:
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
out = self.attn(
q=concat_win(vid_q, txt_q).bfloat16(),
k=concat_win(vid_k, txt_k).bfloat16(),
v=concat_win(vid_v, txt_v).bfloat16(),
cu_seqlens_q=cache_win(
"vid_seqlens_q", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int()
),
cu_seqlens_k=cache_win(
"vid_seqlens_k", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int()
),
max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()),
max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()),
).type_as(vid_q)
# text pooling
vid_out, txt_out = unconcat_win(out)
vid_out = rearrange(vid_out, "l h d -> l (h d)")
txt_out = rearrange(txt_out, "l h d -> l (h d)")
vid_out = window_reverse(vid_out)
vid_out = gather_heads_scatter_seq(vid_out, head_dim=1, seq_dim=0)
txt_out = gather_heads_scatter_seq(txt_out, head_dim=1, seq_dim=0)
vid_out, txt_out = self.proj_out(vid_out, txt_out)
return vid_out, txt_out
class NaMMSRTransformerBlock(MMWindowTransformerBlock):
def __init__(
self,
*,
vid_dim: int,
txt_dim: int,
emb_dim: int,
heads: int,
head_dim: int,
expand_ratio: int,
norm: norm_layer_type,
norm_eps: float,
ada: ada_layer_type,
qk_bias: bool,
qk_rope: bool,
qk_norm: norm_layer_type,
shared_qkv: bool,
shared_mlp: bool,
mlp_type: str,
**kwargs,
):
super().__init__(
vid_dim=vid_dim,
txt_dim=txt_dim,
emb_dim=emb_dim,
heads=heads,
head_dim=head_dim,
expand_ratio=expand_ratio,
norm=norm,
norm_eps=norm_eps,
ada=ada,
qk_bias=qk_bias,
qk_rope=qk_rope,
qk_norm=qk_norm,
shared_qkv=shared_qkv,
shared_mlp=shared_mlp,
mlp_type=mlp_type,
**kwargs,
)
self.attn = NaSwinAttention(
vid_dim=vid_dim,
txt_dim=txt_dim,
heads=heads,
head_dim=head_dim,
qk_bias=qk_bias,
qk_rope=qk_rope,
qk_norm=qk_norm,
qk_norm_eps=norm_eps,
shared_qkv=shared_qkv,
**kwargs,
)
def forward(
self,
vid: torch.FloatTensor, # l c
txt: torch.FloatTensor, # l c
vid_shape: torch.LongTensor, # b 3
txt_shape: torch.LongTensor, # b 1
emb: torch.FloatTensor,
cache: Cache,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
torch.LongTensor,
torch.LongTensor,
]:
hid_len = MMArg(
cache("vid_len", lambda: vid_shape.prod(-1)),
cache("txt_len", lambda: txt_shape.prod(-1)),
)
ada_kwargs = {
"emb": emb,
"hid_len": hid_len,
"cache": cache,
"branch_tag": MMArg("vid", "txt"),
}
vid_attn, txt_attn = self.attn_norm(vid, txt)
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs)
vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache)
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs)
vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt)
vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn)
vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs)
vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp)
vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs)
vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn)
return vid_mlp, txt_mlp, vid_shape, txt_shape