Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
# This file is modified from https://github.com/feifeibear/long-context-attention | |
# Implementation refers to USP Paper: https://arxiv.org/abs/2405.07719 | |
import os | |
import deepspeed.comm as dist | |
import torch | |
class Singleton: | |
_instance = None | |
def __new__(cls, *args, **kwargs): | |
if not cls._instance: | |
cls._instance = super().__new__(cls) | |
cls._instance.__initialized = False | |
return cls._instance | |
def __init__(self): | |
if not self.__initialized: | |
self.__initialized = True | |
class ProcessGroupManager(Singleton): | |
""" | |
sp_degree = sp_ring_degree x sp_ulysses_degree | |
""" | |
def __init__(self, ulysses_degree, ring_degree, dp_degree, use_ulysses_low, ring_type): | |
if not hasattr(self, "__initialized"): | |
super().__init__() | |
self.ulysses_degree = ulysses_degree | |
self.ring_type = ring_type | |
self.ulysses_seq_len = None | |
self.ring_degree = ring_degree | |
self.sp_degree = ring_degree * ulysses_degree | |
self.dp_degree = dp_degree | |
self.rank = dist.get_rank() | |
if self.ring_degree == 1: | |
# Using Ulysses Sequence Parallelism only | |
num_ulysses_pgs = self.dp_degree | |
self.ring_pg = None | |
self.ring_rank = None | |
for i in range(num_ulysses_pgs): | |
ulysses_ranks = list(range(i * self.ulysses_degree, (i + 1) * self.ulysses_degree)) | |
group = dist.new_group(ulysses_ranks) | |
if self.rank in ulysses_ranks: | |
self.ulysses_pg = group | |
for sp_rank in range(self.sp_degree): | |
dp_ranks = list(range(sp_rank, self.dp_degree * self.sp_degree, self.sp_degree)) | |
group = dist.new_group(dp_ranks) | |
if self.rank in dp_ranks: | |
self.dp_pg = group | |
self.ulysses_rank = dist.get_rank(self.ulysses_pg) | |
self.sp_rank = self.ulysses_rank | |
self.dp_rank = dist.get_rank(self.dp_pg) | |
self.sp_pg = self.ulysses_pg | |
print(f"GPU {torch.cuda.current_device()} Ulysses rank: {self.ulysses_rank} out of {self.sp_degree}") | |
else: | |
# Using Hybrid Sequence Parallelism | |
assert self.ring_degree > 1 | |
num_ulysses_pgs = self.ring_degree # world_size // self.ulysses_degree | |
num_ring_pgs = self.ulysses_degree # world_size // self.ring_degree | |
# Set up process groups | |
if use_ulysses_low: | |
for dp_rank in range(dp_degree): | |
offset = dp_rank * self.sp_degree | |
for i in range(num_ulysses_pgs): | |
ulysses_ranks = list( | |
range( | |
i * self.ulysses_degree + offset, | |
(i + 1) * self.ulysses_degree + offset, | |
) | |
) | |
group = dist.new_group(ulysses_ranks) | |
if self.rank in ulysses_ranks: | |
self.ulysses_pg = group | |
for i in range(num_ring_pgs): | |
ring_ranks = list(range(i + offset, self.sp_degree + offset, num_ring_pgs)) | |
group = dist.new_group(ring_ranks) | |
if self.rank in ring_ranks: | |
self.ring_pg = group | |
else: | |
for dp_rank in range(dp_degree): | |
offset = dp_rank * self.sp_degree | |
for i in range(num_ring_pgs): | |
ring_ranks = list(range(i * self.ring_degree + offset, (i + 1) * self.ring_degree + offset)) | |
group = dist.new_group(ring_ranks) | |
if self.rank in ring_ranks: | |
self.ring_pg = group | |
for i in range(num_ulysses_pgs): | |
ulysses_ranks = list(range(i + offset, self.sp_degree + offset, num_ulysses_pgs)) | |
group = dist.new_group(ulysses_ranks) | |
if self.rank in ulysses_ranks: | |
self.ulysses_pg = group | |
for sp_rank in range(self.sp_degree): | |
dp_ranks = list(range(sp_rank, self.dp_degree * self.sp_degree, self.sp_degree)) | |
group = dist.new_group(dp_ranks) | |
if self.rank in dp_ranks: | |
self.dp_pg = group | |
for i in range(self.dp_degree): | |
sp_ranks = list(range(i * self.sp_degree, (i + 1) * self.sp_degree)) | |
group = dist.new_group(sp_ranks) | |
if self.rank in sp_ranks: | |
self.sp_pg = group | |
self.ulysses_rank = dist.get_rank(self.ulysses_pg) | |
self.ring_rank = dist.get_rank(self.ring_pg) | |
self.dp_rank = dist.get_rank(self.dp_pg) | |
if use_ulysses_low: | |
self.sp_rank = self.ulysses_rank + self.ring_rank * self.ulysses_degree | |
else: | |
self.sp_rank = self.ring_rank + self.ulysses_rank * self.ring_degree | |
print( | |
f"Rank {self.rank}, GPU {torch.cuda.current_device()} Hybrid SP rank: {self.sp_rank} out of {self.sp_degree} (Ulysses: {self.ulysses_rank}/{self.ulysses_degree}, Ring: {self.ring_rank}/{self.ring_degree})" | |
) | |
print("--------------ProcessGroupManager Initialized---------------------") | |
PROCESS_GROUP_MANAGER = None | |
def set_pg_manager(sp_degree, sp_ring_degree=1, use_ulysses_low=True, ring_type=None): | |
""" | |
Set the process group manager for sequence parallelism. | |
sp_degree = sp_ring_degree x sp_ulysses_degree | |
""" | |
# first check torch distributed group init and set device accordingly; | |
# (DL) TODO: Whether this can be skipped in DeepSpeed. | |
if dist.is_initialized(): | |
if dist.get_rank() == 0: | |
print( | |
"torch distributed is already initialized, " "skipping initialization ...", | |
flush=True, | |
) | |
else: | |
if int(os.environ["RANK"]) == 0: | |
print("Initializing Torch distributed.") | |
dist.init_distributed(dist_backend="nccl", dist_init_required=True) | |
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) | |
torch.cuda.set_device(dist.get_rank() % local_world_size) | |
world_size = dist.get_world_size() | |
assert sp_degree <= world_size | |
assert world_size % sp_degree == 0, f"world_size {world_size} % sp_degree {sp_degree} != 0" | |
if sp_ring_degree < 1: | |
sp_ring_degree = 1 | |
sp_ulysses_degree = sp_degree // sp_ring_degree | |
assert sp_degree % sp_ring_degree == 0, f"sp_degree {sp_degree} % sp_ring_degree {sp_ring_degree} != 0" | |
dp_degree = world_size // sp_degree | |
# Init the process group manager | |
global PROCESS_GROUP_MANAGER | |
PROCESS_GROUP_MANAGER = ProcessGroupManager( | |
sp_ulysses_degree, sp_ring_degree, dp_degree, use_ulysses_low, ring_type | |
) | |
def get_pg_manager(): | |
return PROCESS_GROUP_MANAGER | |
def get_sequence_parallel_size(): | |
"""Get the size of the sequence parallel group.""" | |
return PROCESS_GROUP_MANAGER.sp_degree | |
def get_sequence_parallel_rank(): | |
"""Get the rank of this process in the sequence parallel group the caller rank belongs to.""" | |
return PROCESS_GROUP_MANAGER.sp_rank | |
def get_sequence_parallel_pg(): | |
"""Get the overall sequence parallel process group (include Ring and Ulysses).""" | |
return PROCESS_GROUP_MANAGER.sp_pg | |
def get_ulysses_sp_size(): | |
"""Get the size of the Ulysses sequence parallel group.""" | |
return PROCESS_GROUP_MANAGER.ulysses_degree | |
def get_ulysses_seq_len(): | |
"""Get the size of the Ulysses sequence parallel group.""" | |
return PROCESS_GROUP_MANAGER.ulysses_seq_len | |
def set_ulysses_seq_len(seq_len): | |
"""Get the size of the Ulysses sequence parallel group.""" | |
PROCESS_GROUP_MANAGER.ulysses_seq_len = seq_len | |
def get_ulysses_sp_rank(): | |
"""Get the rank of this process in the Ulysses sequence parallel group the caller rank belongs to.""" | |
return PROCESS_GROUP_MANAGER.ulysses_rank | |
def get_ulysses_sp_pg(): | |
"""Get the Ulysses sequence parallel process group.""" | |
return PROCESS_GROUP_MANAGER.ulysses_pg | |
def get_ring_sp_size(): | |
"""Get the size of the RingAttn sequence parallel group.""" | |
return PROCESS_GROUP_MANAGER.ring_degree | |
def get_ring_sp_rank(): | |
"""Get the rank of this process in the RingAttn sequence parallel group the caller rank belongs to.""" | |
return PROCESS_GROUP_MANAGER.ring_rank | |
def get_ring_sp_pg(): | |
"""Get the RingAttn sequence parallel process group.""" | |
return PROCESS_GROUP_MANAGER.ring_pg | |
def get_ring_type(): | |
"""Get the RingAttn implementation type.""" | |
return PROCESS_GROUP_MANAGER.ring_type | |
def get_data_parallel_size(): | |
"""Get the size of the data parallel group.""" | |
return PROCESS_GROUP_MANAGER.dp_degree | |
def get_data_parallel_rank(): | |
"""Get the rank of this process in the data parallel group the caller rank belongs to.""" | |
return PROCESS_GROUP_MANAGER.dp_rank | |