File size: 1,624 Bytes
9c72a9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from diffusers.models.attention_processor import Attention
from .attention import WanSparseAttnProcessor2_0 
from .sparse_transformer import replace_sparse_forward
from ...attn_mask import MaskMap

def replace_wan_attention(
    pipe,
    height,
    width,
    num_frames,
    dense_layers=0,
    dense_timesteps=0,
    decay_factor=1.0,
    sparsity_type="radial",
):

    num_frames = 1 + num_frames // (pipe.vae_scale_factor_temporal * pipe.transformer.config.patch_size[0])
    mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
    frame_size = int(height // mod_value) * int(width // mod_value)
    
    AttnModule = WanSparseAttnProcessor2_0
    AttnModule.dense_block = dense_layers
    AttnModule.dense_timestep = dense_timesteps
    AttnModule.mask_map = MaskMap(video_token_num=frame_size * num_frames, num_frame=num_frames)
    AttnModule.decay_factor = decay_factor
    AttnModule.sparsity_type = sparsity_type
    
    print(f"Replacing Wan attention with {sparsity_type} attention")
    print(f"video token num: {AttnModule.mask_map.video_token_num}, num frames: {num_frames}")
    print(f"dense layers: {dense_layers}, dense timesteps: {dense_timesteps}, decay factor: {decay_factor}")
    
    replace_sparse_forward()
    
    for layer_idx, m in enumerate(pipe.transformer.blocks):
        m.attn1.processor.layer_idx = layer_idx
        
    for _, m in pipe.transformer.named_modules():
        if isinstance(m, Attention) and hasattr(m.processor, "layer_idx"):
            layer_idx = m.processor.layer_idx
            m.set_processor(AttnModule(layer_idx))