eawolf2357-git / videoxl /train /llava_trainer.py
seawolf2357's picture
Upload folder using huggingface_hub
321d89c verified
import os
import torch
import datasets
import math
import random
import torch.nn as nn
from typing import Any, Dict, List, Optional, Union
from torch.utils.data import Sampler
from torch.utils.data import Sampler, Dataset
from transformers import Trainer
from transformers.trainer import (
is_sagemaker_mp_enabled,
get_parameter_names,
has_length,
ALL_LAYERNORM_LAYERS,
logger,
)
from transformers.trainer import Trainer, is_datasets_available
from typing import List, Optional
from .modeling_utils import evaluate_generation, evaluate_perplexity
def get_vision_tower_state_maybe_zero_3(named_params, keys_to_match=['']):
to_return = {k: t for k, t in named_params if any(
key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu()
for k, v in to_return.items()}
return to_return
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
print(name, 'no ignore status')
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
return to_return
def split_to_even_chunks(indices, lengths, num_chunks):
"""
Split a list of indices into `chunks` chunks of roughly equal lengths.
"""
if len(indices) % num_chunks != 0:
return [indices[i::num_chunks] for i in range(num_chunks)]
num_indices_per_chunk = len(indices) // num_chunks
chunks = [[] for _ in range(num_chunks)]
chunks_lengths = [0 for _ in range(num_chunks)]
for index in indices:
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
chunks[shortest_chunk].append(index)
chunks_lengths[shortest_chunk] += lengths[index]
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
chunks_lengths[shortest_chunk] = float("inf")
return chunks
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
assert all(l != 0 for l in lengths), "Should not have zero length."
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
# all samples are in the same modality
return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
megabatch_size = world_size * batch_size
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
last_mm = mm_megabatches[-1]
last_lang = lang_megabatches[-1]
additional_batch = last_mm + last_lang
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
megabatches = [megabatches[i] for i in megabatch_indices]
if len(additional_batch) > 0:
megabatches.append(sorted(additional_batch))
return [i for megabatch in megabatches for i in megabatch]
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
megabatch_size = world_size * batch_size
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
return [i for megabatch in megabatches for batch in megabatch for i in batch]
class LengthGroupedSampler(Sampler):
r"""
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
keeping a bit of randomness.
"""
def __init__(
self,
batch_size: int,
world_size: int,
lengths: Optional[List[int]] = None,
generator=None,
group_by_modality: bool = False,
):
if lengths is None:
raise ValueError("Lengths must be provided.")
self.batch_size = batch_size
self.world_size = world_size
self.lengths = lengths
self.generator = generator
self.group_by_modality = group_by_modality
def __len__(self):
return len(self.lengths)
def __iter__(self):
if self.group_by_modality:
indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
else:
indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
return iter(indices)
class StrideGroupedSampler(Sampler):
"""Group """
def __init__(
self,
batch_size: int,
window: int,
stride: int,
group: str,
sort: Optional[str] = None,
dataset: Optional[Dataset] = None,
lengths: Optional[List[int]] = None,
model_input_name: Optional[str] = None
):
if dataset is None and lengths is None:
raise ValueError("One of dataset and lengths must be provided.")
if group is None:
raise ValueError("Group cannot be None!")
if lengths is None:
model_input_name = model_input_name if model_input_name is not None else "input_ids"
if (
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
or model_input_name not in dataset[0]
):
raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
f"'{model_input_name}' key."
)
lengths = [len(feature[model_input_name]) for feature in dataset]
elif isinstance(lengths, torch.Tensor):
logger.info(
"If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]..."
)
lengths = lengths.tolist()
indices = list(range(len(lengths)))
# get number of strides for each data
num_strides = []
for length in lengths:
num_stride = math.ceil((length - window) / stride) + 1
num_strides.append(num_stride)
indice_stride_pairs = list(zip(indices, num_strides))
# NOTE: shuffle the indices in advance, otherwise the randomness may be lost when all num_strides are equal
random.shuffle(indice_stride_pairs)
# sort data according to the number of strides
indice_stride_pairs = sorted(indice_stride_pairs, key=lambda x: x[1])
# group data instances with the same number of strides into the same batch
batches = []
batch = []
prev_num_stride = None
for index, num_stride in indice_stride_pairs:
if num_stride != prev_num_stride:
# in strict mode, all instances in the batch are forced to have the same number of strides
if group == "strict":
batch.clear()
elif group == "relaxed":
pass
else:
raise ValueError(f"Group method {group} must be in None, strict, relaxed!")
batch.append(index)
prev_num_stride = num_stride
if len(batch) == batch_size:
batches.append((batch.copy(), num_stride))
batch.clear()
if len(batch) and group == "relaxed":
batches.append((batch.copy(), num_stride))
if sort is None:
random.shuffle(batches)
elif sort == "ascend":
batches = sorted(batches, key=lambda x: x[1])
elif sort == "descend":
batches = sorted(batches, key=lambda x: x[1], reverse=True)
else:
raise ValueError(f"Sort method {sort} must be in None, ascend, descend!")
batches = [x[0] for x in batches]
self.indices = sum(batches, [])
def __len__(self):
return len(self.indices)
def __iter__(self):
return iter(self.indices)
class LLaVATrainer(Trainer):
def __init__(self, *args, model_args, **kwargs):
super().__init__(*args, **kwargs)
self.model_args = model_args
def compute_loss(self, model, inputs, return_outputs=False):
if "retrieval_span" in inputs:
self.model.memory._retrieval_span = inputs['retrieval_span'][0]
inputs.pop("retrieval_span")
outputs = super().compute_loss(model, inputs, return_outputs)
if hasattr(self.model, "memory") and hasattr(self.model.memory, "_retrieval_span"):
del self.model.memory._retrieval_span
del self.model.memory._retrieval_condensing_ratios
return outputs
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
"""
Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
handling potential state.
"""
inputs.pop("length", None)
inputs.pop("index", None)
# move to GPU
inputs = self._prepare_input(inputs)
# NOTE: reset memory for each individual input
if hasattr(self.model, "memory"):
self.model.memory.reset()
return inputs
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
# Build the sampler.
if self.args.group_by_stride is not None:
# print(is_datasets_available(),isinstance(self.train_dataset, datasets.Dataset))
# if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
# print("yes")
# lengths = self.train_dataset.modality_lengths
# else:
# print("no")
# lengths = None
lengths = self.train_dataset.modality_lengths
# lengths=self.train_dataset.stored_variables
# print("$$$$$$$lengths&&&&&&&&",lengths,self.train_dataset.modality_lengths)
# print("@@@@@@@",self.model.memory.config)
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
return StrideGroupedSampler(
# NOTE: multiply world size to get the total number of training instances across devices
batch_size=self.args.train_batch_size * self.args.world_size,
window=self.model.memory.config.beacon_window,
stride=self.model.memory.config.beacon_stride,
group=self.args.group_by_stride,
sort=self.args.sort_by_stride,
dataset=self.train_dataset,
lengths=lengths,
model_input_name=model_input_name,
)
else:
return super()._get_train_sampler()
@torch.no_grad()
def evaluate(self, eval_dataset: Dataset | None = None, ignore_keys: List[str] | None = None, metric_key_prefix: str = "eval") -> Dict[str, float]:
# memory metrics - must set up as early as possible
self._memory_tracker.start()
if eval_dataset is None and self.eval_dataset is None:
return
if self.args.eval_method == "generation":
labels = self.eval_dataset["labels"]
self.eval_dataset = self.eval_dataset.remove_columns(["labels"])
dataloader = self.get_eval_dataloader()
self.model.memory.reset()
train_beacon_ratio = self.model.memory.beacon_ratio
train_beacon_ratio_mix = self.model.memory.beacon_ratio_mix
self.model.memory.set(
beacon_ratio=self.args.eval_beacon_ratio,
beacon_ratio_mix=self.args.eval_beacon_ratio_mix,
)
model = self.model.eval()
if self.args.eval_method == "perplexity":
perplexity = evaluate_perplexity(model, dataloader, accelerator=self.accelerator)
metrics = {"perplexity": perplexity}
elif self.args.eval_method == "generation":
indices, outputs = evaluate_generation(
model,
dataloader,
accelerator=self.accelerator,
tokenizer=self.tokenizer,
)
metrics = self.compute_metrics(outputs, labels, indices=indices)
else:
raise NotImplementedError(f"Eval method {self.args.eval_method} not implemented!")
self.model.memory.reset()
self.model.memory.set(
beacon_ratio=train_beacon_ratio,
beacon_ratio_mix=train_beacon_ratio_mix,
)
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_") and key != "epoch":
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
self.log(metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
self._memory_tracker.stop_and_update_metrics(metrics)
return metrics
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if is_sagemaker_mp_enabled():
return super().create_optimizer()
opt_model = self.model
if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
if self.args.mm_projector_lr is not None:
projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
if self.args.mm_vision_tower_lr is not None:
vision_tower_parameters = [
name for name, _ in opt_model.named_parameters() if "vision_tower" in name]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and n not in vision_tower_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and n in vision_tower_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.mm_vision_tower_lr,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and n not in vision_tower_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and n in vision_tower_parameters and p.requires_grad)
],
"weight_decay": 0.0,
"lr": self.args.mm_vision_tower_lr,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.mm_projector_lr,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
],
"weight_decay": 0.0,
"lr": self.args.mm_projector_lr,
},
]
else:
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.mm_projector_lr,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
],
"weight_decay": 0.0,
"lr": self.args.mm_projector_lr,
},
]
else:
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
return self.optimizer
def _save_checkpoint(self, model, trial, metrics=None):
if getattr(self.args, 'tune_mm_mlp_adapter', False):
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
# Only save Adapter
keys_to_match = ['mm_projector', 'vision_resampler']
if getattr(self.args, "use_im_start_end", False):
keys_to_match.extend(['embed_tokens', 'embed_in'])
weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
if self.args.local_rank == 0 or self.args.local_rank == -1:
self.model.config.save_pretrained(output_dir)
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
else:
super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, 'tune_mm_mlp_adapter', False):
pass
else:
super(LLaVATrainer, self)._save(output_dir, state_dict)