# Copyright (c) 2025 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. # LICENSE is in incl_licenses directory. # Copyright 2024 NVIDIA CORPORATION & AFFILIATES # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 import copy import json import os import pathlib import re import warnings from dataclasses import dataclass import torch import torch.distributed as dist from accelerate.hooks import add_hook_to_module from transformers import PretrainedConfig, PreTrainedModel from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from llava.train.sequence_parallel.globals import get_pg_manager, get_ulysses_sp_pg def rprint(*args, **kwargs): rank = int(os.environ.get("RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) if world_size > 1 and dist.is_initialized(): return print(f"[dist-{rank}-of-{world_size}]", *args, **kwargs) else: return print(*args, **kwargs) def mprint(*args, **kwargs): rank = int(os.environ.get("RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) if world_size > 1 and dist.is_initialized(): if rank == 0: return print(f"[dist-{rank}-of-{world_size}]", *args, **kwargs) else: return else: return print(*args, **kwargs) def is_local(model_name_or_path: str) -> bool: return os.path.isdir(model_name_or_path) def get_checkpoint_path(output_dir: str, checkpoint_prefix: str = "checkpoint") -> str | None: output_dir = os.path.abspath(output_dir) pathlib_dir = pathlib.Path(output_dir) if list(pathlib_dir.glob("config.json")): # training has been finished return output_dir, False else: try: ordering_and_checkpoint_path = [] glob_checkpoints = [ str(x) for x in pathlib.Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x) ] for path in glob_checkpoints: regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) if regex_match is not None and regex_match.groups() is not None: ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) checkpoints_sorted = sorted(ordering_and_checkpoint_path) return checkpoints_sorted[-1][1], True except: return None, True def prepare_config_for_training( config: PretrainedConfig, model_args: dataclass, training_args: dataclass, data_args: dataclass ) -> None: config.chat_template = model_args.chat_template assert model_args.vision_tower is not None, "requires vision tower" assert model_args.speech_tower is not None, "requires speech tower" assert model_args.sound_tower is not None, "requires sound tower" # set module configurations if getattr(config, "llm_cfg", None) is None: config.llm_cfg = model_args.model_name_or_path if getattr(config, "vision_tower_cfg", None) is None: config.vision_tower_cfg = model_args.vision_tower if getattr(config, "speech_tower_cfg", None) is None: config.speech_tower_cfg = model_args.speech_tower if getattr(config, "sound_tower_cfg", None) is None: config.sound_tower_cfg = model_args.sound_tower if getattr(config, "mm_projector_cfg", None) is None: config.mm_projector_cfg = model_args.mm_projector if getattr(config, "speech_mm_projector_cfg", None) is None: config.speech_mm_projector_cfg = model_args.speech_mm_projector if getattr(config, "sound_mm_projector_cfg", None) is None: config.sound_mm_projector_cfg = model_args.sound_mm_projector # set default dtype config.model_dtype = torch.bfloat16 if training_args.bf16 else torch.float16 config.model_dtype = config.model_dtype.__str__() # set tuning modules config.tune_language_model = training_args.tune_language_model config.tune_vision_tower = training_args.tune_vision_tower config.tune_speech_tower = training_args.tune_speech_tower config.tune_sound_tower = training_args.tune_sound_tower config.tune_mm_projector = training_args.tune_mm_projector config.tune_speech_mm_projector = training_args.tune_speech_mm_projector config.tune_sound_mm_projector = training_args.tune_sound_mm_projector # set data args # Get the image_aspect_ratio from the config if is defined there # (case of resuming from a checkpoint) or from the data_args # (i.e. from the command line when starting a new training). if getattr(data_args, "image_aspect_ratio", None) is not None: if getattr(config, "image_aspect_ratio", None) is None: config.image_aspect_ratio = data_args.image_aspect_ratio elif getattr(config, "image_aspect_ratio", None) is not None: data_args.image_aspect_ratio = config.image_aspect_ratio else: raise ValueError("image_aspect_ratio must be set either in data_args or in the pretrained config") if ( hasattr(training_args, "deepspeed") and training_args.deepspeed is not None and "mics" in training_args.deepspeed ): config.deepspeed = training_args.deepspeed for key, value in model_args.__dict__.items(): try: value = json.loads(value) except: pass setattr(config, key, value) def vision_resolution_elevation(model: PreTrainedModel, config: PretrainedConfig): vision_tower = model.get_vision_tower() if vision_tower is not None and "radio" not in vision_tower.__class__.__name__.lower(): vision_tower._maybe_resize_pos_embeds( model=vision_tower.vision_tower, image_processor=vision_tower.image_processor, resolution=getattr(config, "vision_resolution", -1), interpolate_mode=getattr(config, "interpolate_mode", "linear"), ) def unit_test_rope_scaling(model: PreTrainedModel, config: PretrainedConfig, training_args: dataclass): return False def calculate_loss_weight(labels, ignore_index=-100): # (Qinghao): Weighted loss based on num_active_elements # To achieve accurate sequence parallel loss calculation, we need to get # the real active_elements of each sequence partitions. # For data parallelism, the loss almost remains the same (also more accurate). shift_labels = labels[..., 1:].contiguous() shift_labels = shift_labels.view(-1) padding_mask = shift_labels.eq(ignore_index) # IGNORE_INDEX = -100 by default num_active_elements = padding_mask.numel() - padding_mask.long().sum() # global_active_sum = copy.deepcopy(num_active_elements) global_active_sum = num_active_elements.detach().clone() dist.all_reduce(global_active_sum) loss_weight = num_active_elements / global_active_sum * dist.get_world_size() return loss_weight def reshard_hiddne_states_and_labels(hidden_states, labels): PROCESS_GROUP_MANAGER = get_pg_manager() sp_degree = PROCESS_GROUP_MANAGER.sp_degree sp_rank = PROCESS_GROUP_MANAGER.sp_rank sp_group = PROCESS_GROUP_MANAGER.ulysses_pg from llava.constants import IGNORE_INDEX # Get the seq len on different sp ranks bs, shard_seqlen = labels.shape ulysses_seq_len = [torch.zeros(1, dtype=torch.int64, device=labels.device) for _ in range(sp_degree)] dist.barrier(group=sp_group) dist.all_gather(ulysses_seq_len, torch.tensor(shard_seqlen, device=labels.device), group=sp_group) dist.barrier(group=sp_group) global_seq_len = torch.cat(ulysses_seq_len, dim=0) # Gather all labels and flaten them all_labels = [ torch.zeros(bs, seq_len, dtype=labels.dtype, device=labels.device).contiguous() for seq_len in ulysses_seq_len ] dist.all_gather(all_labels, labels.contiguous(), group=sp_group) # flatten_global_labels = torch.cat(all_labels, dim=1)[:, 1:].view(-1) flatten_global_labels = torch.cat(all_labels, dim=1)[:, 1:].contiguous().view(-1) # Get the label!=IGNORE_INDEX's index flatten_label_mask = flatten_global_labels.ne(IGNORE_INDEX) flatten_effective_label_index = flatten_label_mask.nonzero(as_tuple=True) # padding the effective_label_index if the length is smaller than sp_degree if flatten_effective_label_index[0].shape[0] < sp_degree: warnings.warn( f"The effective label length {flatten_effective_label_index[0].shape[0]} is smaller than sp_degree {sp_degree}, padding the index" ) repeat_num = sp_degree // flatten_effective_label_index[0].shape[0] + 1 else: repeat_num = 1 # Reconstruct the labels by selecting from the global labels effective_global_labels = flatten_global_labels[flatten_effective_label_index] if repeat_num > 1: effective_global_labels = effective_global_labels.repeat(repeat_num) # Global effective seqence length global_effective_seq_len = effective_global_labels.shape[0] reshard_size = global_effective_seq_len // sp_degree # Hyper parameters to reshard the hidden states and labels if sp_rank == 0: original_start_id = 0 original_end_id = torch.sum(global_seq_len[: sp_rank + 1]).item() start_id = 0 end_id = reshard_size * (sp_rank + 1) elif sp_rank == sp_degree - 1: original_start_id = torch.sum(global_seq_len[:sp_rank]).item() original_end_id = torch.sum(global_seq_len[: sp_rank + 1]).item() start_id = reshard_size * sp_rank end_id = global_effective_seq_len else: original_start_id = torch.sum(global_seq_len[:sp_rank]).item() original_end_id = torch.sum(global_seq_len[: sp_rank + 1]).item() start_id = reshard_size * sp_rank end_id = reshard_size * (sp_rank + 1) # Get the local labels effective_local_labels = torch.narrow(effective_global_labels, 0, start_id, end_id - start_id) # Gather all hidden states and flaten them # all_hidden_states = [torch.zeros(bs, seq_len, hidden_states.shape[-1], dtype=hidden_states.dtype, device=hidden_states.device, requires_grad=True).contiguous() for seq_len in ulysses_seq_len] all_hidden_states = torch.zeros( bs, torch.sum(global_seq_len), hidden_states.shape[-1], dtype=hidden_states.dtype, device=hidden_states.device ).contiguous() all_hidden_states[:, original_start_id:original_end_id, :] += hidden_states dist.barrier(group=sp_group) dist.all_reduce(all_hidden_states, group=sp_group) dist.barrier(group=sp_group) flatten_global_hidden_states = all_hidden_states[:, :-1, :].contiguous().view(-1, hidden_states.shape[-1]) # Get the local hidden states effective_flatten_global_hidden_states = flatten_global_hidden_states[flatten_effective_label_index] if repeat_num > 1: effective_flatten_global_hidden_states = effective_flatten_global_hidden_states.repeat(repeat_num, 1) effective_local_hidden_states = torch.narrow(effective_flatten_global_hidden_states, 0, start_id, end_id - start_id) return effective_local_hidden_states, effective_local_labels def sp_loss_rescale(shift_labels, loss): from llava.constants import IGNORE_INDEX PROCESS_GROUP_MANAGER = get_pg_manager() labels_mask = shift_labels.ne(IGNORE_INDEX) # IGNORE_INDEX = -100 by default num_active_elements = torch.sum(labels_mask) global_active_sum = copy.deepcopy(num_active_elements) # dist.barrier(group=get_ulysses_sp_pg()) dist.all_reduce(global_active_sum, group=get_ulysses_sp_pg()) # print(loss.shape, num_active_elements.shape, global_active_sum.shape) loss = loss * num_active_elements / global_active_sum dist.all_reduce(loss, group=get_ulysses_sp_pg()) return loss