SreyanG-NVIDIA's picture
Upload 225 files
174ae06 verified
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import copy
import json
import os
import pathlib
import re
import warnings
from dataclasses import dataclass
import torch
import torch.distributed as dist
from accelerate.hooks import add_hook_to_module
from transformers import PretrainedConfig, PreTrainedModel
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from llava.train.sequence_parallel.globals import get_pg_manager, get_ulysses_sp_pg
def rprint(*args, **kwargs):
rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
if world_size > 1 and dist.is_initialized():
return print(f"[dist-{rank}-of-{world_size}]", *args, **kwargs)
else:
return print(*args, **kwargs)
def mprint(*args, **kwargs):
rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
if world_size > 1 and dist.is_initialized():
if rank == 0:
return print(f"[dist-{rank}-of-{world_size}]", *args, **kwargs)
else:
return
else:
return print(*args, **kwargs)
def is_local(model_name_or_path: str) -> bool:
return os.path.isdir(model_name_or_path)
def get_checkpoint_path(output_dir: str, checkpoint_prefix: str = "checkpoint") -> str | None:
output_dir = os.path.abspath(output_dir)
pathlib_dir = pathlib.Path(output_dir)
if list(pathlib_dir.glob("config.json")):
# training has been finished
return output_dir, False
else:
try:
ordering_and_checkpoint_path = []
glob_checkpoints = [
str(x) for x in pathlib.Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)
]
for path in glob_checkpoints:
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
if regex_match is not None and regex_match.groups() is not None:
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
return checkpoints_sorted[-1][1], True
except:
return None, True
def prepare_config_for_training(
config: PretrainedConfig, model_args: dataclass, training_args: dataclass, data_args: dataclass
) -> None:
config.chat_template = model_args.chat_template
assert model_args.vision_tower is not None, "requires vision tower"
assert model_args.speech_tower is not None, "requires speech tower"
assert model_args.sound_tower is not None, "requires sound tower"
# set module configurations
if getattr(config, "llm_cfg", None) is None:
config.llm_cfg = model_args.model_name_or_path
if getattr(config, "vision_tower_cfg", None) is None:
config.vision_tower_cfg = model_args.vision_tower
if getattr(config, "speech_tower_cfg", None) is None:
config.speech_tower_cfg = model_args.speech_tower
if getattr(config, "sound_tower_cfg", None) is None:
config.sound_tower_cfg = model_args.sound_tower
if getattr(config, "mm_projector_cfg", None) is None:
config.mm_projector_cfg = model_args.mm_projector
if getattr(config, "speech_mm_projector_cfg", None) is None:
config.speech_mm_projector_cfg = model_args.speech_mm_projector
if getattr(config, "sound_mm_projector_cfg", None) is None:
config.sound_mm_projector_cfg = model_args.sound_mm_projector
# set default dtype
config.model_dtype = torch.bfloat16 if training_args.bf16 else torch.float16
config.model_dtype = config.model_dtype.__str__()
# set tuning modules
config.tune_language_model = training_args.tune_language_model
config.tune_vision_tower = training_args.tune_vision_tower
config.tune_speech_tower = training_args.tune_speech_tower
config.tune_sound_tower = training_args.tune_sound_tower
config.tune_mm_projector = training_args.tune_mm_projector
config.tune_speech_mm_projector = training_args.tune_speech_mm_projector
config.tune_sound_mm_projector = training_args.tune_sound_mm_projector
# set data args
# Get the image_aspect_ratio from the config if is defined there
# (case of resuming from a checkpoint) or from the data_args
# (i.e. from the command line when starting a new training).
if getattr(data_args, "image_aspect_ratio", None) is not None:
if getattr(config, "image_aspect_ratio", None) is None:
config.image_aspect_ratio = data_args.image_aspect_ratio
elif getattr(config, "image_aspect_ratio", None) is not None:
data_args.image_aspect_ratio = config.image_aspect_ratio
else:
raise ValueError("image_aspect_ratio must be set either in data_args or in the pretrained config")
if (
hasattr(training_args, "deepspeed")
and training_args.deepspeed is not None
and "mics" in training_args.deepspeed
):
config.deepspeed = training_args.deepspeed
for key, value in model_args.__dict__.items():
try:
value = json.loads(value)
except:
pass
setattr(config, key, value)
def vision_resolution_elevation(model: PreTrainedModel, config: PretrainedConfig):
vision_tower = model.get_vision_tower()
if vision_tower is not None and "radio" not in vision_tower.__class__.__name__.lower():
vision_tower._maybe_resize_pos_embeds(
model=vision_tower.vision_tower,
image_processor=vision_tower.image_processor,
resolution=getattr(config, "vision_resolution", -1),
interpolate_mode=getattr(config, "interpolate_mode", "linear"),
)
def unit_test_rope_scaling(model: PreTrainedModel, config: PretrainedConfig, training_args: dataclass):
return False
def calculate_loss_weight(labels, ignore_index=-100):
# (Qinghao): Weighted loss based on num_active_elements
# To achieve accurate sequence parallel loss calculation, we need to get
# the real active_elements of each sequence partitions.
# For data parallelism, the loss almost remains the same (also more accurate).
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
padding_mask = shift_labels.eq(ignore_index) # IGNORE_INDEX = -100 by default
num_active_elements = padding_mask.numel() - padding_mask.long().sum()
# global_active_sum = copy.deepcopy(num_active_elements)
global_active_sum = num_active_elements.detach().clone()
dist.all_reduce(global_active_sum)
loss_weight = num_active_elements / global_active_sum * dist.get_world_size()
return loss_weight
def reshard_hiddne_states_and_labels(hidden_states, labels):
PROCESS_GROUP_MANAGER = get_pg_manager()
sp_degree = PROCESS_GROUP_MANAGER.sp_degree
sp_rank = PROCESS_GROUP_MANAGER.sp_rank
sp_group = PROCESS_GROUP_MANAGER.ulysses_pg
from llava.constants import IGNORE_INDEX
# Get the seq len on different sp ranks
bs, shard_seqlen = labels.shape
ulysses_seq_len = [torch.zeros(1, dtype=torch.int64, device=labels.device) for _ in range(sp_degree)]
dist.barrier(group=sp_group)
dist.all_gather(ulysses_seq_len, torch.tensor(shard_seqlen, device=labels.device), group=sp_group)
dist.barrier(group=sp_group)
global_seq_len = torch.cat(ulysses_seq_len, dim=0)
# Gather all labels and flaten them
all_labels = [
torch.zeros(bs, seq_len, dtype=labels.dtype, device=labels.device).contiguous() for seq_len in ulysses_seq_len
]
dist.all_gather(all_labels, labels.contiguous(), group=sp_group)
# flatten_global_labels = torch.cat(all_labels, dim=1)[:, 1:].view(-1)
flatten_global_labels = torch.cat(all_labels, dim=1)[:, 1:].contiguous().view(-1)
# Get the label!=IGNORE_INDEX's index
flatten_label_mask = flatten_global_labels.ne(IGNORE_INDEX)
flatten_effective_label_index = flatten_label_mask.nonzero(as_tuple=True)
# padding the effective_label_index if the length is smaller than sp_degree
if flatten_effective_label_index[0].shape[0] < sp_degree:
warnings.warn(
f"The effective label length {flatten_effective_label_index[0].shape[0]} is smaller than sp_degree {sp_degree}, padding the index"
)
repeat_num = sp_degree // flatten_effective_label_index[0].shape[0] + 1
else:
repeat_num = 1
# Reconstruct the labels by selecting from the global labels
effective_global_labels = flatten_global_labels[flatten_effective_label_index]
if repeat_num > 1:
effective_global_labels = effective_global_labels.repeat(repeat_num)
# Global effective seqence length
global_effective_seq_len = effective_global_labels.shape[0]
reshard_size = global_effective_seq_len // sp_degree
# Hyper parameters to reshard the hidden states and labels
if sp_rank == 0:
original_start_id = 0
original_end_id = torch.sum(global_seq_len[: sp_rank + 1]).item()
start_id = 0
end_id = reshard_size * (sp_rank + 1)
elif sp_rank == sp_degree - 1:
original_start_id = torch.sum(global_seq_len[:sp_rank]).item()
original_end_id = torch.sum(global_seq_len[: sp_rank + 1]).item()
start_id = reshard_size * sp_rank
end_id = global_effective_seq_len
else:
original_start_id = torch.sum(global_seq_len[:sp_rank]).item()
original_end_id = torch.sum(global_seq_len[: sp_rank + 1]).item()
start_id = reshard_size * sp_rank
end_id = reshard_size * (sp_rank + 1)
# Get the local labels
effective_local_labels = torch.narrow(effective_global_labels, 0, start_id, end_id - start_id)
# Gather all hidden states and flaten them
# all_hidden_states = [torch.zeros(bs, seq_len, hidden_states.shape[-1], dtype=hidden_states.dtype, device=hidden_states.device, requires_grad=True).contiguous() for seq_len in ulysses_seq_len]
all_hidden_states = torch.zeros(
bs, torch.sum(global_seq_len), hidden_states.shape[-1], dtype=hidden_states.dtype, device=hidden_states.device
).contiguous()
all_hidden_states[:, original_start_id:original_end_id, :] += hidden_states
dist.barrier(group=sp_group)
dist.all_reduce(all_hidden_states, group=sp_group)
dist.barrier(group=sp_group)
flatten_global_hidden_states = all_hidden_states[:, :-1, :].contiguous().view(-1, hidden_states.shape[-1])
# Get the local hidden states
effective_flatten_global_hidden_states = flatten_global_hidden_states[flatten_effective_label_index]
if repeat_num > 1:
effective_flatten_global_hidden_states = effective_flatten_global_hidden_states.repeat(repeat_num, 1)
effective_local_hidden_states = torch.narrow(effective_flatten_global_hidden_states, 0, start_id, end_id - start_id)
return effective_local_hidden_states, effective_local_labels
def sp_loss_rescale(shift_labels, loss):
from llava.constants import IGNORE_INDEX
PROCESS_GROUP_MANAGER = get_pg_manager()
labels_mask = shift_labels.ne(IGNORE_INDEX) # IGNORE_INDEX = -100 by default
num_active_elements = torch.sum(labels_mask)
global_active_sum = copy.deepcopy(num_active_elements)
# dist.barrier(group=get_ulysses_sp_pg())
dist.all_reduce(global_active_sum, group=get_ulysses_sp_pg())
# print(loss.shape, num_active_elements.shape, global_active_sum.shape)
loss = loss * num_active_elements / global_active_sum
dist.all_reduce(loss, group=get_ulysses_sp_pg())
return loss