Spaces:
Paused
Paused
import os | |
import json | |
import math | |
import torch | |
from tqdm import tqdm | |
from typing import Callable, Dict, List, Literal, Optional, Sequence, Tuple | |
from transformers import DataCollatorWithPadding, Seq2SeqTrainingArguments | |
from transformers.trainer import TRAINING_ARGS_NAME, TRAINER_STATE_NAME | |
from transformers.tokenization_utils import PreTrainedTokenizer | |
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead | |
from trl.core import LengthSampler | |
from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits | |
from .config import FinetuningArguments | |
from .other import ( | |
AverageMeter, | |
get_logger, | |
save_trainable_params, | |
save_valuehead_params, | |
get_logits_processor, | |
FINETUNING_ARGS_NAME | |
) | |
logger = get_logger(__name__) | |
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None: | |
if target == "reward": # save original head temporarily | |
valuehead_state_dict = model.v_head.state_dict() | |
setattr(model, "origin_head_weight", valuehead_state_dict["summary.weight"]) | |
setattr(model, "origin_head_bias", valuehead_state_dict["summary.bias"]) | |
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active | |
model.v_head.load_state_dict({ | |
"summary.weight": getattr(model, "{}_head_weight".format(target)), | |
"summary.bias": getattr(model, "{}_head_bias".format(target)) | |
}) | |
def compute_rewards( | |
input_ids: torch.Tensor, # (batch size x seq len) with format `X [gMASK] [BOS] Y [EOS] [PAD] ... [PAD]` | |
model: AutoModelForCausalLMWithValueHead, | |
tokenizer: PreTrainedTokenizer | |
) -> torch.Tensor: | |
replace_model(model, target="reward") | |
_, _, values = model(input_ids=input_ids) | |
values = values.transpose(0, 1) | |
rewards = [] | |
for i in range(input_ids.size(0)): | |
eos_idx = (input_ids[i] == tokenizer.eos_token_id).nonzero() # Note: checking with [EOS] token is unsafe | |
if len(eos_idx): | |
eos_idx = eos_idx[0].item() | |
else: | |
eos_idx = input_ids.size(1) - 1 | |
rewards.append(values[i][eos_idx]) | |
rewards = torch.stack(rewards, dim=0) | |
replace_model(model, target="default") | |
return rewards | |
def cast_layernorm_dtype( | |
model: AutoModelForCausalLMWithValueHead, | |
layer_norm_names: List[str] = ["layernorm"], # for chatglm setting | |
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None | |
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: | |
layer_norm_state_dict = {} | |
for name, param in model.named_parameters(): | |
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): | |
if layer_norm_params is not None: | |
param.data = layer_norm_params[name] # restore float32 weights | |
else: | |
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability | |
param.data = param.data.to(torch.float16) | |
return model, layer_norm_state_dict | |
class PPODataCollatorForChatGLM(DataCollatorWithPadding): | |
r""" | |
Data collator for ChatGLM. It is capable of dynamically padding for batched data. | |
""" | |
def __init__( | |
self, | |
tokenizer: PreTrainedTokenizer, | |
min_input_length: int, | |
max_input_length: int, | |
inference_mode: bool = False, | |
): | |
super().__init__(tokenizer, padding=True) | |
self.inference_mode = inference_mode | |
if min_input_length < max_input_length: | |
self.input_size = LengthSampler(min_input_length, max_input_length) | |
else: | |
self.input_size = lambda: max_input_length # always use max_input_length | |
def __call__(self, features: Sequence[Dict[str, Sequence]]) -> Dict[str, torch.Tensor]: | |
r""" | |
Pads batched data to the longest sequence in the batch. We adopt left-padding for ppo data. | |
Equips with a length sampler to generate sequences with variable lengths. | |
ChatGLM is able to generate attentions masks and position ids by itself. | |
""" | |
if self.inference_mode: | |
raise NotImplementedError | |
input_ids = [torch.tensor(feature["input_ids"][:self.input_size()]).flip(0) for feature in features] | |
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) | |
features = {"input_ids": input_ids.flip(-1)} | |
return features | |
class PPOTrainerForChatGLM(PPOTrainer): | |
r""" | |
Inherits PPOTrainer. | |
""" | |
def __init__(self, training_args: Seq2SeqTrainingArguments, finetuning_args: FinetuningArguments, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.state = {"log_history": []} | |
self.training_args = training_args | |
self.finetuning_args = finetuning_args | |
def generate( | |
self, | |
query_tensor: torch.Tensor, # (batch size x seq len) | |
length_sampler: Callable = None, | |
return_prompt: bool = True, | |
**generation_kwargs, | |
) -> torch.Tensor: | |
r""" | |
Generate response with the model given the query tensor. | |
Inspired by: https://github.com/lvwerra/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/trl/trainer/ppo_trainer.py#L387 | |
""" | |
self.model, layer_norm_params = cast_layernorm_dtype(self.model) | |
if length_sampler is not None: | |
generation_kwargs["max_new_tokens"] = length_sampler() | |
unwrapped_model = self.accelerator.unwrap_model(self.model) | |
response = unwrapped_model.generate( | |
input_ids=query_tensor, **generation_kwargs | |
) | |
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop | |
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273 | |
if unwrapped_model.pretrained_model.generation_config._from_model_config: | |
unwrapped_model.pretrained_model.generation_config._from_model_config = False | |
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params) | |
if not return_prompt and not self.is_encoder_decoder: | |
return response[:, query_tensor.size(1):] | |
return response | |
def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor): | |
input_ids = [] | |
for query, response in zip(queries, responses): # query is left-padded, response is right-padded | |
start = (query != self.tokenizer.pad_token_id).nonzero()[0].item() | |
input_ids.append(torch.cat((query[start:], response, query[:start]))) # change to right-padding | |
model_inputs = {"input_ids": torch.stack(input_ids, dim=0).to(self.current_device)} # already padded to equal length | |
model_inputs["attention_mask"] = torch.ones_like(model_inputs["input_ids"]) # unused indeed, avoid distributed error | |
return model_inputs | |
def batched_forward_pass( | |
self, | |
model: AutoModelForCausalLMWithValueHead, | |
queries: torch.Tensor, | |
responses: torch.Tensor, | |
model_inputs: dict, | |
): | |
r""" | |
Calculate model outputs in multiple batches. | |
Override to inject custom behavior. | |
""" | |
bs = len(queries) | |
fbs = self.config.mini_batch_size | |
all_logprobs = [] | |
all_logits = [] | |
all_masks = [] | |
all_values = [] | |
for i in range(int(bs / fbs)): | |
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} | |
input_ids = input_kwargs["input_ids"] | |
logits, _, values = model(input_ids=input_ids) # chatglm only needs input_ids | |
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) | |
values = values.transpose(0, 1) | |
masks = torch.zeros_like(input_ids) | |
for j in range(fbs): | |
start = (input_ids[j] == self.tokenizer.bos_token_id).nonzero()[0].item() # always contain a [BOS] token | |
end = (input_ids[j] == self.tokenizer.eos_token_id).nonzero() # Note: checking with [EOS] token is unsafe | |
if len(end): | |
end = end[0].item() | |
else: | |
end = masks.size(1) | |
masks[j][start:end] = 1 | |
if end - start < 2: | |
raise ValueError("Responses are too short. Make sure they are at least 4 tokens long.") | |
all_logits.append(logits) | |
all_values.append(values) | |
all_logprobs.append(logprobs) | |
all_masks.append(masks) | |
return ( | |
torch.cat(all_logprobs), | |
torch.cat(all_logits)[:, :-1], | |
torch.cat(all_values)[:, :-1], | |
torch.cat(all_masks)[:, :-1], | |
) | |
def ppo_train(self, max_target_length: int) -> None: | |
total_train_batch_size = self.config.batch_size * self.config.gradient_accumulation_steps * self.training_args.world_size | |
len_dataloader = len(self.dataloader) | |
num_steps_per_epoch = max(len_dataloader // self.config.gradient_accumulation_steps, 1) | |
num_examples = len(self.dataset) | |
num_train_epochs = self.training_args.num_train_epochs | |
max_steps = math.ceil(num_train_epochs * num_steps_per_epoch) | |
if self.is_world_process_zero(): | |
logger.info("***** Running training *****") | |
logger.info(f" Num examples = {num_examples}") | |
logger.info(f" Num Epochs = {num_train_epochs}") | |
logger.info(f" Instantaneous batch size per device = {self.config.batch_size}") | |
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") | |
logger.info(f" Gradient Accumulation steps = {self.config.gradient_accumulation_steps}") | |
logger.info(f" Total optimization steps = {max_steps}") | |
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}") | |
# Keyword arguments for `model.generate` | |
gen_kwargs = { | |
"top_k": 0.0, | |
"top_p": 1.0, | |
"do_sample": True, | |
"pad_token_id": self.tokenizer.pad_token_id, | |
"eos_token_id": self.tokenizer.eos_token_id, | |
"logits_processor": get_logits_processor() | |
} | |
output_length_sampler = LengthSampler(max_target_length // 2, max_target_length) | |
unwrapped_model = self.accelerator.unwrap_model(self.model) | |
dataiter = iter(self.dataloader) | |
steps_trained = 0 | |
loss_meter = AverageMeter() | |
reward_meter = AverageMeter() | |
for step in tqdm(range(max_steps)): | |
for _ in range(self.config.gradient_accumulation_steps): | |
batch = next(dataiter) | |
steps_trained += 1 | |
queries = batch["input_ids"] # left-padded sequences | |
unwrapped_model.gradient_checkpointing_disable() | |
unwrapped_model.config.use_cache = True | |
# Get response from ChatGLM | |
responses_with_queries = self.generate(queries, length_sampler=output_length_sampler, **gen_kwargs) | |
responses = responses_with_queries[:, queries.size(1):].clone().detach() # right-padded sequences (remember to clone!!!) | |
# batch["response"] = tokenizer.batch_decode(responses, skip_special_tokens=True) # comment to avoid decode error | |
for i in range(responses_with_queries.size(0)): # change to right-padding | |
start = (responses_with_queries[i] != self.tokenizer.pad_token_id).nonzero()[0].item() | |
responses_with_queries[i] = torch.cat((responses_with_queries[i][start:], responses_with_queries[i][:start])) | |
# Compute rewards | |
rewards = compute_rewards(responses_with_queries, unwrapped_model, self.tokenizer) | |
# Run PPO step | |
unwrapped_model.gradient_checkpointing_enable() | |
unwrapped_model.config.use_cache = False | |
split_into_list = lambda x: [x[i] for i in range(x.size(0))] | |
stats = self.step(*map(split_into_list, [queries, responses, rewards])) | |
loss_meter.update(stats["ppo/loss/total"]) | |
reward_meter.update(rewards.sum().item(), n=rewards.size(0)) | |
if steps_trained == len_dataloader: | |
dataiter = iter(self.dataloader) | |
steps_trained = 0 | |
if self.is_world_process_zero() and (step+1) % self.training_args.logging_steps == 0: | |
logs = { | |
"loss": round(loss_meter.avg, 4), | |
"reward": round(reward_meter.avg, 4), | |
"learning_rate": stats["ppo/learning_rate"], | |
"epoch": round(step / num_steps_per_epoch, 2) | |
} | |
print(logs) | |
logs["step"] = step | |
self.state["log_history"].append(logs) | |
loss_meter.reset() | |
reward_meter.reset() | |
if (step+1) % self.training_args.save_steps == 0: # save checkpoint | |
self.save_model(os.path.join(self.training_args.output_dir, f"checkpoint-{step+1}")) | |
def is_world_process_zero(self) -> bool: | |
r""" | |
Whether or not this process is the global main process (when training in a distributed fashion on several | |
machines, this is only going to be `True` for one process). | |
""" | |
return self.training_args.process_index == 0 | |
def save_state(self, output_dir: Optional[str] = None) -> None: | |
r""" | |
Saves trainer state. | |
""" | |
if not self.is_world_process_zero(): | |
return | |
output_dir = output_dir if output_dir is not None else self.training_args.output_dir | |
os.makedirs(output_dir, exist_ok=True) | |
json.dump(self.state, open(os.path.join(output_dir, TRAINER_STATE_NAME), "w", encoding="utf-8", newline="\n"), indent=2) | |
def save_model(self, output_dir: Optional[str] = None) -> None: | |
r""" | |
Saves trainable parameters as model checkpoints. We use `self.model.pretrained_model` to refer to the backbone model. | |
Override to inject custom behavior. | |
""" | |
if not self.is_world_process_zero(): | |
return | |
output_dir = output_dir if output_dir is not None else self.training_args.output_dir | |
os.makedirs(output_dir, exist_ok=True) | |
logger.info(f"Saving model checkpoint to {output_dir}") | |
unwrapped_model = self.accelerator.unwrap_model(self.model) | |
if hasattr(unwrapped_model.pretrained_model, "peft_config"): # peft methods | |
unwrapped_model.pretrained_model.save_pretrained(output_dir) # save lora weights | |
else: # non-peft methods | |
save_trainable_params(output_dir, unwrapped_model.pretrained_model) | |
if hasattr(unwrapped_model, "v_head"): | |
save_valuehead_params(output_dir, unwrapped_model.v_head) # save valuehead weights | |
torch.save(self.training_args, os.path.join(output_dir, TRAINING_ARGS_NAME)) | |
torch.save(self.finetuning_args, os.path.join(output_dir, FINETUNING_ARGS_NAME)) | |