|
import pandas as pd |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments |
|
from sklearn.model_selection import train_test_split |
|
|
|
data = pd.read_csv('data/train_data.csv') |
|
queries = data['query'].tolist() |
|
arguments = data['arguments'].tolist() |
|
|
|
train_queries, eval_queries, train_arguments, eval_arguments = train_test_split(queries, arguments, test_size=0.2, random_state=42) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large") |
|
|
|
train_encodings = tokenizer(train_queries, truncation=True, padding=True) |
|
eval_encodings = tokenizer(eval_queries, truncation=True, padding=True) |
|
|
|
with tokenizer.as_target_tokenizer(): |
|
train_labels = tokenizer(train_arguments, truncation=True, padding=True) |
|
eval_labels = tokenizer(eval_arguments, truncation=True, padding=True) |
|
|
|
class PlotDataset(torch.utils.data.Dataset): |
|
def __init__(self, encodings, labels): |
|
self.encodings = encodings |
|
self.labels = labels |
|
|
|
def __getitem__(self, idx): |
|
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} |
|
item['labels'] = torch.tensor(self.labels['input_ids'][idx]) |
|
return item |
|
|
|
def __len__(self): |
|
return len(self.encodings.input_ids) |
|
|
|
train_dataset = PlotDataset(train_encodings, train_labels) |
|
eval_dataset = PlotDataset(eval_encodings, eval_labels) |
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
output_dir='./results', |
|
per_device_train_batch_size=2, |
|
per_device_eval_batch_size=2, |
|
num_train_epochs=3, |
|
logging_dir='./logs', |
|
logging_steps=10, |
|
save_steps=500, |
|
save_total_limit=2, |
|
evaluation_strategy="epoch", |
|
predict_with_generate=True, |
|
generation_max_length=100, |
|
) |
|
trainer = Seq2SeqTrainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
tokenizer=tokenizer, |
|
) |
|
|
|
trainer.train() |
|
|
|
trainer.save_model("fine-tuned-bart-large") |
|
tokenizer.save_pretrained("fine-tuned-bart-large") |
|
|
|
print("Model and tokenizer fine-tuned and saved as 'fine-tuned-bart-large'") |