File size: 6,295 Bytes
5f2fd70
 
 
636399f
8df3bfc
636399f
6805517
5f2fd70
 
636399f
 
448c445
 
636399f
448c445
8df3bfc
 
448c445
 
636399f
 
25b7ae1
636399f
 
 
 
 
 
 
5f2fd70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636399f
5f2fd70
 
 
636399f
 
5f2fd70
 
 
636399f
 
25b7ae1
 
636399f
 
25b7ae1
636399f
25b7ae1
 
 
 
 
 
 
636399f
25b7ae1
636399f
 
25b7ae1
 
8df3bfc
25b7ae1
5f2fd70
636399f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f2fd70
636399f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
import logging
import re

app = FastAPI()

# Enable CORS if needed
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # In production, restrict this to your frontend URL
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

####################################
# Text Generation Endpoint
####################################

TEXT_MODEL_NAME = "aubmindlab/aragpt2-medium"
text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
text_model = AutoModelForCausalLM.from_pretrained(TEXT_MODEL_NAME)

general_prompt_template = """
أنت الآن نموذج لغة مخصص لتوليد نصوص عربية تعليمية بناءً على المادة والمستوى التعليمي. سيتم إعطاؤك مادة تعليمية ومستوى تعليمي، وعليك إنشاء نص مناسب بناءً على ذلك. النص يجب أن يكون:
1. واضحًا وسهل الفهم.
2. مناسبًا للمستوى التعليمي المحدد.
3. مرتبطًا بالمادة التعليمية المطلوبة.
4. قصيرًا ومباشرًا.
### أمثلة:
1. المادة: العلوم
   المستوى: الابتدائي
   النص: النباتات كائنات حية تحتاج إلى الماء والهواء وضوء الشمس لتنمو. بعض النباتات تنتج أزهارًا وفواكه. النباتات تساعدنا في الحصول على الأكسجين.
2. المادة: التاريخ
   المستوى: المتوسط
   النص: التاريخ هو دراسة الماضي وأحداثه المهمة. من خلال التاريخ، نتعلم عن الحضارات القديمة مثل الحضارة الفرعونية والحضارة الإسلامية. التاريخ يساعدنا على فهم تطور البشرية.
3. المادة: الجغرافيا
   المستوى: المتوسط
   النص: الجغرافيا هي دراسة الأرض وخصائصها. نتعلم عن القارات والمحيطات والجبال. الجغرافيا تساعدنا على فهم العالم من حولنا.
---
المادة: {المادة}
المستوى: {المستوى}
اكتب نصًا مناسبًا بناءً على المادة والمستوى المحددين. ركّز على جعل النص بسيطًا وواضحًا للمستوى المطلوب.
"""

class GenerateTextRequest(BaseModel):
    المادة: str
    المستوى: str

@app.post("/generate-text")
def generate_text(request: GenerateTextRequest):
    المادة = request.المادة
    المستوى = request.المستوى

    if not المادة or not المستوى:
        raise HTTPException(status_code=400, detail="المادة والمستوى مطلوبان.")

    try:
        prompt = general_prompt_template.format(المادة=المادة, المستوى=المستوى)
        inputs = text_tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
        with torch.no_grad():
            outputs = text_model.generate(
                inputs.input_ids,
                max_length=300,
                num_return_sequences=1,
                temperature=0.1,
                top_p=0.9,
                do_sample=True,
            )
        generated_text = text_tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Remove the prompt from the generated text
        generated_text = generated_text.replace(prompt, "").strip()
        logger.info(f"Generated text: {generated_text}")
        return {"generated_text": generated_text}
    except Exception as e:
        logger.error(f"Error during text generation: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")

####################################
# Question & Answer Generation Endpoint
####################################

QA_MODEL_NAME = "Mihakram/AraT5-base-question-generation"
qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME)
qa_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_NAME)

def extract_answer(context: str) -> str:
    """Extract the first sentence (or a key phrase) from the context."""
    sentences = re.split(r'[.!؟]', context)
    sentences = [s.strip() for s in sentences if s.strip()]
    return sentences[0] if sentences else context

def get_question(context: str, answer: str) -> str:
    """Generate a question based on the context and the candidate answer."""
    text = "النص: " + context + " " + "الإجابة: " + answer + " </s>"
    text_encoding = qa_tokenizer.encode_plus(text, return_tensors="pt")
    qa_model.eval()
    generated_ids = qa_model.generate(
        input_ids=text_encoding['input_ids'],
        attention_mask=text_encoding['attention_mask'],
        max_length=64,
        num_beams=5,
        num_return_sequences=1
    )
    question = qa_tokenizer.decode(generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
    # Optionally remove a leading phrase if present
    question = question.replace('question: ', '').strip()
    return question

def generate_question_answer(context: str):
    answer = extract_answer(context)
    question = get_question(context, answer)
    return question, answer

class GenerateQARequest(BaseModel):
    text: str

@app.post("/generate-qa")
def generate_qa(request: GenerateQARequest):
    context = request.text
    if not context:
        raise HTTPException(status_code=400, detail="Text is required.")
    try:
        question, answer = generate_question_answer(context)
        logger.info(f"Generated QA -> Question: {question}, Answer: {answer}")
        return {"question": question, "answer": answer}
    except Exception as e:
        logger.error(f"Error during QA generation: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error during QA generation: {str(e)}")

@app.get("/")
def read_root():
    return {"message": "Welcome to the Arabic Text Generation API!"}