Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
import copy | |
import json | |
import os | |
import pathlib | |
import re | |
import warnings | |
from dataclasses import dataclass | |
import torch | |
import torch.distributed as dist | |
from accelerate.hooks import add_hook_to_module | |
from transformers import PretrainedConfig, PreTrainedModel | |
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled | |
from llava.train.sequence_parallel.globals import get_pg_manager, get_ulysses_sp_pg | |
def rprint(*args, **kwargs): | |
rank = int(os.environ.get("RANK", 0)) | |
world_size = int(os.environ.get("WORLD_SIZE", 1)) | |
if world_size > 1 and dist.is_initialized(): | |
return print(f"[dist-{rank}-of-{world_size}]", *args, **kwargs) | |
else: | |
return print(*args, **kwargs) | |
def mprint(*args, **kwargs): | |
rank = int(os.environ.get("RANK", 0)) | |
world_size = int(os.environ.get("WORLD_SIZE", 1)) | |
if world_size > 1 and dist.is_initialized(): | |
if rank == 0: | |
return print(f"[dist-{rank}-of-{world_size}]", *args, **kwargs) | |
else: | |
return | |
else: | |
return print(*args, **kwargs) | |
def is_local(model_name_or_path: str) -> bool: | |
return os.path.isdir(model_name_or_path) | |
def get_checkpoint_path(output_dir: str, checkpoint_prefix: str = "checkpoint") -> str | None: | |
output_dir = os.path.abspath(output_dir) | |
pathlib_dir = pathlib.Path(output_dir) | |
if list(pathlib_dir.glob("config.json")): | |
# training has been finished | |
return output_dir, False | |
else: | |
try: | |
ordering_and_checkpoint_path = [] | |
glob_checkpoints = [ | |
str(x) for x in pathlib.Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x) | |
] | |
for path in glob_checkpoints: | |
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) | |
if regex_match is not None and regex_match.groups() is not None: | |
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) | |
checkpoints_sorted = sorted(ordering_and_checkpoint_path) | |
return checkpoints_sorted[-1][1], True | |
except: | |
return None, True | |
def prepare_config_for_training( | |
config: PretrainedConfig, model_args: dataclass, training_args: dataclass, data_args: dataclass | |
) -> None: | |
config.chat_template = model_args.chat_template | |
assert model_args.vision_tower is not None, "requires vision tower" | |
assert model_args.speech_tower is not None, "requires speech tower" | |
assert model_args.sound_tower is not None, "requires sound tower" | |
# set module configurations | |
if getattr(config, "llm_cfg", None) is None: | |
config.llm_cfg = model_args.model_name_or_path | |
if getattr(config, "vision_tower_cfg", None) is None: | |
config.vision_tower_cfg = model_args.vision_tower | |
if getattr(config, "speech_tower_cfg", None) is None: | |
config.speech_tower_cfg = model_args.speech_tower | |
if getattr(config, "sound_tower_cfg", None) is None: | |
config.sound_tower_cfg = model_args.sound_tower | |
if getattr(config, "mm_projector_cfg", None) is None: | |
config.mm_projector_cfg = model_args.mm_projector | |
if getattr(config, "speech_mm_projector_cfg", None) is None: | |
config.speech_mm_projector_cfg = model_args.speech_mm_projector | |
if getattr(config, "sound_mm_projector_cfg", None) is None: | |
config.sound_mm_projector_cfg = model_args.sound_mm_projector | |
# set default dtype | |
config.model_dtype = torch.bfloat16 if training_args.bf16 else torch.float16 | |
config.model_dtype = config.model_dtype.__str__() | |
# set tuning modules | |
config.tune_language_model = training_args.tune_language_model | |
config.tune_vision_tower = training_args.tune_vision_tower | |
config.tune_speech_tower = training_args.tune_speech_tower | |
config.tune_sound_tower = training_args.tune_sound_tower | |
config.tune_mm_projector = training_args.tune_mm_projector | |
config.tune_speech_mm_projector = training_args.tune_speech_mm_projector | |
config.tune_sound_mm_projector = training_args.tune_sound_mm_projector | |
# set data args | |
# Get the image_aspect_ratio from the config if is defined there | |
# (case of resuming from a checkpoint) or from the data_args | |
# (i.e. from the command line when starting a new training). | |
if getattr(data_args, "image_aspect_ratio", None) is not None: | |
if getattr(config, "image_aspect_ratio", None) is None: | |
config.image_aspect_ratio = data_args.image_aspect_ratio | |
elif getattr(config, "image_aspect_ratio", None) is not None: | |
data_args.image_aspect_ratio = config.image_aspect_ratio | |
else: | |
raise ValueError("image_aspect_ratio must be set either in data_args or in the pretrained config") | |
if ( | |
hasattr(training_args, "deepspeed") | |
and training_args.deepspeed is not None | |
and "mics" in training_args.deepspeed | |
): | |
config.deepspeed = training_args.deepspeed | |
for key, value in model_args.__dict__.items(): | |
try: | |
value = json.loads(value) | |
except: | |
pass | |
setattr(config, key, value) | |
def vision_resolution_elevation(model: PreTrainedModel, config: PretrainedConfig): | |
vision_tower = model.get_vision_tower() | |
if vision_tower is not None and "radio" not in vision_tower.__class__.__name__.lower(): | |
vision_tower._maybe_resize_pos_embeds( | |
model=vision_tower.vision_tower, | |
image_processor=vision_tower.image_processor, | |
resolution=getattr(config, "vision_resolution", -1), | |
interpolate_mode=getattr(config, "interpolate_mode", "linear"), | |
) | |
def unit_test_rope_scaling(model: PreTrainedModel, config: PretrainedConfig, training_args: dataclass): | |
return False | |
def calculate_loss_weight(labels, ignore_index=-100): | |
# (Qinghao): Weighted loss based on num_active_elements | |
# To achieve accurate sequence parallel loss calculation, we need to get | |
# the real active_elements of each sequence partitions. | |
# For data parallelism, the loss almost remains the same (also more accurate). | |
shift_labels = labels[..., 1:].contiguous() | |
shift_labels = shift_labels.view(-1) | |
padding_mask = shift_labels.eq(ignore_index) # IGNORE_INDEX = -100 by default | |
num_active_elements = padding_mask.numel() - padding_mask.long().sum() | |
# global_active_sum = copy.deepcopy(num_active_elements) | |
global_active_sum = num_active_elements.detach().clone() | |
dist.all_reduce(global_active_sum) | |
loss_weight = num_active_elements / global_active_sum * dist.get_world_size() | |
return loss_weight | |
def reshard_hiddne_states_and_labels(hidden_states, labels): | |
PROCESS_GROUP_MANAGER = get_pg_manager() | |
sp_degree = PROCESS_GROUP_MANAGER.sp_degree | |
sp_rank = PROCESS_GROUP_MANAGER.sp_rank | |
sp_group = PROCESS_GROUP_MANAGER.ulysses_pg | |
from llava.constants import IGNORE_INDEX | |
# Get the seq len on different sp ranks | |
bs, shard_seqlen = labels.shape | |
ulysses_seq_len = [torch.zeros(1, dtype=torch.int64, device=labels.device) for _ in range(sp_degree)] | |
dist.barrier(group=sp_group) | |
dist.all_gather(ulysses_seq_len, torch.tensor(shard_seqlen, device=labels.device), group=sp_group) | |
dist.barrier(group=sp_group) | |
global_seq_len = torch.cat(ulysses_seq_len, dim=0) | |
# Gather all labels and flaten them | |
all_labels = [ | |
torch.zeros(bs, seq_len, dtype=labels.dtype, device=labels.device).contiguous() for seq_len in ulysses_seq_len | |
] | |
dist.all_gather(all_labels, labels.contiguous(), group=sp_group) | |
# flatten_global_labels = torch.cat(all_labels, dim=1)[:, 1:].view(-1) | |
flatten_global_labels = torch.cat(all_labels, dim=1)[:, 1:].contiguous().view(-1) | |
# Get the label!=IGNORE_INDEX's index | |
flatten_label_mask = flatten_global_labels.ne(IGNORE_INDEX) | |
flatten_effective_label_index = flatten_label_mask.nonzero(as_tuple=True) | |
# padding the effective_label_index if the length is smaller than sp_degree | |
if flatten_effective_label_index[0].shape[0] < sp_degree: | |
warnings.warn( | |
f"The effective label length {flatten_effective_label_index[0].shape[0]} is smaller than sp_degree {sp_degree}, padding the index" | |
) | |
repeat_num = sp_degree // flatten_effective_label_index[0].shape[0] + 1 | |
else: | |
repeat_num = 1 | |
# Reconstruct the labels by selecting from the global labels | |
effective_global_labels = flatten_global_labels[flatten_effective_label_index] | |
if repeat_num > 1: | |
effective_global_labels = effective_global_labels.repeat(repeat_num) | |
# Global effective seqence length | |
global_effective_seq_len = effective_global_labels.shape[0] | |
reshard_size = global_effective_seq_len // sp_degree | |
# Hyper parameters to reshard the hidden states and labels | |
if sp_rank == 0: | |
original_start_id = 0 | |
original_end_id = torch.sum(global_seq_len[: sp_rank + 1]).item() | |
start_id = 0 | |
end_id = reshard_size * (sp_rank + 1) | |
elif sp_rank == sp_degree - 1: | |
original_start_id = torch.sum(global_seq_len[:sp_rank]).item() | |
original_end_id = torch.sum(global_seq_len[: sp_rank + 1]).item() | |
start_id = reshard_size * sp_rank | |
end_id = global_effective_seq_len | |
else: | |
original_start_id = torch.sum(global_seq_len[:sp_rank]).item() | |
original_end_id = torch.sum(global_seq_len[: sp_rank + 1]).item() | |
start_id = reshard_size * sp_rank | |
end_id = reshard_size * (sp_rank + 1) | |
# Get the local labels | |
effective_local_labels = torch.narrow(effective_global_labels, 0, start_id, end_id - start_id) | |
# Gather all hidden states and flaten them | |
# all_hidden_states = [torch.zeros(bs, seq_len, hidden_states.shape[-1], dtype=hidden_states.dtype, device=hidden_states.device, requires_grad=True).contiguous() for seq_len in ulysses_seq_len] | |
all_hidden_states = torch.zeros( | |
bs, torch.sum(global_seq_len), hidden_states.shape[-1], dtype=hidden_states.dtype, device=hidden_states.device | |
).contiguous() | |
all_hidden_states[:, original_start_id:original_end_id, :] += hidden_states | |
dist.barrier(group=sp_group) | |
dist.all_reduce(all_hidden_states, group=sp_group) | |
dist.barrier(group=sp_group) | |
flatten_global_hidden_states = all_hidden_states[:, :-1, :].contiguous().view(-1, hidden_states.shape[-1]) | |
# Get the local hidden states | |
effective_flatten_global_hidden_states = flatten_global_hidden_states[flatten_effective_label_index] | |
if repeat_num > 1: | |
effective_flatten_global_hidden_states = effective_flatten_global_hidden_states.repeat(repeat_num, 1) | |
effective_local_hidden_states = torch.narrow(effective_flatten_global_hidden_states, 0, start_id, end_id - start_id) | |
return effective_local_hidden_states, effective_local_labels | |
def sp_loss_rescale(shift_labels, loss): | |
from llava.constants import IGNORE_INDEX | |
PROCESS_GROUP_MANAGER = get_pg_manager() | |
labels_mask = shift_labels.ne(IGNORE_INDEX) # IGNORE_INDEX = -100 by default | |
num_active_elements = torch.sum(labels_mask) | |
global_active_sum = copy.deepcopy(num_active_elements) | |
# dist.barrier(group=get_ulysses_sp_pg()) | |
dist.all_reduce(global_active_sum, group=get_ulysses_sp_pg()) | |
# print(loss.shape, num_active_elements.shape, global_active_sum.shape) | |
loss = loss * num_active_elements / global_active_sum | |
dist.all_reduce(loss, group=get_ulysses_sp_pg()) | |
return loss | |