Nurisslam commited on
Commit
7cc244b
·
verified ·
1 Parent(s): 1fa8d40

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +39 -0
train.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import MT5ForConditionalGeneration, MT5Tokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments
2
+ from datasets import load_dataset, Dataset
3
+ import json
4
+
5
+ # Загрузка данных
6
+ with open("data/dataset_qa.json") as f:
7
+ data = json.load(f)
8
+ dataset = Dataset.from_list(data)
9
+
10
+ model_name = "google/mt5-small"
11
+ tokenizer = MT5Tokenizer.from_pretrained(model_name)
12
+ model = MT5ForConditionalGeneration.from_pretrained(model_name)
13
+
14
+ def preprocess(example):
15
+ input_text = "Сұрақ: " + example["question"] + " Контекст: " + example["context"]
16
+ target_text = example["answer"]
17
+ inputs = tokenizer(input_text, max_length=512, truncation=True, padding="max_length")
18
+ labels = tokenizer(target_text, max_length=128, truncation=True, padding="max_length")
19
+ inputs["labels"] = labels["input_ids"]
20
+ return inputs
21
+
22
+ tokenized_dataset = dataset.map(preprocess)
23
+
24
+ training_args = Seq2SeqTrainingArguments(
25
+ output_dir="./model",
26
+ evaluation_strategy="epoch",
27
+ learning_rate=2e-5,
28
+ per_device_train_batch_size=4,
29
+ num_train_epochs=5,
30
+ save_total_limit=1,
31
+ )
32
+
33
+ trainer = Seq2SeqTrainer(
34
+ model=model,
35
+ args=training_args,
36
+ train_dataset=tokenized_dataset,
37
+ )
38
+
39
+ trainer.train()