Nurisslam commited on
Commit
f5ea7e7
·
verified ·
1 Parent(s): c930d5c

Rename inference.py to train.py

Browse files
Files changed (2) hide show
  1. inference.py +0 -13
  2. train.py +31 -0
inference.py DELETED
@@ -1,13 +0,0 @@
1
- from transformers import MT5ForConditionalGeneration, MT5Tokenizer
2
-
3
- model = MT5ForConditionalGeneration.from_pretrained("./model")
4
- tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small")
5
-
6
- def ask(question, context):
7
- input_text = f"Сұрақ: {question} Контекст: {context}"
8
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids
9
- output = model.generate(input_ids, max_length=100)
10
- return tokenizer.decode(output[0], skip_special_tokens=True)
11
-
12
- context = """Мәліметтер қоры дегеніміз – белгілі бір сипаттамасы бар, өзара байланыса сақталатын ақпараттар жиынтығы."""
13
- print(ask("Мәліметтер қоры дегеніміз не?", context))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering, Trainer, TrainingArguments
3
+
4
+ model_name = "ai4bharat/indic-bert"
5
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
6
+ model = AutoModelForQuestionAnswering.from_pretrained(model_name)
7
+
8
+ dataset = load_dataset("json", data_files="qa_dataset.json")
9
+
10
+ def preprocess(examples):
11
+ inputs = tokenizer(examples['question'], examples['context'], truncation=True, padding='max_length')
12
+ return inputs
13
+
14
+ dataset = dataset.map(preprocess, batched=True)
15
+
16
+ training_args = TrainingArguments(
17
+ output_dir="./model",
18
+ evaluation_strategy="no",
19
+ per_device_train_batch_size=4,
20
+ num_train_epochs=3
21
+ )
22
+
23
+ trainer = Trainer(
24
+ model=model,
25
+ args=training_args,
26
+ train_dataset=dataset['train']
27
+ )
28
+
29
+ trainer.train()
30
+ model.save_pretrained("./model")
31
+ tokenizer.save_pretrained("./model")