SreyanG-NVIDIA's picture
Upload 225 files
174ae06 verified
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# This file is modified from https://github.com/feifeibear/long-context-attention
# Implementation refers to USP Paper: https://arxiv.org/abs/2405.07719
import copy
from typing import Any
import torch
import torch.distributed as dist
from torch import Tensor
from torch.nn import Module
from .all_to_all import SeqAllToAll4D, SeqAllToAll5D
from .globals import get_ring_sp_pg, get_ring_type, get_ulysses_sp_pg
from .ring import (
ring_flash_attn_func,
ring_flash_attn_qkvpacked_func,
ring_flash_attn_varlen_func,
ring_flash_attn_varlen_qkvpacked_func,
stripe_flash_attn_func,
stripe_flash_attn_qkvpacked_func,
zigzag_ring_flash_attn_func,
zigzag_ring_flash_attn_qkvpacked_func,
zigzag_ring_flash_attn_varlen_func,
zigzag_ring_flash_attn_varlen_qkvpacked_func,
)
RING_IMPL_DICT = {
"ring": ring_flash_attn_func,
"zigzag": zigzag_ring_flash_attn_func,
"strip": stripe_flash_attn_func,
"ring_varlen": ring_flash_attn_varlen_func,
"zigzag_ring_varlen": zigzag_ring_flash_attn_varlen_func,
}
RING_IMPL_QKVPACKED_DICT = {
"ring": ring_flash_attn_qkvpacked_func,
"zigzag": zigzag_ring_flash_attn_qkvpacked_func,
"strip": stripe_flash_attn_qkvpacked_func,
"ring_varlen": ring_flash_attn_varlen_qkvpacked_func,
"zigzag_varlen": zigzag_ring_flash_attn_varlen_qkvpacked_func,
}
class HybridAttention(torch.nn.Module):
"""Initialization.
Arguments:
ulysses_pg (ProcessGroup): ulysses process group
ring_pg (ProcessGroup): ring process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
"""
def __init__(
self,
scatter_idx: int = 2,
gather_idx: int = 1,
use_pack_qkv: bool = False,
attention_warper: Module = None,
) -> None:
super().__init__()
self.ring_pg = get_ring_sp_pg()
self.ulysses_pg = get_ulysses_sp_pg()
self.use_pack_qkv = use_pack_qkv
assert (
self.ulysses_pg is not None or self.ring_pg is not None
), f"use set_pg_manager() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}"
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
if attention_warper is None:
self.ring_attn_fn = RING_IMPL_DICT[get_ring_type()]
else:
self.ring_attn_fn = attention_warper
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
*args: Any,
attention_mask=None,
dropout_p=0.0,
softmax_scale=None,
seqlens_in_batch=None,
causal=False,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
) -> Tensor:
"""forward
Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
args: other args
Returns:
* output (Tensor): context output
"""
# 3 X (bs, seq_len/N, head_cnt, head_size) -> 3 X (bs, seq_len, head_cnt/N, head_size)
# scatter 2, gather 1
if self.use_pack_qkv:
# TODO (Qinghao): To support packed qkv
raise NotImplementedError("Packed qkv is not supported yet.")
# (3*bs, seq_len/N, head_cnt, head_size)
qkv = torch.cat([query, key, value]).continous()
# (3*bs, seq_len, head_cnt/N, head_size)
qkv = SeqAllToAll4D.apply(self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx)
qkv = torch.chunk(qkv, 3, dim=0)
out = self.ring_attn_fn(
qkv[0],
qkv[1],
qkv[2],
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
group=self.ring_pg,
)
query_layer = SeqAllToAll4D.apply(self.ulysses_pg, query, self.scatter_idx, self.gather_idx)
key_layer = SeqAllToAll4D.apply(self.ulysses_pg, key, self.scatter_idx, self.gather_idx)
value_layer = SeqAllToAll4D.apply(self.ulysses_pg, value, self.scatter_idx, self.gather_idx)
if attention_mask is not None:
new_attention_mask = torch.cat([attention_mask] * dist.get_world_size(self.ulysses_pg), dim=1)
out = self.ring_attn_fn(
query_layer,
key_layer,
value_layer,
*args,
attention_mask=new_attention_mask,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
seqlens_in_batch=seqlens_in_batch,
causal=causal,
group=self.ring_pg,
)
else:
out = self.ring_attn_fn(
query_layer,
key_layer,
value_layer,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
group=self.ring_pg,
)
if type(out) == tuple:
context_layer, _, _ = out
else:
context_layer = out
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
# scatter 1, gather 2
output = SeqAllToAll4D.apply(self.ulysses_pg, context_layer, self.gather_idx, self.scatter_idx)
# out e.g., [s/p::h]
return output
# TODO (Qinghao): To be supported
class HybridAttentionQKVPacked(torch.nn.Module):
"""Initialization.
Arguments:
ulysses_pg (ProcessGroup): ulysses process group
ring_pg (ProcessGroup): ring process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
"""
def __init__(
self,
scatter_idx: int = 3,
gather_idx: int = 1,
ring_impl_type: str = "zigzag",
) -> None:
super().__init__()
self.ring_pg = get_ring_sp_pg()
self.ulysses_pg = get_ulysses_sp_pg()
assert (
self.ulysses_pg is not None or self.ring_pg is not None
), f"use set_pg_manager() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}"
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
self.ring_attn_fn = RING_IMPL_QKVPACKED_DICT[ring_impl_type]
def forward(
self,
qkv,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*args: Any,
) -> Tensor:
"""forward
Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
args: other args
Returns:
* output (Tensor): context output
"""
# scatter 3, gather 1
world_size = dist.get_world_size(self.ulysses_pg)
if world_size > 1 and dist.is_initialized():
qkv = SeqAllToAll5D.apply(self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx)
out = self.ring_attn_fn(
qkv,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
group=self.ring_pg,
)
# print(f"out {out.shape}")
if type(out) == tuple:
out = out[0]
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
# scatter 1, gather 2
if world_size > 1 and dist.is_initialized():
out = SeqAllToAll4D.apply(self.ulysses_pg, out, self.gather_idx, self.scatter_idx - 1)
# out e.g., [s/p::h]
return out
# TODO (Qinghao): To be supported
class AsyncHybridAttention(torch.nn.Module):
"""Initialization.
Arguments:
ulysses_pg (ProcessGroup): ulysses process group
ring_pg (ProcessGroup): ring process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
"""
def __init__(
self,
scatter_idx: int = 2,
gather_idx: int = 1,
ring_impl_type: str = "zigzag",
) -> None:
super().__init__()
self.ring_pg = get_ring_sp_pg()
self.ulysses_pg = get_ulysses_sp_pg()
self.stream = torch.cuda.Stream()
self._async_op = True
assert (
self.ulysses_pg is not None or self.ring_pg is not None
), f"use set_pg_manager() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}"
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
self.ring_attn_fn = RING_IMPL_DICT[ring_impl_type]
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*args: Any,
) -> Tensor:
"""forward
Arguments:
query (Tensor): query input to the layer (bs, seqlen/P, hc, hs)
key (Tensor): key input to the layer (bs, seqlen/P, hc_kv, hs)
value (Tensor): value input to the layer (bs, seqlen/P, hc_kv, hs)
args: other args
Returns:
* output (Tensor): context output
"""
# un*ud = hc
ulysses_degree = dist.get_world_size(self.ulysses_pg)
bs, shard_seqlen, hc, hs = query.shape
bs, shard_seqlen, hc_kv, hs = key.shape
seq_len = shard_seqlen * ulysses_degree
un = hc // ulysses_degree
un_kv = hc_kv // ulysses_degree
assert un_kv == un, f"un_kv {un_kv} un {un}"
qkv = torch.cat([query, key, value]).contiguous()
# (3*bs, seqlen/P, hc, hs) -> (hc, seqlen/P, 3*bs, hs) -> (un, ud, seqlen/P, 3*bs, hs), where hc = un*ud
qkv_list = torch.unbind(qkv.transpose(0, 2).contiguous().reshape(un, ulysses_degree, shard_seqlen, 3 * bs, hs))
# 3xall-to-all output buffer
qkv_trans_list = [
torch.zeros(
ulysses_degree,
1,
shard_seqlen,
3 * bs,
hs,
dtype=query.dtype,
device=query.device,
)
for i in range(len(qkv_list))
]
# last all-to-all buffter
context_layer_list = [
torch.zeros(
ulysses_degree,
1,
shard_seqlen,
bs,
hs,
dtype=query.dtype,
device=query.device,
)
for i in range(len(qkv_list))
]
comm_handle_list = []
# un * (ud, shard_seqlen, 3*bs, hs)
for i, qkv in enumerate(qkv_list):
with torch.cuda.stream(self.stream):
ret = dist.all_to_all_single(
qkv_trans_list[i],
qkv,
group=self.ulysses_pg,
async_op=self._async_op,
)
comm_handle_list.append(ret)
last_comm_handle_list = []
for i, qkv_trans in enumerate(qkv_trans_list):
if comm_handle_list[i] is not None:
comm_handle_list[i].wait()
qkv_trans = (
qkv_trans.reshape(seq_len, 3 * bs, 1, hs).transpose(0, 1).contiguous().reshape(3 * bs, seq_len, 1, hs)
)
# qkv_trans = all_to_all_4D_async(qkv, qkv_trans_list[i], self.scatter_idx, self.gather_idx, self.ulysses_pg)
qkv_trans = torch.chunk(qkv_trans, 3, dim=0)
out = self.ring_attn_fn(
qkv_trans[0],
qkv_trans[1],
qkv_trans[2],
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
group=self.ring_pg,
)
if type(out) == tuple:
context_layer, _, _ = out
else:
context_layer = out
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
# scatter 1, gather 2
context_layer = (
context_layer.reshape(bs, ulysses_degree, shard_seqlen, 1, hs)
.transpose(0, 3)
.transpose(0, 1)
.contiguous()
.reshape(ulysses_degree, 1, shard_seqlen, bs, hs)
)
with torch.cuda.stream(self.stream):
ret = dist.all_to_all_single(
context_layer_list[i],
context_layer,
group=self.ulysses_pg,
async_op=self._async_op,
)
last_comm_handle_list.append(ret)
# hc = un * P
# un x (hc = P, seq_len/P, bs, hs) -> (bs, seq_len, hc = P, hs)
for i, ret in enumerate(last_comm_handle_list):
if ret is not None:
ret.wait()
context_layer_list[i] = (
context_layer_list[i]
.reshape(ulysses_degree, shard_seqlen, bs, hs)
.transpose(0, 2)
.contiguous()
.reshape(bs, shard_seqlen, ulysses_degree, hs)
)
output = torch.cat(context_layer_list, dim=2)
return output
def backward(self, *args, **kwargs):
raise RuntimeError("Backward computation is not allowed for AsyncHybridAttention.")