Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
# This file is modified from https://github.com/feifeibear/long-context-attention | |
# Implementation refers to USP Paper: https://arxiv.org/abs/2405.07719 | |
import copy | |
from typing import Any | |
import torch | |
import torch.distributed as dist | |
from torch import Tensor | |
from torch.nn import Module | |
from .all_to_all import SeqAllToAll4D, SeqAllToAll5D | |
from .globals import get_ring_sp_pg, get_ring_type, get_ulysses_sp_pg | |
from .ring import ( | |
ring_flash_attn_func, | |
ring_flash_attn_qkvpacked_func, | |
ring_flash_attn_varlen_func, | |
ring_flash_attn_varlen_qkvpacked_func, | |
stripe_flash_attn_func, | |
stripe_flash_attn_qkvpacked_func, | |
zigzag_ring_flash_attn_func, | |
zigzag_ring_flash_attn_qkvpacked_func, | |
zigzag_ring_flash_attn_varlen_func, | |
zigzag_ring_flash_attn_varlen_qkvpacked_func, | |
) | |
RING_IMPL_DICT = { | |
"ring": ring_flash_attn_func, | |
"zigzag": zigzag_ring_flash_attn_func, | |
"strip": stripe_flash_attn_func, | |
"ring_varlen": ring_flash_attn_varlen_func, | |
"zigzag_ring_varlen": zigzag_ring_flash_attn_varlen_func, | |
} | |
RING_IMPL_QKVPACKED_DICT = { | |
"ring": ring_flash_attn_qkvpacked_func, | |
"zigzag": zigzag_ring_flash_attn_qkvpacked_func, | |
"strip": stripe_flash_attn_qkvpacked_func, | |
"ring_varlen": ring_flash_attn_varlen_qkvpacked_func, | |
"zigzag_varlen": zigzag_ring_flash_attn_varlen_qkvpacked_func, | |
} | |
class HybridAttention(torch.nn.Module): | |
"""Initialization. | |
Arguments: | |
ulysses_pg (ProcessGroup): ulysses process group | |
ring_pg (ProcessGroup): ring process group | |
scatter_idx (int): scatter_idx for all2all comm | |
gather_idx (int): gather_idx for all2all comm | |
""" | |
def __init__( | |
self, | |
scatter_idx: int = 2, | |
gather_idx: int = 1, | |
use_pack_qkv: bool = False, | |
attention_warper: Module = None, | |
) -> None: | |
super().__init__() | |
self.ring_pg = get_ring_sp_pg() | |
self.ulysses_pg = get_ulysses_sp_pg() | |
self.use_pack_qkv = use_pack_qkv | |
assert ( | |
self.ulysses_pg is not None or self.ring_pg is not None | |
), f"use set_pg_manager() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}" | |
self.scatter_idx = scatter_idx | |
self.gather_idx = gather_idx | |
if attention_warper is None: | |
self.ring_attn_fn = RING_IMPL_DICT[get_ring_type()] | |
else: | |
self.ring_attn_fn = attention_warper | |
def forward( | |
self, | |
query: Tensor, | |
key: Tensor, | |
value: Tensor, | |
*args: Any, | |
attention_mask=None, | |
dropout_p=0.0, | |
softmax_scale=None, | |
seqlens_in_batch=None, | |
causal=False, | |
window_size=(-1, -1), | |
alibi_slopes=None, | |
deterministic=False, | |
return_attn_probs=False, | |
) -> Tensor: | |
"""forward | |
Arguments: | |
query (Tensor): query input to the layer | |
key (Tensor): key input to the layer | |
value (Tensor): value input to the layer | |
args: other args | |
Returns: | |
* output (Tensor): context output | |
""" | |
# 3 X (bs, seq_len/N, head_cnt, head_size) -> 3 X (bs, seq_len, head_cnt/N, head_size) | |
# scatter 2, gather 1 | |
if self.use_pack_qkv: | |
# TODO (Qinghao): To support packed qkv | |
raise NotImplementedError("Packed qkv is not supported yet.") | |
# (3*bs, seq_len/N, head_cnt, head_size) | |
qkv = torch.cat([query, key, value]).continous() | |
# (3*bs, seq_len, head_cnt/N, head_size) | |
qkv = SeqAllToAll4D.apply(self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx) | |
qkv = torch.chunk(qkv, 3, dim=0) | |
out = self.ring_attn_fn( | |
qkv[0], | |
qkv[1], | |
qkv[2], | |
dropout_p=dropout_p, | |
softmax_scale=softmax_scale, | |
causal=causal, | |
window_size=window_size, | |
alibi_slopes=alibi_slopes, | |
deterministic=deterministic, | |
return_attn_probs=return_attn_probs, | |
group=self.ring_pg, | |
) | |
query_layer = SeqAllToAll4D.apply(self.ulysses_pg, query, self.scatter_idx, self.gather_idx) | |
key_layer = SeqAllToAll4D.apply(self.ulysses_pg, key, self.scatter_idx, self.gather_idx) | |
value_layer = SeqAllToAll4D.apply(self.ulysses_pg, value, self.scatter_idx, self.gather_idx) | |
if attention_mask is not None: | |
new_attention_mask = torch.cat([attention_mask] * dist.get_world_size(self.ulysses_pg), dim=1) | |
out = self.ring_attn_fn( | |
query_layer, | |
key_layer, | |
value_layer, | |
*args, | |
attention_mask=new_attention_mask, | |
dropout_p=dropout_p, | |
softmax_scale=softmax_scale, | |
seqlens_in_batch=seqlens_in_batch, | |
causal=causal, | |
group=self.ring_pg, | |
) | |
else: | |
out = self.ring_attn_fn( | |
query_layer, | |
key_layer, | |
value_layer, | |
dropout_p=dropout_p, | |
softmax_scale=softmax_scale, | |
causal=causal, | |
window_size=window_size, | |
alibi_slopes=alibi_slopes, | |
deterministic=deterministic, | |
return_attn_probs=return_attn_probs, | |
group=self.ring_pg, | |
) | |
if type(out) == tuple: | |
context_layer, _, _ = out | |
else: | |
context_layer = out | |
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) | |
# scatter 1, gather 2 | |
output = SeqAllToAll4D.apply(self.ulysses_pg, context_layer, self.gather_idx, self.scatter_idx) | |
# out e.g., [s/p::h] | |
return output | |
# TODO (Qinghao): To be supported | |
class HybridAttentionQKVPacked(torch.nn.Module): | |
"""Initialization. | |
Arguments: | |
ulysses_pg (ProcessGroup): ulysses process group | |
ring_pg (ProcessGroup): ring process group | |
scatter_idx (int): scatter_idx for all2all comm | |
gather_idx (int): gather_idx for all2all comm | |
""" | |
def __init__( | |
self, | |
scatter_idx: int = 3, | |
gather_idx: int = 1, | |
ring_impl_type: str = "zigzag", | |
) -> None: | |
super().__init__() | |
self.ring_pg = get_ring_sp_pg() | |
self.ulysses_pg = get_ulysses_sp_pg() | |
assert ( | |
self.ulysses_pg is not None or self.ring_pg is not None | |
), f"use set_pg_manager() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}" | |
self.scatter_idx = scatter_idx | |
self.gather_idx = gather_idx | |
self.ring_attn_fn = RING_IMPL_QKVPACKED_DICT[ring_impl_type] | |
def forward( | |
self, | |
qkv, | |
dropout_p=0.0, | |
softmax_scale=None, | |
causal=False, | |
window_size=(-1, -1), | |
alibi_slopes=None, | |
deterministic=False, | |
return_attn_probs=False, | |
*args: Any, | |
) -> Tensor: | |
"""forward | |
Arguments: | |
query (Tensor): query input to the layer | |
key (Tensor): key input to the layer | |
value (Tensor): value input to the layer | |
args: other args | |
Returns: | |
* output (Tensor): context output | |
""" | |
# scatter 3, gather 1 | |
world_size = dist.get_world_size(self.ulysses_pg) | |
if world_size > 1 and dist.is_initialized(): | |
qkv = SeqAllToAll5D.apply(self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx) | |
out = self.ring_attn_fn( | |
qkv, | |
dropout_p=dropout_p, | |
softmax_scale=softmax_scale, | |
causal=causal, | |
window_size=window_size, | |
alibi_slopes=alibi_slopes, | |
deterministic=deterministic, | |
return_attn_probs=return_attn_probs, | |
group=self.ring_pg, | |
) | |
# print(f"out {out.shape}") | |
if type(out) == tuple: | |
out = out[0] | |
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) | |
# scatter 1, gather 2 | |
if world_size > 1 and dist.is_initialized(): | |
out = SeqAllToAll4D.apply(self.ulysses_pg, out, self.gather_idx, self.scatter_idx - 1) | |
# out e.g., [s/p::h] | |
return out | |
# TODO (Qinghao): To be supported | |
class AsyncHybridAttention(torch.nn.Module): | |
"""Initialization. | |
Arguments: | |
ulysses_pg (ProcessGroup): ulysses process group | |
ring_pg (ProcessGroup): ring process group | |
scatter_idx (int): scatter_idx for all2all comm | |
gather_idx (int): gather_idx for all2all comm | |
""" | |
def __init__( | |
self, | |
scatter_idx: int = 2, | |
gather_idx: int = 1, | |
ring_impl_type: str = "zigzag", | |
) -> None: | |
super().__init__() | |
self.ring_pg = get_ring_sp_pg() | |
self.ulysses_pg = get_ulysses_sp_pg() | |
self.stream = torch.cuda.Stream() | |
self._async_op = True | |
assert ( | |
self.ulysses_pg is not None or self.ring_pg is not None | |
), f"use set_pg_manager() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}" | |
self.scatter_idx = scatter_idx | |
self.gather_idx = gather_idx | |
self.ring_attn_fn = RING_IMPL_DICT[ring_impl_type] | |
def forward( | |
self, | |
query: Tensor, | |
key: Tensor, | |
value: Tensor, | |
dropout_p=0.0, | |
softmax_scale=None, | |
causal=False, | |
window_size=(-1, -1), | |
alibi_slopes=None, | |
deterministic=False, | |
return_attn_probs=False, | |
*args: Any, | |
) -> Tensor: | |
"""forward | |
Arguments: | |
query (Tensor): query input to the layer (bs, seqlen/P, hc, hs) | |
key (Tensor): key input to the layer (bs, seqlen/P, hc_kv, hs) | |
value (Tensor): value input to the layer (bs, seqlen/P, hc_kv, hs) | |
args: other args | |
Returns: | |
* output (Tensor): context output | |
""" | |
# un*ud = hc | |
ulysses_degree = dist.get_world_size(self.ulysses_pg) | |
bs, shard_seqlen, hc, hs = query.shape | |
bs, shard_seqlen, hc_kv, hs = key.shape | |
seq_len = shard_seqlen * ulysses_degree | |
un = hc // ulysses_degree | |
un_kv = hc_kv // ulysses_degree | |
assert un_kv == un, f"un_kv {un_kv} un {un}" | |
qkv = torch.cat([query, key, value]).contiguous() | |
# (3*bs, seqlen/P, hc, hs) -> (hc, seqlen/P, 3*bs, hs) -> (un, ud, seqlen/P, 3*bs, hs), where hc = un*ud | |
qkv_list = torch.unbind(qkv.transpose(0, 2).contiguous().reshape(un, ulysses_degree, shard_seqlen, 3 * bs, hs)) | |
# 3xall-to-all output buffer | |
qkv_trans_list = [ | |
torch.zeros( | |
ulysses_degree, | |
1, | |
shard_seqlen, | |
3 * bs, | |
hs, | |
dtype=query.dtype, | |
device=query.device, | |
) | |
for i in range(len(qkv_list)) | |
] | |
# last all-to-all buffter | |
context_layer_list = [ | |
torch.zeros( | |
ulysses_degree, | |
1, | |
shard_seqlen, | |
bs, | |
hs, | |
dtype=query.dtype, | |
device=query.device, | |
) | |
for i in range(len(qkv_list)) | |
] | |
comm_handle_list = [] | |
# un * (ud, shard_seqlen, 3*bs, hs) | |
for i, qkv in enumerate(qkv_list): | |
with torch.cuda.stream(self.stream): | |
ret = dist.all_to_all_single( | |
qkv_trans_list[i], | |
qkv, | |
group=self.ulysses_pg, | |
async_op=self._async_op, | |
) | |
comm_handle_list.append(ret) | |
last_comm_handle_list = [] | |
for i, qkv_trans in enumerate(qkv_trans_list): | |
if comm_handle_list[i] is not None: | |
comm_handle_list[i].wait() | |
qkv_trans = ( | |
qkv_trans.reshape(seq_len, 3 * bs, 1, hs).transpose(0, 1).contiguous().reshape(3 * bs, seq_len, 1, hs) | |
) | |
# qkv_trans = all_to_all_4D_async(qkv, qkv_trans_list[i], self.scatter_idx, self.gather_idx, self.ulysses_pg) | |
qkv_trans = torch.chunk(qkv_trans, 3, dim=0) | |
out = self.ring_attn_fn( | |
qkv_trans[0], | |
qkv_trans[1], | |
qkv_trans[2], | |
dropout_p=dropout_p, | |
softmax_scale=softmax_scale, | |
causal=causal, | |
window_size=window_size, | |
alibi_slopes=alibi_slopes, | |
deterministic=deterministic, | |
return_attn_probs=return_attn_probs, | |
group=self.ring_pg, | |
) | |
if type(out) == tuple: | |
context_layer, _, _ = out | |
else: | |
context_layer = out | |
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) | |
# scatter 1, gather 2 | |
context_layer = ( | |
context_layer.reshape(bs, ulysses_degree, shard_seqlen, 1, hs) | |
.transpose(0, 3) | |
.transpose(0, 1) | |
.contiguous() | |
.reshape(ulysses_degree, 1, shard_seqlen, bs, hs) | |
) | |
with torch.cuda.stream(self.stream): | |
ret = dist.all_to_all_single( | |
context_layer_list[i], | |
context_layer, | |
group=self.ulysses_pg, | |
async_op=self._async_op, | |
) | |
last_comm_handle_list.append(ret) | |
# hc = un * P | |
# un x (hc = P, seq_len/P, bs, hs) -> (bs, seq_len, hc = P, hs) | |
for i, ret in enumerate(last_comm_handle_list): | |
if ret is not None: | |
ret.wait() | |
context_layer_list[i] = ( | |
context_layer_list[i] | |
.reshape(ulysses_degree, shard_seqlen, bs, hs) | |
.transpose(0, 2) | |
.contiguous() | |
.reshape(bs, shard_seqlen, ulysses_degree, hs) | |
) | |
output = torch.cat(context_layer_list, dim=2) | |
return output | |
def backward(self, *args, **kwargs): | |
raise RuntimeError("Backward computation is not allowed for AsyncHybridAttention.") | |