import os import torch import datasets import math import random import torch.nn as nn from typing import Any, Dict, List, Optional, Union from torch.utils.data import Sampler from torch.utils.data import Sampler, Dataset from transformers import Trainer from transformers.trainer import ( is_sagemaker_mp_enabled, get_parameter_names, has_length, ALL_LAYERNORM_LAYERS, logger, ) from transformers.trainer import Trainer, is_datasets_available from typing import List, Optional from .modeling_utils import evaluate_generation, evaluate_perplexity def get_vision_tower_state_maybe_zero_3(named_params, keys_to_match=['']): to_return = {k: t for k, t in named_params if any( key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: print(name, 'no ignore status') with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} return to_return def split_to_even_chunks(indices, lengths, num_chunks): """ Split a list of indices into `chunks` chunks of roughly equal lengths. """ if len(indices) % num_chunks != 0: return [indices[i::num_chunks] for i in range(num_chunks)] num_indices_per_chunk = len(indices) // num_chunks chunks = [[] for _ in range(num_chunks)] chunks_lengths = [0 for _ in range(num_chunks)] for index in indices: shortest_chunk = chunks_lengths.index(min(chunks_lengths)) chunks[shortest_chunk].append(index) chunks_lengths[shortest_chunk] += lengths[index] if len(chunks[shortest_chunk]) == num_indices_per_chunk: chunks_lengths[shortest_chunk] = float("inf") return chunks def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): # We need to use torch for the random part as a distributed sampler will set the random seed for torch. assert all(l != 0 for l in lengths), "Should not have zero length." if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): # all samples are in the same modality return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] megabatch_size = world_size * batch_size mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] last_mm = mm_megabatches[-1] last_lang = lang_megabatches[-1] additional_batch = last_mm + last_lang megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] megabatch_indices = torch.randperm(len(megabatches), generator=generator) megabatches = [megabatches[i] for i in megabatch_indices] if len(additional_batch) > 0: megabatches.append(sorted(additional_batch)) return [i for megabatch in megabatches for i in megabatch] def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): # We need to use torch for the random part as a distributed sampler will set the random seed for torch. indices = torch.randperm(len(lengths), generator=generator) megabatch_size = world_size * batch_size megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] return [i for megabatch in megabatches for batch in megabatch for i in batch] class LengthGroupedSampler(Sampler): r""" Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while keeping a bit of randomness. """ def __init__( self, batch_size: int, world_size: int, lengths: Optional[List[int]] = None, generator=None, group_by_modality: bool = False, ): if lengths is None: raise ValueError("Lengths must be provided.") self.batch_size = batch_size self.world_size = world_size self.lengths = lengths self.generator = generator self.group_by_modality = group_by_modality def __len__(self): return len(self.lengths) def __iter__(self): if self.group_by_modality: indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) else: indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) return iter(indices) class StrideGroupedSampler(Sampler): """Group """ def __init__( self, batch_size: int, window: int, stride: int, group: str, sort: Optional[str] = None, dataset: Optional[Dataset] = None, lengths: Optional[List[int]] = None, model_input_name: Optional[str] = None ): if dataset is None and lengths is None: raise ValueError("One of dataset and lengths must be provided.") if group is None: raise ValueError("Group cannot be None!") if lengths is None: model_input_name = model_input_name if model_input_name is not None else "input_ids" if ( not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding)) or model_input_name not in dataset[0] ): raise ValueError( "Can only automatically infer lengths for datasets whose items are dictionaries with an " f"'{model_input_name}' key." ) lengths = [len(feature[model_input_name]) for feature in dataset] elif isinstance(lengths, torch.Tensor): logger.info( "If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]..." ) lengths = lengths.tolist() indices = list(range(len(lengths))) # get number of strides for each data num_strides = [] for length in lengths: num_stride = math.ceil((length - window) / stride) + 1 num_strides.append(num_stride) indice_stride_pairs = list(zip(indices, num_strides)) # NOTE: shuffle the indices in advance, otherwise the randomness may be lost when all num_strides are equal random.shuffle(indice_stride_pairs) # sort data according to the number of strides indice_stride_pairs = sorted(indice_stride_pairs, key=lambda x: x[1]) # group data instances with the same number of strides into the same batch batches = [] batch = [] prev_num_stride = None for index, num_stride in indice_stride_pairs: if num_stride != prev_num_stride: # in strict mode, all instances in the batch are forced to have the same number of strides if group == "strict": batch.clear() elif group == "relaxed": pass else: raise ValueError(f"Group method {group} must be in None, strict, relaxed!") batch.append(index) prev_num_stride = num_stride if len(batch) == batch_size: batches.append((batch.copy(), num_stride)) batch.clear() if len(batch) and group == "relaxed": batches.append((batch.copy(), num_stride)) if sort is None: random.shuffle(batches) elif sort == "ascend": batches = sorted(batches, key=lambda x: x[1]) elif sort == "descend": batches = sorted(batches, key=lambda x: x[1], reverse=True) else: raise ValueError(f"Sort method {sort} must be in None, ascend, descend!") batches = [x[0] for x in batches] self.indices = sum(batches, []) def __len__(self): return len(self.indices) def __iter__(self): return iter(self.indices) class LLaVATrainer(Trainer): def __init__(self, *args, model_args, **kwargs): super().__init__(*args, **kwargs) self.model_args = model_args def compute_loss(self, model, inputs, return_outputs=False): if "retrieval_span" in inputs: self.model.memory._retrieval_span = inputs['retrieval_span'][0] inputs.pop("retrieval_span") outputs = super().compute_loss(model, inputs, return_outputs) if hasattr(self.model, "memory") and hasattr(self.model.memory, "_retrieval_span"): del self.model.memory._retrieval_span del self.model.memory._retrieval_condensing_ratios return outputs def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: """ Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and handling potential state. """ inputs.pop("length", None) inputs.pop("index", None) # move to GPU inputs = self._prepare_input(inputs) # NOTE: reset memory for each individual input if hasattr(self.model, "memory"): self.model.memory.reset() return inputs def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: # Build the sampler. if self.args.group_by_stride is not None: # print(is_datasets_available(),isinstance(self.train_dataset, datasets.Dataset)) # if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): # print("yes") # lengths = self.train_dataset.modality_lengths # else: # print("no") # lengths = None lengths = self.train_dataset.modality_lengths # lengths=self.train_dataset.stored_variables # print("$$$$$$$lengths&&&&&&&&",lengths,self.train_dataset.modality_lengths) # print("@@@@@@@",self.model.memory.config) model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None return StrideGroupedSampler( # NOTE: multiply world size to get the total number of training instances across devices batch_size=self.args.train_batch_size * self.args.world_size, window=self.model.memory.config.beacon_window, stride=self.model.memory.config.beacon_stride, group=self.args.group_by_stride, sort=self.args.sort_by_stride, dataset=self.train_dataset, lengths=lengths, model_input_name=model_input_name, ) else: return super()._get_train_sampler() @torch.no_grad() def evaluate(self, eval_dataset: Dataset | None = None, ignore_keys: List[str] | None = None, metric_key_prefix: str = "eval") -> Dict[str, float]: # memory metrics - must set up as early as possible self._memory_tracker.start() if eval_dataset is None and self.eval_dataset is None: return if self.args.eval_method == "generation": labels = self.eval_dataset["labels"] self.eval_dataset = self.eval_dataset.remove_columns(["labels"]) dataloader = self.get_eval_dataloader() self.model.memory.reset() train_beacon_ratio = self.model.memory.beacon_ratio train_beacon_ratio_mix = self.model.memory.beacon_ratio_mix self.model.memory.set( beacon_ratio=self.args.eval_beacon_ratio, beacon_ratio_mix=self.args.eval_beacon_ratio_mix, ) model = self.model.eval() if self.args.eval_method == "perplexity": perplexity = evaluate_perplexity(model, dataloader, accelerator=self.accelerator) metrics = {"perplexity": perplexity} elif self.args.eval_method == "generation": indices, outputs = evaluate_generation( model, dataloader, accelerator=self.accelerator, tokenizer=self.tokenizer, ) metrics = self.compute_metrics(outputs, labels, indices=indices) else: raise NotImplementedError(f"Eval method {self.args.eval_method} not implemented!") self.model.memory.reset() self.model.memory.set( beacon_ratio=train_beacon_ratio, beacon_ratio_mix=train_beacon_ratio_mix, ) # Prefix all keys with metric_key_prefix + '_' for key in list(metrics.keys()): if not key.startswith(f"{metric_key_prefix}_") and key != "epoch": metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) self.log(metrics) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) self._memory_tracker.stop_and_update_metrics(metrics) return metrics def create_optimizer(self): """ Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. """ if is_sagemaker_mp_enabled(): return super().create_optimizer() opt_model = self.model if self.optimizer is None: decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) decay_parameters = [name for name in decay_parameters if "bias" not in name] if self.args.mm_projector_lr is not None: projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name] if self.args.mm_vision_tower_lr is not None: vision_tower_parameters = [ name for name, _ in opt_model.named_parameters() if "vision_tower" in name] optimizer_grouped_parameters = [ { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and n not in vision_tower_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and n in vision_tower_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, "lr": self.args.mm_vision_tower_lr, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and n not in vision_tower_parameters and p.requires_grad) ], "weight_decay": 0.0, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and n in vision_tower_parameters and p.requires_grad) ], "weight_decay": 0.0, "lr": self.args.mm_vision_tower_lr, }, { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, "lr": self.args.mm_projector_lr, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad) ], "weight_decay": 0.0, "lr": self.args.mm_projector_lr, }, ] else: optimizer_grouped_parameters = [ { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad) ], "weight_decay": 0.0, }, { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, "lr": self.args.mm_projector_lr, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad) ], "weight_decay": 0.0, "lr": self.args.mm_projector_lr, }, ] else: optimizer_grouped_parameters = [ { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) ], "weight_decay": 0.0, }, ] optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == "Adam8bit": import bitsandbytes manager = bitsandbytes.optim.GlobalOptimManager.get_instance() skipped = 0 for module in opt_model.modules(): if isinstance(module, nn.Embedding): skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) logger.info(f"skipped {module}: {skipped/2**20}M params") manager.register_module_override(module, "weight", {"optim_bits": 32}) logger.debug(f"bitsandbytes: will optimize {module} in fp32") logger.info(f"skipped: {skipped/2**20}M params") return self.optimizer def _save_checkpoint(self, model, trial, metrics=None): if getattr(self.args, 'tune_mm_mlp_adapter', False): from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) # Only save Adapter keys_to_match = ['mm_projector', 'vision_resampler'] if getattr(self.args, "use_im_start_end", False): keys_to_match.extend(['embed_tokens', 'embed_in']) weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) if self.args.local_rank == 0 or self.args.local_rank == -1: self.model.config.save_pretrained(output_dir) torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) else: super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) def _save(self, output_dir: Optional[str] = None, state_dict=None): if getattr(self.args, 'tune_mm_mlp_adapter', False): pass else: super(LLaVATrainer, self)._save(output_dir, state_dict)