Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,322 Bytes
ba7cb71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.distributed as dist
from ..modules.attention import flash_attention
from .util import all_to_all
def distributed_attention(
q,
k,
v,
seq_lens,
window_size=(-1, -1),
):
"""
Performs distributed attention based on DeepSpeed Ulysses attention mechanism.
please refer to https://arxiv.org/pdf/2309.14509
Args:
q: [B, Lq // p, Nq, C1].
k: [B, Lk // p, Nk, C1].
v: [B, Lk // p, Nk, C2]. Nq must be divisible by Nk.
seq_lens: [B], length of each sequence in batch
window_size: (left right). If not (-1, -1), apply sliding window local attention.
"""
if not dist.is_initialized():
raise ValueError("distributed group should be initialized.")
b = q.shape[0]
# gather q/k/v sequence
q = all_to_all(q, scatter_dim=2, gather_dim=1)
k = all_to_all(k, scatter_dim=2, gather_dim=1)
v = all_to_all(v, scatter_dim=2, gather_dim=1)
# apply attention
x = flash_attention(
q,
k,
v,
k_lens=seq_lens,
window_size=window_size,
)
# scatter q/k/v sequence
x = all_to_all(x, scatter_dim=1, gather_dim=2)
return x
|