Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
# Adopted from https://github.com/zhuzilin/ring-flash-attention. | |
# Implementation refers to Ring Attention Paper: https://arxiv.org/abs/2310.01889 | |
import torch | |
from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward | |
from .utils import RingComm, update_out_and_lse | |
def zigzag_ring_flash_attn_forward( | |
process_group, | |
q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor, | |
softmax_scale, | |
dropout_p=0, | |
causal=True, | |
window_size=(-1, -1), | |
alibi_slopes=None, | |
deterministic=False, | |
): | |
assert causal == True, "zigzag ring is meaningless for causal=False" | |
comm = RingComm(process_group) | |
block_seq_len = q.shape[1] // 2 | |
q1 = q[:, block_seq_len:] | |
out = None | |
lse = None | |
next_k, next_v = None, None | |
def forward(q, k, v, causal): | |
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( | |
q, | |
k, | |
v, | |
dropout_p, | |
softmax_scale, | |
causal=causal, | |
window_size=window_size, | |
alibi_slopes=alibi_slopes, | |
return_softmax=True and dropout_p > 0, | |
) | |
return block_out, block_lse | |
for step in range(comm.world_size): | |
if step + 1 != comm.world_size: | |
next_k: torch.Tensor = comm.send_recv(k) | |
next_v: torch.Tensor = comm.send_recv(v) | |
comm.commit() | |
if step == 0: | |
block_out, block_lse = forward(q, k, v, causal=True) | |
out, lse = update_out_and_lse(out, lse, block_out, block_lse) | |
elif step <= comm.rank: | |
k0 = k[:, :block_seq_len] | |
v0 = v[:, :block_seq_len] | |
block_out, block_lse = forward(q, k0, v0, causal=False) | |
out, lse = update_out_and_lse(out, lse, block_out, block_lse) | |
else: | |
block_out, block_lse = forward(q1, k, v, causal=False) | |
out, lse = update_out_and_lse( | |
out, | |
lse, | |
block_out, | |
block_lse, | |
slice_=(slice(None), slice(block_seq_len, None)), | |
) | |
if step + 1 != comm.world_size: | |
comm.wait() | |
k = next_k | |
v = next_v | |
out = out.to(q.dtype) | |
lse = lse.squeeze(dim=-1).transpose(1, 2) | |
return out, lse | |
def zigzag_ring_flash_attn_backward( | |
process_group, | |
dout, | |
q, | |
k, | |
v, | |
out, | |
softmax_lse, | |
softmax_scale, | |
dropout_p=0, | |
causal=True, | |
window_size=(-1, -1), | |
alibi_slopes=None, | |
deterministic=False, | |
): | |
assert causal == True, "zigzag ring is meaningless for causal=False" | |
kv_comm = RingComm(process_group) | |
d_kv_comm = RingComm(process_group) | |
dq, dk, dv = None, None, None | |
next_dk, next_dv = None, None | |
next_k, next_v = None, None | |
dk_comm_buffer, dv_comm_buffer = None, None | |
dout1 = dout.chunk(2, dim=1)[1] | |
q1 = q.chunk(2, dim=1)[1] | |
out1 = out.chunk(2, dim=1)[1] | |
softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous() | |
block_seq_len = q.shape[1] // 2 | |
# repeatly allocating buffer may be slow... | |
dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) | |
dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) | |
dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) | |
def backward(dout, q, k, v, out, softmax_lse, causal): | |
seqlen_q = q.shape[1] | |
seqlen_kv = k.shape[1] | |
_flash_attn_backward( | |
dout, | |
q, | |
k, | |
v, | |
out, | |
softmax_lse, | |
dq_buffer[:, :seqlen_q], | |
dk_buffer[:, :seqlen_kv], | |
dv_buffer[:, :seqlen_kv], | |
dropout_p, | |
softmax_scale, | |
causal, | |
window_size, | |
alibi_slopes, | |
deterministic, | |
rng_state=None, | |
) | |
for step in range(kv_comm.world_size): | |
if step + 1 != kv_comm.world_size: | |
next_k = kv_comm.send_recv(k) | |
next_v = kv_comm.send_recv(v) | |
kv_comm.commit() | |
if step == 0: | |
backward(dout, q, k, v, out, softmax_lse, causal=True) | |
dq = dq_buffer.to(torch.float32) | |
dk = dk_buffer.to(torch.float32) | |
dv = dv_buffer.to(torch.float32) | |
else: | |
if step <= kv_comm.rank: | |
k0 = k[:, :block_seq_len] | |
v0 = v[:, :block_seq_len] | |
backward(dout, q, k0, v0, out, softmax_lse, causal=False) | |
dq += dq_buffer | |
else: | |
backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) | |
# always use the first half in dq_buffer. | |
dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len] | |
d_kv_comm.wait() | |
dk_comm_buffer, dv_comm_buffer = dk, dv | |
dk, dv = next_dk, next_dv | |
if step <= kv_comm.rank: | |
dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len] | |
dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len] | |
else: | |
dk += dk_buffer | |
dv += dv_buffer | |
if step + 1 != kv_comm.world_size: | |
kv_comm.wait() | |
k = next_k | |
v = next_v | |
next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) | |
next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) | |
d_kv_comm.commit() | |
d_kv_comm.wait() | |
return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) | |
class ZigZagRingFlashAttnFunc(torch.autograd.Function): | |
def forward( | |
ctx, | |
q, | |
k, | |
v, | |
dropout_p, | |
softmax_scale, | |
causal, | |
window_size, | |
alibi_slopes, | |
deterministic, | |
return_softmax, | |
group, | |
): | |
if softmax_scale is None: | |
softmax_scale = q.shape[-1] ** (-0.5) | |
assert alibi_slopes is None | |
k = k.contiguous() | |
v = v.contiguous() | |
out, softmax_lse = zigzag_ring_flash_attn_forward( | |
group, | |
q, | |
k, | |
v, | |
softmax_scale=softmax_scale, | |
dropout_p=dropout_p, | |
causal=causal, | |
window_size=window_size, | |
alibi_slopes=alibi_slopes, | |
deterministic=False, | |
) | |
# this should be out_padded | |
ctx.save_for_backward(q, k, v, out, softmax_lse) | |
ctx.dropout_p = dropout_p | |
ctx.softmax_scale = softmax_scale | |
ctx.causal = causal | |
ctx.window_size = window_size | |
ctx.alibi_slopes = alibi_slopes | |
ctx.deterministic = deterministic | |
ctx.group = group | |
return out if not return_softmax else (out, softmax_lse, None) | |
def backward(ctx, dout, *args): | |
q, k, v, out, softmax_lse = ctx.saved_tensors | |
dq, dk, dv = zigzag_ring_flash_attn_backward( | |
ctx.group, | |
dout, | |
q, | |
k, | |
v, | |
out, | |
softmax_lse, | |
softmax_scale=ctx.softmax_scale, | |
dropout_p=ctx.dropout_p, | |
causal=ctx.causal, | |
window_size=ctx.window_size, | |
alibi_slopes=ctx.alibi_slopes, | |
deterministic=ctx.deterministic, | |
) | |
return dq, dk, dv, None, None, None, None, None, None, None, None | |
def zigzag_ring_flash_attn_qkvpacked_func( | |
qkv, | |
dropout_p=0.0, | |
softmax_scale=None, | |
causal=False, | |
window_size=(-1, -1), | |
alibi_slopes=None, | |
deterministic=False, | |
return_attn_probs=False, | |
group=None, | |
): | |
return ZigZagRingFlashAttnFunc.apply( | |
qkv[:, :, 0], | |
qkv[:, :, 1], | |
qkv[:, :, 2], | |
dropout_p, | |
softmax_scale, | |
causal, | |
window_size, | |
alibi_slopes, | |
deterministic, | |
return_attn_probs, | |
group, | |
) | |
def zigzag_ring_flash_attn_kvpacked_func( | |
q, | |
kv, | |
dropout_p=0.0, | |
softmax_scale=None, | |
causal=False, | |
window_size=(-1, -1), | |
alibi_slopes=None, | |
deterministic=False, | |
return_attn_probs=False, | |
group=None, | |
): | |
return ZigZagRingFlashAttnFunc.apply( | |
q, | |
kv[:, :, 0], | |
kv[:, :, 1], | |
dropout_p, | |
softmax_scale, | |
causal, | |
window_size, | |
alibi_slopes, | |
deterministic, | |
return_attn_probs, | |
group, | |
) | |
def zigzag_ring_flash_attn_func( | |
q, | |
k, | |
v, | |
dropout_p=0.0, | |
softmax_scale=None, | |
causal=False, | |
window_size=(-1, -1), | |
alibi_slopes=None, | |
deterministic=False, | |
return_attn_probs=False, | |
group=None, | |
): | |
return ZigZagRingFlashAttnFunc.apply( | |
q, | |
k, | |
v, | |
dropout_p, | |
softmax_scale, | |
causal, | |
window_size, | |
alibi_slopes, | |
deterministic, | |
return_attn_probs, | |
group, | |
) | |