# Adapted from: https://www.philschmid.de/fine-tune-flan-t5#3-fine-tune-and-evaluate-flan-t5 import os import json import copy import glob import torch import random import warnings import evaluate import numpy as np from datasets import Dataset from dataclasses import dataclass from transformers import ( set_seed, AutoConfig, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer, Trainer, Seq2SeqTrainer, TrainingArguments, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, ) # from peft import get_peft_model, LoraConfig, TaskType from transformers.trainer_callback import TrainerCallback # from dsp.modules.finetuning.fid import * warnings.filterwarnings("ignore") IGNORE_INDEX = -100 DEFAULT_SEP_TOKEN = "[SEP]" DEFAULT_PAD_TOKEN = "[PAD]" DEFAULT_EOS_TOKEN = "" DEFAULT_BOS_TOKEN = "" DEFAULT_UNK_TOKEN = "" SPECIAL_TOKENS_DICT = { "sep_token": DEFAULT_SEP_TOKEN, "pad_token": DEFAULT_PAD_TOKEN, # "eos_token": DEFAULT_EOS_TOKEN, # "bos_token": DEFAULT_BOS_TOKEN, "unk_token": DEFAULT_UNK_TOKEN, } def _freeze_model_layers(model, unfreeze_last_n): # Freeze all layers for parameter in model.parameters(): parameter.requires_grad = False # Unfreeze the last n transformer blocks in the decoder NUM_DECODER_LAYERS = len(model.transformer.h) for i, m in enumerate(model.transformer.h): if i >= NUM_DECODER_LAYERS - unfreeze_last_n: for parameter in m.parameters(): parameter.requires_grad = True # Unfreeze parameters after decoder block for parameter in model.transformer.ln_f.parameters(): parameter.requires_grad = True for parameter in model.lm_head.parameters(): parameter.requires_grad = True return model def _load_data(path): # dataset = Dataset.from_json(path) L = [] import ujson with open(path) as f: for line in f: L.append(ujson.loads(line)) dataset = Dataset.from_list(L) return dataset def preprocess_prompt(text, tokenizer, encoder_decoder_model, decoder_only_model, rationale): text = f'{text} ' if encoder_decoder_model else f'{text} {tokenizer.sep_token}' return text def preprocess_completion(text, tokenizer, encoder_decoder_model, decoder_only_model, rationale): text = text if encoder_decoder_model else f'{text}{tokenizer.eos_token}' return text.lstrip() def _preprocess_data(dataset, tokenizer, encoder_decoder_model, decoder_only_model, config): dataset = dataset.map(lambda x: { "prompt": preprocess_prompt(x["prompt"], tokenizer, encoder_decoder_model, decoder_only_model, config['rationale']), "completion": preprocess_completion(x["completion"], tokenizer, encoder_decoder_model, decoder_only_model, config['rationale']), }) skipped = [x for x in dataset if x["completion"] is None] print(f'# examples skipped due to parsing error: {len(skipped)} / {len(dataset)}') dataset = dataset.filter(lambda x: x["completion"]) return dataset def _tokenize_dataset(dataset, tokenizer, encoder_decoder_model, decoder_only_model): def get_dataset_stats(dataset, tokenizer, column): tokenized_inputs = dataset.map(lambda x: tokenizer(x[column]), batched=True) max_length = max([len(x) for x in tokenized_inputs["input_ids"]]) return max_length def get_tokens_seq2seq(sample, max_source_length, max_target_length, padding="max_length"): # Tokenize inputs model_inputs = tokenizer(sample["prompt"], max_length=max_source_length, padding=padding, truncation=True) # Tokenize targets labels = tokenizer(text_target=sample["completion"], max_length=max_target_length, padding=padding, truncation=True) labels = labels["input_ids"] # Replace all tokenizer.pad_token_id in the labels by IGNORE_INDEX when we want to ignore padding in the loss. if padding == "max_length": labels = [[(l if l != tokenizer.pad_token_id else IGNORE_INDEX) for l in label] for label in labels] model_inputs["labels"] = labels return model_inputs def get_tokens_causal(sample, max_length, padding="max_length"): # Tokenize inputs model_inputs = tokenizer(sample["combined"], max_length=max_length, padding=padding, truncation=True) # Create targets labels = copy.deepcopy(model_inputs["input_ids"]) prompts = tokenizer(sample["prompt"], max_length=max_length, truncation=True) prompt_lens = [len(tokens) for tokens in prompts["input_ids"]] for label, source_len in zip(labels, prompt_lens): label[:source_len] = [IGNORE_INDEX] * source_len # Replace all tokenizer.pad_token_id in the labels by IGNORE_INDEX when we want to ignore padding in the loss. if padding == "max_length": labels = [[(l if l != tokenizer.pad_token_id else IGNORE_INDEX) for l in label] for label in labels] model_inputs["labels"] = labels return model_inputs if encoder_decoder_model: max_source_length = get_dataset_stats(dataset, tokenizer, "prompt") max_target_length = get_dataset_stats(dataset, tokenizer, "completion") kwargs = {"max_source_length" : max_source_length, "max_target_length" : max_target_length} tokenized_dataset = dataset.map(get_tokens_seq2seq, batched=True, fn_kwargs=kwargs) elif decoder_only_model: dataset = dataset.map(lambda example: {"combined": example["prompt"] + " " + example["completion"]}) dataset = dataset.filter(lambda x: len(tokenizer(x["combined"])["input_ids"]) <= tokenizer.model_max_length) max_length = get_dataset_stats(dataset, tokenizer, "combined") kwargs = {"max_length" : max_length} tokenized_dataset = dataset.map(get_tokens_causal, batched=True, fn_kwargs=kwargs) print(f"Dataset statistics: {kwargs}") print(f"Keys of tokenized dataset: {list(tokenized_dataset.features)}") return tokenized_dataset def _compute_metrics(metric, eval_preds, tokenizer): preds, labels = eval_preds if isinstance(preds, tuple): preds = preds[0] decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) # Replace IGNORE_INDEX in the labels as we can't decode them. labels = np.where(labels != IGNORE_INDEX, labels, tokenizer.pad_token_id) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) result = {k: round(v * 100, 4) for k, v in result.items()} prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] result["gen_len"] = np.mean(prediction_lens) return result class PeftSavingCallback(TrainerCallback): def on_train_end(self, args, state, control, **kwargs): peft_model_path = state.best_model_checkpoint kwargs["model"].save_pretrained(peft_model_path) pytorch_model_path = os.path.join(state.best_model_checkpoint, "pytorch_model.bin") os.remove(pytorch_model_path) if os.path.exists(pytorch_model_path) else None def _train_seq2seq(model, tokenizer, tokenized_dataset, metric, config): # Define data collator data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model) # Define training args training_args = Seq2SeqTrainingArguments( output_dir=config['output_dir'], per_device_train_batch_size=config['batch_size'], gradient_accumulation_steps=config['gradient_accumulation_steps'], per_device_eval_batch_size=config['batch_size'], predict_with_generate=True, learning_rate=config['lr'], #1e-4, # 5e-5 num_train_epochs=config['epochs'], # logging & evaluation strategies log_level="error", logging_dir=f"{config['output_dir']}/logs", logging_strategy="steps", logging_steps=500, evaluation_strategy="epoch", save_strategy="epoch", save_total_limit=config['epochs'], load_best_model_at_end=True, report_to="tensorboard", fp16=config['fp16'], bf16=config['bf16'], ) # Create trainer instance trainer = Seq2SeqTrainer( model=model, tokenizer=tokenizer, args=training_args, train_dataset=tokenized_dataset["train"], eval_dataset=tokenized_dataset["test"], data_collator=data_collator, compute_metrics=lambda x: _compute_metrics(metric, x, tokenizer), callbacks=[PeftSavingCallback] if config['peft'] else None, ) trainer.train() return trainer.state.best_model_checkpoint def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model): """ Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg @dataclass class DataCollatorForSupervisedDataset(object): """ Collate examples for supervised fine-tuning. """ tokenizer: PreTrainedTokenizer def __call__(self, instances): pad_token_id = self.tokenizer.pad_token_id input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids, labels = torch.tensor(input_ids), torch.tensor(labels) input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) return dict(input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(pad_token_id)) def _train_causal(model, tokenizer, tokenized_dataset, metric, config): # Define data collator data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) # Define training args training_args = TrainingArguments( output_dir=config['output_dir'], per_device_train_batch_size=config['batch_size'], gradient_accumulation_steps=config['gradient_accumulation_steps'], per_device_eval_batch_size=config['batch_size'], learning_rate=config['lr'], #1e-4,# 5e-5 num_train_epochs=config['epochs'], # logging & evaluation strategies log_level="error", logging_dir=f"{config['output_dir']}/logs", logging_strategy="steps", logging_steps=500, evaluation_strategy="epoch", save_strategy="epoch", save_total_limit=config['epochs'], load_best_model_at_end=True, report_to="tensorboard", fp16=config['fp16'], bf16=config['bf16'], ) # Create trainer instance trainer = Trainer( model=model, tokenizer=tokenizer, args=training_args, train_dataset=tokenized_dataset["train"], eval_dataset=tokenized_dataset["test"], data_collator=data_collator, callbacks=[PeftSavingCallback] if config['peft'] else None, ) trainer.train() return trainer.state.best_model_checkpoint def finetune_hf(data_path, target, config): set_seed(42) output_dir = os.path.join('../finetuning_ckpts', config['save']) if os.path.exists(output_dir): # training completed, load best model ckpts = glob.glob(f'{output_dir}/checkpoint*') final_ckpt = sorted(ckpts, key=lambda x: int(x.split('-')[-1]))[-1] with open(os.path.join(final_ckpt, 'trainer_state.json'), 'r') as f: state = json.load(f) best_model_checkpoint = state['best_model_checkpoint'] else: os.makedirs(output_dir, exist_ok=True) config['target'] = target config['output_dir'] = output_dir with open(os.path.join(config['output_dir'], 'compiler_config.json'), 'w') as f: json.dump(config, f) architecture = AutoConfig.from_pretrained(target).__dict__["architectures"][0] encoder_decoder_model = ("ConditionalGeneration" in architecture) or ("T5WithLMHeadModel" in architecture) decoder_only_model = ("CausalLM" in architecture) or ("GPT2LMHeadModel" in architecture) assert encoder_decoder_model or decoder_only_model, f"Unknown HuggingFace model class: {target}" assert not config['fid'] or encoder_decoder_model, f"Model must be encoder-decoder for Fusion in Decoder" assert not config['fid'] or not config['peft'], f"FiD and PEFT can't be trained together" # load model AutoModelClass = AutoModelForSeq2SeqLM if encoder_decoder_model else AutoModelForCausalLM if config['peft']: model = AutoModelClass.from_pretrained(target, device_map='auto') task_type = TaskType.SEQ_2_SEQ_LM if encoder_decoder_model else TaskType.CAUSAL_LM peft_config = LoraConfig(task_type=task_type, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1) model = get_peft_model(model, peft_config) model.print_trainable_parameters() else: if config['fid']: t5 = AutoModelClass.from_pretrained(target) model = FiDT5(t5.config) model.load_t5(t5.state_dict()) else: model = AutoModelClass.from_pretrained(target) # model = _freeze_model_layers(model, unfreeze_last_n=2) # load tokenizer tokenizer = AutoTokenizer.from_pretrained(target) if decoder_only_model: smart_tokenizer_and_embedding_resize(SPECIAL_TOKENS_DICT, tokenizer, model) # load data dataset = _load_data(data_path) dataset = _preprocess_data(dataset, tokenizer, encoder_decoder_model, decoder_only_model, config) tokenized_dataset = _tokenize_dataset(dataset, tokenizer, encoder_decoder_model, decoder_only_model) tokenized_dataset = tokenized_dataset.train_test_split(test_size=0.1) print(f'Finetuning dataset: {tokenized_dataset}') # start training metric = evaluate.load("rouge") if encoder_decoder_model: best_model_checkpoint = _train_seq2seq(model, tokenizer, tokenized_dataset, metric, config) elif decoder_only_model: best_model_checkpoint = _train_causal(model, tokenizer, tokenized_dataset, metric, config) print(f'Best checkpoint of model: {best_model_checkpoint}') return best_model_checkpoint