SreyanG-NVIDIA's picture
Upload 225 files
174ae06 verified
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# This file is modified from https://github.com/feifeibear/long-context-attention
# Implementation refers to USP Paper: https://arxiv.org/abs/2405.07719
# This file is also partly modified from https://github.com/microsoft/DeepSpeed
# Implementation refers to Ulysses Paper: https://arxiv.org/abs/2309.14509
import copy
from typing import Any, Tuple
import deepspeed.comm as dist
import torch
import torch.distributed as torch_dist
from flash_attn import flash_attn_func
from torch import Tensor
from torch.nn import Module
from llava.train.sequence_parallel.globals import get_ulysses_seq_len, get_ulysses_sp_rank, get_ulysses_sp_size
from .all_to_all import SeqAllGather, SeqAllToAll4D, SeqAllToAll5D
class _ExpandKVFunction(torch.autograd.Function):
"""
Copy the KV head repeat times to extend sequence parallel support for Ulysses.
Args:
kv: input kv.
repeat_times: the repeat number of each head.
num_head_dim: the dimension of head number.
"""
@staticmethod
def forward(ctx, k, v, repeat_times, num_head_dim):
kv_shape = k.shape
num_heads_kv = kv_shape[num_head_dim]
ctx.num_head_dim = num_head_dim
ctx.num_heads_kv = num_heads_kv
# here we construct a repeat index to indicate which dim should copy
repeat_index = [1] * k.ndim
repeat_index[num_head_dim] = repeat_times
# split the kv into head num splits
k_splits = torch.chunk(k, chunks=num_heads_kv, dim=num_head_dim)
v_splits = torch.chunk(v, chunks=num_heads_kv, dim=num_head_dim)
k_repeats, v_repeats = [], []
# for each split, we copy it to repeat_times copys.
for split in k_splits:
k_split_repeat = split.repeat(repeat_index)
k_repeats.append(k_split_repeat)
for split in v_splits:
v_split_repeat = split.repeat(repeat_index)
v_repeats.append(v_split_repeat)
return torch.cat(k_repeats, dim=num_head_dim), torch.cat(v_repeats, dim=num_head_dim)
@staticmethod
def backward(ctx, grad_output_k, grad_output_v):
"""
For backward, we sum the copy head inside a query group.
"""
num_head_dim = ctx.num_head_dim
num_heads_kv = ctx.num_heads_kv
# we split the grad into query groups splits.
grad_output_k_splits = torch.chunk(grad_output_k, chunks=num_heads_kv, dim=num_head_dim)
grad_output_v_splits = torch.chunk(grad_output_v, chunks=num_heads_kv, dim=num_head_dim)
grad_output_k_sums, grad_output_v_sums = [], []
# for each split, we sum the head
for grad_output_k_split in grad_output_k_splits:
grad_output_k_sum = grad_output_k_split.sum(dim=num_head_dim, keepdim=True)
grad_output_k_sums.append(grad_output_k_sum)
for grad_output_v_split in grad_output_v_splits:
grad_output_v_sum = grad_output_v_split.sum(dim=num_head_dim, keepdim=True)
grad_output_v_sums.append(grad_output_v_sum)
# then we concat the split sums on the num_head_dim dimension.
grad_k = torch.cat(grad_output_k_sums, dim=num_head_dim)
grad_v = torch.cat(grad_output_v_sums, dim=num_head_dim)
return grad_k, grad_v, None, None
expandKV = _ExpandKVFunction.apply
class UlyssesAttention(torch.nn.Module):
"""Initialization.
Arguments:
local_attention (Module): local attention with q,k,v
sequence_process_group (ProcessGroup): sequence parallel process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
"""
def __init__(
self,
local_attention: Module,
sequence_process_group: dist.ProcessGroup = None,
scatter_idx: int = 2,
gather_idx: int = 1,
) -> None:
super().__init__()
self.local_attn = local_attention
self.spg = sequence_process_group
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
self.ulysses_degree = get_ulysses_sp_size()
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
*args: Any,
attention_mask=None,
dropout_p=0.0,
softmax_scale=None,
seqlens_in_batch=None,
causal=False,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
) -> Tensor:
"""forward
Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
args: other args
Returns:
* output (Tensor): context output
"""
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
# KV Replication for GQA
head_dim = 2
num_head_kv = key.shape[head_dim]
if self.ulysses_degree > num_head_kv:
assert self.ulysses_degree % num_head_kv == 0, "Ulysses require num_head_kv to be dividable by sp degree."
key, value = expandKV(key, value, self.ulysses_degree // num_head_kv, head_dim)
# scatter 2, gather 1
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx)
if attention_mask is not None:
local_attention_mask = copy.deepcopy(attention_mask)
shard_seqlen = local_attention_mask.size(1)
ulysses_seq_len = get_ulysses_seq_len()
max_global_length = max(ulysses_seq_len)
global_attention_mask_list = []
for i in range(get_ulysses_sp_size()):
if i == get_ulysses_sp_rank():
global_attention_mask_list.append(
torch.cat(
[
local_attention_mask,
torch.zeros(
(local_attention_mask.size(0), max_global_length - shard_seqlen),
dtype=local_attention_mask.dtype,
device=local_attention_mask.device,
),
],
dim=1,
)
)
else:
global_attention_mask_list.append(
torch.zeros(
(local_attention_mask.size(0), max_global_length),
dtype=local_attention_mask.dtype,
device=local_attention_mask.device,
)
)
global_attention_mask = torch.stack(global_attention_mask_list, dim=0)
torch_dist.all_reduce(global_attention_mask, group=self.spg)
torch_dist.barrier(group=self.spg)
new_global_attention_mask_list = list(torch.unbind(global_attention_mask, dim=0))
# Unpad the global attention mask list and concatenate them
for i in range(len(new_global_attention_mask_list)):
new_global_attention_mask_list[i] = new_global_attention_mask_list[i][:, : ulysses_seq_len[i]]
global_attention_mask = torch.cat(new_global_attention_mask_list, dim=1)
context_layer = self.local_attn(
q,
k,
v,
*args,
attention_mask=global_attention_mask,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
seqlens_in_batch=seqlens_in_batch,
causal=causal,
)
else:
context_layer = self.local_attn(
q,
k,
v,
*args,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
)
if isinstance(context_layer, tuple):
context_layer = context_layer[0]
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
# scatter 1, gather 2
output = SeqAllToAll4D.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
# out e.g., [s/p::h]
return output