Spaces:
Running
on
A100
Running
on
A100
File size: 12,035 Bytes
174ae06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 |
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# This file is modified from https://github.com/feifeibear/long-context-attention
# Implementation refers to USP Paper: https://arxiv.org/abs/2405.07719
from typing import Any, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from torch.nn import Module
from llava.train.sequence_parallel.globals import (
get_ulysses_seq_len,
get_ulysses_sp_pg,
get_ulysses_sp_rank,
get_ulysses_sp_size,
set_ulysses_seq_len,
)
def all_to_all_4D(input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.tensor:
"""
all-to-all for QKV
Args:
input (torch.tensor): a tensor sharded along dim scatter dim
scatter_idx (int): default 1
gather_idx (int): default 2
group : torch process group
Returns:
torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)
"""
assert input.dim() == 4, f"input must be 4D tensor, got {input.dim()} and shape {input.shape}"
# seq_world_size = dist.get_world_size(group)
# (DL): Change to ulysses size to handle hybrid parallelism.
seq_world_size = get_ulysses_sp_size()
if scatter_idx == 2 and gather_idx == 1:
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
bs, shard_seqlen, hc, hs = input.shape
# (Dacheng): For multi-modality use case, sequence length is different, causing unknown behavior for a2a.
# Pad it first.
# (Dacheng): This will trigger for each attention to make sure the second a2a is correct.
# (TODO) Maybe can optimize to per forward call.
ulysses_seq_len = [torch.zeros(1, dtype=torch.int64, device=input.device) for _ in range(get_ulysses_sp_size())]
dist.barrier(group=get_ulysses_sp_pg())
dist.all_gather(ulysses_seq_len, torch.tensor(shard_seqlen, device=input.device), group=get_ulysses_sp_pg())
set_ulysses_seq_len(ulysses_seq_len)
max_global_length = max(ulysses_seq_len)
# pad to the second dimension to the longest
input = torch.nn.functional.pad(input, (0, 0, 0, 0, 0, max_global_length - shard_seqlen))
seqlen = max_global_length * seq_world_size # shard_seqlen * seq_world_size
shard_hc = hc // seq_world_size
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
# (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs)
input_t = (
# input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs)
input.reshape(bs, max_global_length, seq_world_size, shard_hc, hs)
.transpose(0, 2)
.contiguous()
)
output = torch.empty_like(input_t)
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
dist.barrier(group=group)
dist.all_to_all_single(output, input_t, group=group)
# if scattering the seq-dim, transpose the heads back to the original dimension
output = output.reshape(seqlen, bs, shard_hc, hs)
# then we will unpad it back
output_list = torch.split(output, max_global_length, dim=0)
assert len(output_list) == get_ulysses_sp_size()
unpadded_output_list = [_output[: _seqlen.item()] for _output, _seqlen in zip(output_list, ulysses_seq_len)]
# Concatenate the unpadded tensors back together
output = torch.cat(unpadded_output_list)
# (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs)
output = output.transpose(0, 1).contiguous().reshape(bs, sum(ulysses_seq_len), shard_hc, hs)
# assert False
return output
elif scatter_idx == 1 and gather_idx == 2:
ulysses_seq_len = get_ulysses_seq_len()
assert ulysses_seq_len is not None, "the second a2a (scatter 1, gather 2) is called at first."
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
bs, _, shard_hc, hs = input.shape
hc = shard_hc * seq_world_size
# First we need to recover how to pad
max_global_length = max(ulysses_seq_len)
unpadded_input_list = torch.split(input, ulysses_seq_len, dim=1)
padded_input_list = [
torch.nn.functional.pad(_unpadded_input, (0, 0, 0, 0, 0, max_global_length - _unpadded_input.shape[1]))
for _unpadded_input in unpadded_input_list
]
input = torch.cat(padded_input_list, dim=1)
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
# (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs)
input_t = (
input.reshape(bs, seq_world_size, max_global_length, shard_hc, hs)
.transpose(0, 3)
.transpose(0, 1)
.contiguous()
.reshape(seq_world_size, shard_hc, max_global_length, bs, hs)
)
output = torch.empty_like(input_t)
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
# (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
dist.barrier(group=group)
dist.all_to_all_single(output, input_t, group=group)
# if scattering the seq-dim, transpose the heads back to the original dimension
output = output.reshape(hc, max_global_length, bs, hs)
# unpad the output
self_length = ulysses_seq_len[get_ulysses_sp_rank()]
# print(f"Self length {self_length}")
output = output[:, :self_length, :, :]
# (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs)
output = output.transpose(0, 2).contiguous().reshape(bs, self_length, hc, hs)
return output
else:
raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2")
class SeqAllToAll4D(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
group: dist.ProcessGroup,
input: Tensor,
scatter_idx: int,
gather_idx: int,
) -> Tensor:
ctx.group = group
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx
return all_to_all_4D(input, scatter_idx, gather_idx, group=group)
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
return (
None,
SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx),
None,
None,
)
def all_to_all_5D(input: torch.tensor, scatter_idx: int = 3, gather_idx: int = 1, group=None) -> torch.tensor:
"""
all-to-all for QKV
forward (bs, seqlen/N, 3, hc, hs) -> (bs, seqlen, 3, hc/N, hs)
Args:
input (torch.tensor): a tensor sharded along dim scatter dim
scatter_idx (int): default 1
gather_idx (int): default 2
group : torch process group
Returns:
torch.tensor: resharded tensor (bs, seqlen/P, 3, hc, hs)
"""
assert input.dim() == 5, f"input must be 5D tensor, got {input.dim()} and shape {input.shape}"
seq_world_size = dist.get_world_size(group)
if scatter_idx == 3 and gather_idx == 1:
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, 3, hc, hs) output: (bs, seqlen, 3, hc/P, hs)
bs, shard_seqlen, t_cnt, hc, hs = input.shape
assert t_cnt == 3
seqlen = shard_seqlen * seq_world_size
shard_hc = hc // seq_world_size
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
# (bs, seqlen/P, 3, hc, hs) -reshape-> (bs, seq_len/P, 3, P, hc/P, hs) -transpose(0,3)-> (P, seq_len/P, 3, bs, hc/P, hs)
input_t = input.reshape(bs, shard_seqlen, 3, seq_world_size, shard_hc, hs).transpose(0, 3).contiguous()
output = torch.empty_like(input_t)
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
# (P, seq_len/P, 3, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, 3, bs, hc/P, hs) scatter head
dist.barrier(group=group)
dist.all_to_all_single(output, input_t, group=group)
# if scattering the seq-dim, transpose the heads back to the original dimension
output = output.reshape(seqlen, 3, bs, shard_hc, hs)
# (seq_len, 3, bs, hc/P, hs) -trans-> (bs, seq_len, 3, hc/P, hs)
output = output.transpose(0, 2).transpose(1, 2).contiguous()
return output.reshape(bs, seqlen, 3, shard_hc, hs).contiguous()
elif scatter_idx == 1 and gather_idx == 3:
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
bs, seqlen, _, shard_hc, hs = input.shape
hc = shard_hc * seq_world_size
shard_seqlen = seqlen // seq_world_size
seq_world_size = dist.get_world_size(group)
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
# (bs, seqlen, 3, hc/P, hs) -reshape-> (bs, P, seq_len/P, 3, hc/P, hs) -transpose(0, 4)-> (hc/P, P, seqlen/P, 3, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, 3, bs, hs)
input_t = (
input.reshape(bs, seq_world_size, shard_seqlen, 3, shard_hc, hs)
.transpose(0, 4)
.transpose(0, 1)
.contiguous()
.reshape(seq_world_size, shard_hc, shard_seqlen, 3, bs, hs)
)
output = torch.empty_like(input_t)
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
# (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
dist.barrier(group=group)
dist.all_to_all_single(output, input_t, group=group)
# if scattering the seq-dim, transpose the heads back to the original dimension
output = output.reshape(hc, shard_seqlen, 3, bs, hs)
# (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs)
output = output.transpose(0, 3).contiguous()
return output.reshape(bs, shard_seqlen, 3, hc, hs).contiguous()
else:
raise RuntimeError("scatter_idx must be 1 or 3 and gather_idx must be 1 or 3")
class SeqAllToAll5D(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
group: dist.ProcessGroup,
input: Tensor,
scatter_idx: int = 3,
gather_idx: int = 1,
) -> Tensor:
ctx.group = group
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx
return all_to_all_5D(input, scatter_idx, gather_idx, group=group)
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
return (
None,
SeqAllToAll5D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx),
None,
None,
)
class SeqAllGather(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Any) -> Tensor:
# ctx.group = group
ctx.save_for_backward(input[0])
all_gather_list = input[0]
all_gather_tensor = input[1]
dist.all_gather(all_gather_list, all_gather_tensor, group=group)
# torch.concat
return torch.stack(all_gather_list, dim=0)
@staticmethod
def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor]:
(tensor,) = ctx.saved_tensors
return None, (None, tensor)
|