Nurisslam commited on
Commit
aea6f00
·
verified ·
1 Parent(s): f35f5a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -34
app.py CHANGED
@@ -1,39 +1,27 @@
1
- from transformers import MT5ForConditionalGeneration, MT5Tokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments
2
- from datasets import load_dataset, Dataset
3
- import json
4
 
5
- # Загрузка данных
6
- with open("data/dataset_qa.json") as f:
7
- data = json.load(f)
8
- dataset = Dataset.from_list(data)
9
 
10
- model_name = "google/mt5-small"
11
- tokenizer = MT5Tokenizer.from_pretrained(model_name)
12
- model = MT5ForConditionalGeneration.from_pretrained(model_name)
 
 
 
13
 
14
- def preprocess(example):
15
- input_text = "Сұрақ: " + example["question"] + " Контекст: " + example["context"]
16
- target_text = example["answer"]
17
- inputs = tokenizer(input_text, max_length=512, truncation=True, padding="max_length")
18
- labels = tokenizer(target_text, max_length=128, truncation=True, padding="max_length")
19
- inputs["labels"] = labels["input_ids"]
20
- return inputs
21
-
22
- tokenized_dataset = dataset.map(preprocess)
23
-
24
- training_args = Seq2SeqTrainingArguments(
25
- output_dir="./model",
26
- evaluation_strategy="epoch",
27
- learning_rate=2e-5,
28
- per_device_train_batch_size=4,
29
- num_train_epochs=5,
30
- save_total_limit=1,
31
- )
32
-
33
- trainer = Seq2SeqTrainer(
34
- model=model,
35
- args=training_args,
36
- train_dataset=tokenized_dataset,
37
  )
38
 
39
- trainer.train()
 
1
+ import gradio as gr
2
+ from transformers import MT5ForConditionalGeneration, MT5Tokenizer
 
3
 
4
+ # Модельді жүктеу
5
+ model = MT5ForConditionalGeneration.from_pretrained("model")
6
+ tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small")
 
7
 
8
+ # Сұраққа жауап беру функциясы
9
+ def answer_question(question, context):
10
+ input_text = f"Сұрақ: {question} Контекст: {context}"
11
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
12
+ output_ids = model.generate(input_ids, max_length=128)
13
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
14
 
15
+ # Интерфейс
16
+ iface = gr.Interface(
17
+ fn=answer_question,
18
+ inputs=[
19
+ gr.Textbox(label="Сұрақ жазыңыз"),
20
+ gr.Textbox(label="Контекст (мәтін үзіндісі)")
21
+ ],
22
+ outputs="text",
23
+ title="Қазақша AI-ассистент",
24
+ description="Сұрақ қойып, контекст бойынша жауап алыңыз (Мәліметтер қоры тақырыбы)"
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  )
26
 
27
+ iface.launch()