Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
# Copyright (c) Microsoft Corporation. | |
# SPDX-License-Identifier: Apache-2.0 | |
# DeepSpeed Team | |
# This file is modified from https://github.com/feifeibear/long-context-attention | |
# Implementation refers to USP Paper: https://arxiv.org/abs/2405.07719 | |
# This file is also partly modified from https://github.com/microsoft/DeepSpeed | |
# Implementation refers to Ulysses Paper: https://arxiv.org/abs/2309.14509 | |
import copy | |
from typing import Any, Tuple | |
import deepspeed.comm as dist | |
import torch | |
import torch.distributed as torch_dist | |
from flash_attn import flash_attn_func | |
from torch import Tensor | |
from torch.nn import Module | |
from llava.train.sequence_parallel.globals import get_ulysses_seq_len, get_ulysses_sp_rank, get_ulysses_sp_size | |
from .all_to_all import SeqAllGather, SeqAllToAll4D, SeqAllToAll5D | |
class _ExpandKVFunction(torch.autograd.Function): | |
""" | |
Copy the KV head repeat times to extend sequence parallel support for Ulysses. | |
Args: | |
kv: input kv. | |
repeat_times: the repeat number of each head. | |
num_head_dim: the dimension of head number. | |
""" | |
def forward(ctx, k, v, repeat_times, num_head_dim): | |
kv_shape = k.shape | |
num_heads_kv = kv_shape[num_head_dim] | |
ctx.num_head_dim = num_head_dim | |
ctx.num_heads_kv = num_heads_kv | |
# here we construct a repeat index to indicate which dim should copy | |
repeat_index = [1] * k.ndim | |
repeat_index[num_head_dim] = repeat_times | |
# split the kv into head num splits | |
k_splits = torch.chunk(k, chunks=num_heads_kv, dim=num_head_dim) | |
v_splits = torch.chunk(v, chunks=num_heads_kv, dim=num_head_dim) | |
k_repeats, v_repeats = [], [] | |
# for each split, we copy it to repeat_times copys. | |
for split in k_splits: | |
k_split_repeat = split.repeat(repeat_index) | |
k_repeats.append(k_split_repeat) | |
for split in v_splits: | |
v_split_repeat = split.repeat(repeat_index) | |
v_repeats.append(v_split_repeat) | |
return torch.cat(k_repeats, dim=num_head_dim), torch.cat(v_repeats, dim=num_head_dim) | |
def backward(ctx, grad_output_k, grad_output_v): | |
""" | |
For backward, we sum the copy head inside a query group. | |
""" | |
num_head_dim = ctx.num_head_dim | |
num_heads_kv = ctx.num_heads_kv | |
# we split the grad into query groups splits. | |
grad_output_k_splits = torch.chunk(grad_output_k, chunks=num_heads_kv, dim=num_head_dim) | |
grad_output_v_splits = torch.chunk(grad_output_v, chunks=num_heads_kv, dim=num_head_dim) | |
grad_output_k_sums, grad_output_v_sums = [], [] | |
# for each split, we sum the head | |
for grad_output_k_split in grad_output_k_splits: | |
grad_output_k_sum = grad_output_k_split.sum(dim=num_head_dim, keepdim=True) | |
grad_output_k_sums.append(grad_output_k_sum) | |
for grad_output_v_split in grad_output_v_splits: | |
grad_output_v_sum = grad_output_v_split.sum(dim=num_head_dim, keepdim=True) | |
grad_output_v_sums.append(grad_output_v_sum) | |
# then we concat the split sums on the num_head_dim dimension. | |
grad_k = torch.cat(grad_output_k_sums, dim=num_head_dim) | |
grad_v = torch.cat(grad_output_v_sums, dim=num_head_dim) | |
return grad_k, grad_v, None, None | |
expandKV = _ExpandKVFunction.apply | |
class UlyssesAttention(torch.nn.Module): | |
"""Initialization. | |
Arguments: | |
local_attention (Module): local attention with q,k,v | |
sequence_process_group (ProcessGroup): sequence parallel process group | |
scatter_idx (int): scatter_idx for all2all comm | |
gather_idx (int): gather_idx for all2all comm | |
""" | |
def __init__( | |
self, | |
local_attention: Module, | |
sequence_process_group: dist.ProcessGroup = None, | |
scatter_idx: int = 2, | |
gather_idx: int = 1, | |
) -> None: | |
super().__init__() | |
self.local_attn = local_attention | |
self.spg = sequence_process_group | |
self.scatter_idx = scatter_idx | |
self.gather_idx = gather_idx | |
self.ulysses_degree = get_ulysses_sp_size() | |
def forward( | |
self, | |
query: Tensor, | |
key: Tensor, | |
value: Tensor, | |
*args: Any, | |
attention_mask=None, | |
dropout_p=0.0, | |
softmax_scale=None, | |
seqlens_in_batch=None, | |
causal=False, | |
window_size=(-1, -1), | |
alibi_slopes=None, | |
deterministic=False, | |
return_attn_probs=False, | |
) -> Tensor: | |
"""forward | |
Arguments: | |
query (Tensor): query input to the layer | |
key (Tensor): key input to the layer | |
value (Tensor): value input to the layer | |
args: other args | |
Returns: | |
* output (Tensor): context output | |
""" | |
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size) | |
# KV Replication for GQA | |
head_dim = 2 | |
num_head_kv = key.shape[head_dim] | |
if self.ulysses_degree > num_head_kv: | |
assert self.ulysses_degree % num_head_kv == 0, "Ulysses require num_head_kv to be dividable by sp degree." | |
key, value = expandKV(key, value, self.ulysses_degree // num_head_kv, head_dim) | |
# scatter 2, gather 1 | |
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx) | |
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx) | |
v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx) | |
if attention_mask is not None: | |
local_attention_mask = copy.deepcopy(attention_mask) | |
shard_seqlen = local_attention_mask.size(1) | |
ulysses_seq_len = get_ulysses_seq_len() | |
max_global_length = max(ulysses_seq_len) | |
global_attention_mask_list = [] | |
for i in range(get_ulysses_sp_size()): | |
if i == get_ulysses_sp_rank(): | |
global_attention_mask_list.append( | |
torch.cat( | |
[ | |
local_attention_mask, | |
torch.zeros( | |
(local_attention_mask.size(0), max_global_length - shard_seqlen), | |
dtype=local_attention_mask.dtype, | |
device=local_attention_mask.device, | |
), | |
], | |
dim=1, | |
) | |
) | |
else: | |
global_attention_mask_list.append( | |
torch.zeros( | |
(local_attention_mask.size(0), max_global_length), | |
dtype=local_attention_mask.dtype, | |
device=local_attention_mask.device, | |
) | |
) | |
global_attention_mask = torch.stack(global_attention_mask_list, dim=0) | |
torch_dist.all_reduce(global_attention_mask, group=self.spg) | |
torch_dist.barrier(group=self.spg) | |
new_global_attention_mask_list = list(torch.unbind(global_attention_mask, dim=0)) | |
# Unpad the global attention mask list and concatenate them | |
for i in range(len(new_global_attention_mask_list)): | |
new_global_attention_mask_list[i] = new_global_attention_mask_list[i][:, : ulysses_seq_len[i]] | |
global_attention_mask = torch.cat(new_global_attention_mask_list, dim=1) | |
context_layer = self.local_attn( | |
q, | |
k, | |
v, | |
*args, | |
attention_mask=global_attention_mask, | |
dropout_p=dropout_p, | |
softmax_scale=softmax_scale, | |
seqlens_in_batch=seqlens_in_batch, | |
causal=causal, | |
) | |
else: | |
context_layer = self.local_attn( | |
q, | |
k, | |
v, | |
*args, | |
dropout_p=dropout_p, | |
softmax_scale=softmax_scale, | |
causal=causal, | |
) | |
if isinstance(context_layer, tuple): | |
context_layer = context_layer[0] | |
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) | |
# scatter 1, gather 2 | |
output = SeqAllToAll4D.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx) | |
# out e.g., [s/p::h] | |
return output | |