Murtaza249 commited on
Commit
4104208
·
verified ·
1 Parent(s): 5f345d2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering
3
+ import torch
4
+ import random
5
+
6
+ # Cache the model loading to speed up the app
7
+ @st.cache(allow_output_mutation=True)
8
+ def load_models():
9
+ # Load NER pipeline
10
+ ner = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", tokenizer="dbmdz/bert-large-cased-finetuned-conll03-english")
11
+ # Load question generation model
12
+ qg_tokenizer = AutoTokenizer.from_pretrained("valhalla/t5-small-e2e-qg")
13
+ qg_model = AutoModelForSeq2SeqLM.from_pretrained("valhalla/t5-small-e2e-qg")
14
+ # Load question answering model
15
+ qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
16
+ qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
17
+ return ner, qg_tokenizer, qg_model, qa_tokenizer, qa_model
18
+
19
+ ner, qg_tokenizer, qg_model, qa_tokenizer, qa_model = load_models()
20
+
21
+ def generate_questions(text, num_questions=5):
22
+ # Generate questions using the question generation model
23
+ input_text = f"generate questions: {text}"
24
+ input_ids = qg_tokenizer.encode(input_text, return_tensors="pt")
25
+ outputs = qg_model.generate(input_ids, max_length=256, num_return_sequences=num_questions)
26
+ questions = [qg_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
27
+ return questions
28
+
29
+ def get_answer(question, context):
30
+ # Get the answer using the question answering model
31
+ inputs = qa_tokenizer.encode_plus(question, context, return_tensors="pt")
32
+ with torch.no_grad():
33
+ outputs = qa_model(**inputs)
34
+ answer_start = torch.argmax(outputs.start_logits)
35
+ answer_end = torch.argmax(outputs.end_logits) + 1
36
+ answer = qa_tokenizer.convert_tokens_to_string(qa_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
37
+ return answer
38
+
39
+ def get_entities(text):
40
+ # Extract entities using NER
41
+ entities = ner(text)
42
+ entity_dict = {}
43
+ for entity in entities:
44
+ word = entity['word']
45
+ entity_type = entity['entity']
46
+ if entity_type not in entity_dict:
47
+ entity_dict[entity_type] = set()
48
+ entity_dict[entity_type].add(word)
49
+ return entity_dict
50
+
51
+ def generate_distractors(answer, entity_dict, answer_type, num_distractors=3):
52
+ if answer_type in entity_dict and len(entity_dict[answer_type]) > 1:
53
+ distractors = list(entity_dict[answer_type] - {answer})
54
+ if len(distractors) >= num_distractors:
55
+ return random.sample(distractors, num_distractors)
56
+ # Fallback: select random words from the text
57
+ words = text.split()
58
+ distractors = random.sample(words, min(len(words), num_distractors))
59
+ return [d for d in distractors if d != answer][:num_distractors]
60
+
61
+ # Streamlit app
62
+ st.title("Text-to-Quiz Generator")
63
+
64
+ text = st.text_area("Enter the text to generate a quiz from:", height=200)
65
+ if st.button("Generate Quiz"):
66
+ if text:
67
+ # Extract entities
68
+ entity_dict = get_entities(text)
69
+ # Generate questions
70
+ questions = generate_questions(text)
71
+ quiz = []
72
+ for question in questions:
73
+ answer = get_answer(question, text)
74
+ # Determine answer type
75
+ answer_type = None
76
+ for ent_type, ents in entity_dict.items():
77
+ if answer in ents:
78
+ answer_type = ent_type
79
+ break
80
+ if answer_type:
81
+ distractors = generate_distractors(answer, entity_dict, answer_type)
82
+ options = [answer] + distractors
83
+ random.shuffle(options)
84
+ quiz.append({"question": question, "options": options, "answer": answer})
85
+
86
+ if quiz:
87
+ st.subheader("Generated Quiz")
88
+ for i, q in enumerate(quiz, 1):
89
+ st.write(f"**Question {i}:** {q['question']}")
90
+ user_answer = st.radio("Choose an answer:", q['options'], key=f"q{i}")
91
+ if st.button("Check Answer", key=f"check{i}"):
92
+ if user_answer == q['answer']:
93
+ st.success("Correct!")
94
+ else:
95
+ st.error(f"Incorrect. The correct answer is: {q['answer']}")
96
+ else:
97
+ st.warning("Could not generate enough questions. Try with a different text.")
98
+ else:
99
+ st.warning("Please enter some text to generate the quiz.")