# Copyright (c) 2025 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. # LICENSE is in incl_licenses directory. # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import logging import math import os import warnings from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Sequence import torch import transformers from torch.utils.data import Dataset from transformers import AutoConfig, AutoTokenizer, HfArgumentParser, LlamaForCausalLM, set_seed from transformers.modeling_utils import unwrap_model import llava.data.dataset as dataset import llava.data.datasets_mixture as datasets_mixture from llava import conversation as conversation_lib from llava.constants import ( DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX, ) from llava.data import make_supervised_data_module from llava.mm_utils import process_image from llava.model import LlavaLlamaConfig, LlavaLlamaModel from llava.train.args import DataArguments, ModelArguments, TrainingArguments from llava.train.callbacks.autoresume_callback import AutoResumeCallback from llava.train.llava_trainer import LLaVATrainer, VILADPOTrainer from llava.train.sequence_parallel import set_pg_manager from llava.train.slurm_utils import TimeoutTerminateCallback from llava.train.utils import ( get_checkpoint_path, mprint, prepare_config_for_training, unit_test_rope_scaling, vision_resolution_elevation, ) from llava.trl.trainer.utils import DPODataCollatorWithPadding local_rank = None if "WANDB_PROJECT" not in os.environ: os.environ["WANDB_PROJECT"] = "AF3" def get_nb_trainable_parameters(model) -> tuple[int, int]: r""" Returns the number of trainable parameters and the number of all parameters in the model. """ trainable_params = 0 all_param = 0 for _, param in model.named_parameters(): num_params = param.numel() # if using DS Zero 3 and the weights are initialized empty if num_params == 0 and hasattr(param, "ds_numel"): num_params = param.ds_numel # Due to the design of 4bit linear layers from bitsandbytes # one needs to multiply the number of parameters by 2 to get # the correct number of parameters if param.__class__.__name__ == "Params4bit": if hasattr(param, "element_size"): num_bytes = param.element_size() elif not hasattr(param, "quant_storage"): num_bytes = 1 else: num_bytes = param.quant_storage.itemsize num_params = num_params * 2 * num_bytes all_param += num_params if param.requires_grad: trainable_params += num_params return trainable_params, all_param def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param # Borrowed from peft.utils.get_peft_model_state_dict def get_peft_state_maybe_zero_3(named_params, bias): if bias == "none": to_return = {k: t for k, t in named_params if "lora_" in k} elif bias == "all": to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} maybe_lora_bias = {} lora_bias_names = set() for k, t in named_params: if "lora_" in k: to_return[k] = t bias_name = k.split("lora_")[0] + "bias" lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t for k, t in maybe_lora_bias: if bias_name in lora_bias_names: to_return[bias_name] = t else: raise NotImplementedError to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} return to_return def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): to_return = {k: t for k, t in named_params if "lora_" not in k} if require_grad_only: to_return = {k: t for k, t in to_return.items() if t.requires_grad} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def find_all_linear_names(model, lora_llm, lora_vt): cls = torch.nn.Linear lora_module_names = set() multimodal_keywords = ["mm_projector", "vision_resampler"] assert lora_llm or lora_vt, "Not applying LoRA to any of the modules..." if not lora_llm: multimodal_keywords += ["llm"] if not lora_vt: multimodal_keywords += ["vision_tower"] for name, module in model.named_modules(): if any(mm_keyword in name for mm_keyword in multimodal_keywords): continue if isinstance(module, cls): if not "lm_head" in name: lora_module_names.add(name) # names = name.split(".") # lora_module_names.add(names[0] if len(names) == 1 else names[-1]) # if "lm_head" in lora_module_names: # needed for 16-bit # lora_module_names.remove("lm_head") return list(lora_module_names) def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """Collects the state dict and dump to disk.""" if trainer.deepspeed: torch.cuda.synchronize() trainer.save_model(output_dir, _internal_call=True) return state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ): """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg def make_conv(prompt, answer): return [ { "from": "human", "value": prompt, }, { "from": "gpt", "value": answer, }, ] @dataclass class DPODataCollator(DPODataCollatorWithPadding): tokenizer: Any = None def collate(self, batch): # first, pad everything to the same length # input_ids, labels = tuple([instance[key] for instance in instances] # for key in ("input_ids", "labels")) # input_ids = torch.nn.utils.rnn.pad_sequence( # input_ids, # batch_first=True, # padding_value=self.tokenizer.pad_token_id) # labels = torch.nn.utils.rnn.pad_sequence(labels, # batch_first=True, # padding_value=IGNORE_INDEX) # input_ids = input_ids[:, :self.tokenizer.model_max_length] # labels = labels[:, :self.tokenizer.model_max_length] # batch = dict( # input_ids=input_ids, # labels=labels, # attention_mask=input_ids.ne(self.tokenizer.pad_token_id), # ) padded_batch = {} for k in batch[0].keys(): if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): # if "prompt" in k: # to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] # else: to_pad = [torch.LongTensor(ex[k]) for ex in batch] if k.endswith("_input_ids"): padding_value = self.pad_token_id elif k.endswith("_labels"): padding_value = self.label_pad_token_id else: continue # elif k.endswith("_attention_mask"): # padding_value = self.padding_value # else: # raise ValueError(f"Unexpected key in batch '{k}'") padded_batch[k] = torch.nn.utils.rnn.pad_sequence(to_pad, batch_first=True, padding_value=padding_value) # for the prompt, flip back so padding is on left side # if "prompt" in k: # padded_batch[k] = padded_batch[k].flip(dims=[1]) else: padded_batch[k] = [ex[k] for ex in batch] for k in ["chosen_input_ids", "rejected_input_ids"]: attn_k = k.replace("input_ids", "attention_mask") padded_batch[attn_k] = padded_batch[k].ne(self.pad_token_id) return padded_batch def tokenize_batch_element(self, prompt: str, chosen: str, rejected: str) -> Dict: """Tokenize a single batch element. At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt + chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long, we truncate the chosen/rejected. We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens. """ # import pdb; pdb.set_trace() batch = {} chosen_sources = make_conv(prompt, chosen) rejected_sources = make_conv(prompt, rejected) chosen_data_dict = dataset.preprocess([chosen_sources], self.tokenizer, has_image=True) # chosen_data_dict['attention_mask'] = chosen_data_dict["input_ids"].ne(self.tokenizer.pad_token_id) rejected_data_dict = dataset.preprocess([rejected_sources], self.tokenizer, has_image=True) # rejected_data_dict['attention_mask'] = rejected_data_dict["input_ids"].ne(self.tokenizer.pad_token_id) chosen_data_dict = {k: v[0] for k, v in chosen_data_dict.items()} rejected_data_dict = {k: v[0] for k, v in rejected_data_dict.items()} for k, toks in { "chosen": chosen_data_dict, "rejected": rejected_data_dict, }.items(): for type_key, tokens in toks.items(): if type_key == "token_type_ids": continue batch[f"{k}_{type_key}"] = tokens return batch def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: tokenized_batch = [] Xs, keys = [], [] for feature in features: prompt = feature["prompt"] chosen = feature["chosen"] rejected = feature["rejected"] batch_element = self.tokenize_batch_element(prompt, chosen, rejected) batch_element["images"] = feature["images"] tokenized_batch.append(batch_element) # return collated batch padded_batch = self.collate(tokenized_batch) return padded_batch import json def load_jsonl(save_path): with open(save_path) as f: data = [json.loads(line) for line in f.readlines()] return data def load_json(path): with open(path) as f: data = json.load(f) return data def load_data(data_path): if "jsonl" in data_path: data_list = load_jsonl(data_path) else: data_list = load_json(data_path) return data_list class DPODataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_mixture: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments): super(Dataset, self).__init__() data_path = datasets_mixture.DATASETS_LEGACY[data_mixture].data_path list_data_dict = load_data(data_path) # if data_args.num_sample is not None: # list_data_dict = list_data_dict[:data_args.num_sample] print("Formatting inputs...Skip in lazy mode") self.tokenizer = tokenizer self.list_data_dict = list_data_dict self.data_args = data_args self.image_folder = datasets_mixture.DATASETS_LEGACY[data_mixture].image_path def __len__(self): # return 20 return len(self.list_data_dict) @property def lengths(self): length_list = [] for sample in self.list_data_dict: img_tokens = 128 if "image" in sample else 0 length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens) return length_list def __getitem__(self, i) -> Dict[str, torch.Tensor]: """ { 'prompt': 'Is there a snowman wearing a green scarf and hat in the background?', 'chosen': 'No, there is no snowman wearing a green scarf and hat in the background of the image. The image features a person ...', 'rejected': 'No, there is no snowman in the background.', 'image_path': '/mnt/bn/liangkeg/data/ruohongz/dpo_data/dpo_images/LRVInstruction-000000009569.jpg', 'image_name': 'LRVInstruction-000000009569.jpg' } """ # sources = self.list_data_dict[i] # if isinstance(i, int): # sources = [sources] # assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME data_dict = copy.deepcopy(self.list_data_dict[i]) # inplace modification following video_file = data_dict["video"] + ".mp4" video_folder = self.image_folder video_path = os.path.join(video_folder, video_file) num_video_frames = self.data_args.num_video_frames if hasattr(self.data_args, "num_video_frames") else 8 loader_fps = self.data_args.fps if hasattr(self.data_args, "fps") else 0.0 fps = None frame_count = None images, frames_loaded = dataset.LazySupervisedDataset._load_video( video_path, num_video_frames, loader_fps, self.data_args, fps=fps, frame_count=frame_count ) image_tensor = torch.stack([process_image(image, self.data_args, None) for image in images]) image_tensor = torch.stack([process_image(image, self.data_args, None) for image in images]) data_dict["images"] = image_tensor prompt = data_dict["prompt"] prompt = prompt.replace("