Murtaza249 commited on
Commit
d1b1252
Β·
verified Β·
1 Parent(s): 44a679d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +295 -55
app.py CHANGED
@@ -1,72 +1,312 @@
1
  import streamlit as st
2
- from transformers import pipeline
 
3
  import random
 
 
 
 
 
 
 
 
4
 
5
  # Load the pipeline with caching
6
  @st.cache_resource
7
- def load_pipeline():
8
- return pipeline("text2text-generation", model="valhalla/t5-small-e2e-qg")
9
-
10
- # Initialize the pipeline
11
- qg_pipeline = load_pipeline()
12
-
13
- # UI Header
14
- st.markdown("# Text-to-Quiz Generator")
15
- st.markdown("Enter a passage below and let AI create an interactive quiz for you!")
16
- st.markdown("---")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # Input section
19
- passage = st.text_area("Paste your text here:", height=150)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- if st.button("Generate Quiz"):
22
- with st.spinner("Generating your quiz..."):
23
- # Generate question-answer pairs (adjust based on actual model output format)
24
- generated = qg_pipeline(passage, max_length=512)
 
 
 
25
 
26
- # For this example, assume the model outputs a list of strings like "Q: ... A: ..."
27
- # Parse the output into questions and answers (modify this based on your model's actual output)
28
- quiz_data = []
29
- for item in generated:
30
- if "Q:" in item["generated_text"] and "A:" in item["generated_text"]:
31
- q, a = item["generated_text"].split("A:")
32
- question = q.replace("Q:", "").strip()
33
- answer = a.strip()
34
- quiz_data.append({"question": question, "answer": answer})
35
 
36
- if not quiz_data:
37
- st.error("No valid questions generated. Please try a different passage.")
38
- else:
39
- # Store quiz data in session state
40
- st.session_state["quiz_data"] = quiz_data
41
- st.session_state["user_answers"] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Display quiz
44
- if "quiz_data" in st.session_state:
45
- st.markdown("### Your Quiz")
46
- questions = [item["question"] for item in st.session_state["quiz_data"]]
47
- answers = [item["answer"] for item in st.session_state["quiz_data"]]
 
 
 
 
 
 
48
 
49
- # Generate distractors for each question
50
- for i, question in enumerate(questions):
51
- correct_answer = answers[i]
52
- distractors = [a for j, a in enumerate(answers) if j != i]
53
- options = [correct_answer] + random.sample(distractors, min(3, len(distractors)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  random.shuffle(options)
55
 
56
- st.write(f"**Question {i+1}:** {question}")
57
- st.session_state["user_answers"][question] = st.radio(
58
- "Select your answer:", options, key=f"q{i}"
59
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- # Submit button
62
- if st.button("Submit Answers"):
63
  score = 0
64
- for i, question in enumerate(questions):
65
- user_answer = st.session_state["user_answers"][question]
66
- correct_answer = answers[i]
67
- if user_answer == correct_answer:
68
- st.success(f"**Question {i+1}:** Correct! The answer is '{correct_answer}'.")
 
 
 
 
69
  score += 1
 
70
  else:
71
- st.error(f"**Question {i+1}:** Incorrect. The correct answer is '{correct_answer}', not '{user_answer}'.")
72
- st.markdown(f"### Your Score: {score}/{len(questions)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from transformers import pipeline, AutoTokenizer
4
  import random
5
+ import time
6
+
7
+ # Configure page
8
+ st.set_page_config(
9
+ page_title="Text-to-Quiz Generator",
10
+ page_icon="🧠",
11
+ layout="wide"
12
+ )
13
 
14
  # Load the pipeline with caching
15
  @st.cache_resource
16
+ def load_model():
17
+ try:
18
+ # Check if PyTorch is available
19
+ print(f"PyTorch version: {torch.__version__}")
20
+ print(f"CUDA available: {torch.cuda.is_available()}")
21
+
22
+ # Using a smaller, more efficient model that works well for question generation
23
+ model_name = "valhalla/t5-small-e2e-qg"
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+
26
+ # Set device
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ print(f"Using device: {device}")
29
+
30
+ # Load pipeline
31
+ qg_pipeline = pipeline(
32
+ "text2text-generation",
33
+ model=model_name,
34
+ tokenizer=tokenizer,
35
+ device=device
36
+ )
37
+
38
+ return qg_pipeline
39
+ except Exception as e:
40
+ st.error(f"Error loading model: {str(e)}")
41
+ print(f"Error details: {str(e)}")
42
+ return None
43
 
44
+ # Custom CSS
45
+ def load_css():
46
+ st.markdown("""
47
+ <style>
48
+ .main {
49
+ padding: 2rem;
50
+ }
51
+ .question-box {
52
+ background-color: #f0f7ff;
53
+ padding: 1.5rem;
54
+ border-radius: 10px;
55
+ margin-bottom: 1rem;
56
+ border-left: 5px solid #4361ee;
57
+ }
58
+ .stButton button {
59
+ background-color: #4361ee;
60
+ color: white;
61
+ padding: 0.5rem 1rem;
62
+ border-radius: 5px;
63
+ border: none;
64
+ font-weight: bold;
65
+ }
66
+ .title-box {
67
+ padding: 1rem;
68
+ border-radius: 5px;
69
+ margin-bottom: 2rem;
70
+ text-align: center;
71
+ background: linear-gradient(90deg, #4361ee 0%, #3a0ca3 100%);
72
+ color: white;
73
+ }
74
+ .score-box {
75
+ font-size: 1.5rem;
76
+ padding: 1rem;
77
+ border-radius: 5px;
78
+ text-align: center;
79
+ font-weight: bold;
80
+ }
81
+ .feedback {
82
+ padding: 1rem;
83
+ border-radius: 5px;
84
+ margin: 1rem 0;
85
+ }
86
+ </style>
87
+ """, unsafe_allow_html=True)
88
 
89
+ # Function to generate questions from a passage
90
+ def generate_questions(pipeline, text, num_questions=5):
91
+ try:
92
+ # Make sure text is not too long
93
+ max_length = 1024
94
+ if len(text) > max_length:
95
+ text = text[:max_length]
96
 
97
+ # Generate questions
98
+ result = pipeline(
99
+ text,
100
+ max_length=128,
101
+ num_return_sequences=num_questions,
102
+ clean_up_tokenization_spaces=True
103
+ )
 
 
104
 
105
+ # Process and extract questions and answers
106
+ questions_answers = []
107
+ for item in result:
108
+ generated_text = item["generated_text"]
109
+
110
+ # Handle different potential formats
111
+ if "?" in generated_text:
112
+ # Try to find question and answer
113
+ parts = generated_text.split("?", 1)
114
+ if len(parts) > 1:
115
+ question = parts[0].strip() + "?"
116
+ answer = parts[1].strip()
117
+
118
+ # Clean up answer if it starts with common patterns
119
+ for prefix in ["answer:", "a:", " - "]:
120
+ if answer.lower().startswith(prefix):
121
+ answer = answer[len(prefix):].strip()
122
+
123
+ if question and answer and len(question) > 10:
124
+ questions_answers.append({
125
+ "question": question,
126
+ "answer": answer
127
+ })
128
+
129
+ return questions_answers
130
+ except Exception as e:
131
+ st.error(f"Error generating questions: {str(e)}")
132
+ return []
133
 
134
+ # Function to create quiz from generated Q&A pairs
135
+ def create_quiz(questions_answers, num_options=4):
136
+ quiz_items = []
137
+
138
+ # First filter out very short answers and duplicates
139
+ filtered_qa = []
140
+ seen_questions = set()
141
+
142
+ for qa in questions_answers:
143
+ q = qa["question"].strip()
144
+ a = qa["answer"].strip()
145
 
146
+ # Skip very short answers
147
+ if len(a) < 2 or len(q) < 10:
148
+ continue
149
+
150
+ # Skip duplicate questions
151
+ q_lower = q.lower()
152
+ if q_lower in seen_questions:
153
+ continue
154
+
155
+ seen_questions.add(q_lower)
156
+ filtered_qa.append({"question": q, "answer": a})
157
+
158
+ # Use the filtered Q&A pairs
159
+ all_answers = [qa["answer"] for qa in filtered_qa]
160
+
161
+ for i, qa in enumerate(filtered_qa):
162
+ correct_answer = qa["answer"]
163
+
164
+ # Create distractors by selecting random answers from other questions
165
+ other_answers = [a for a in all_answers if a != correct_answer]
166
+ if other_answers:
167
+ # Select random distractors
168
+ num_distractors = min(num_options - 1, len(other_answers))
169
+ distractors = random.sample(other_answers, num_distractors)
170
+
171
+ # Combine correct answer and distractors
172
+ options = [correct_answer] + distractors
173
  random.shuffle(options)
174
 
175
+ quiz_items.append({
176
+ "id": i,
177
+ "question": qa["question"],
178
+ "correct_answer": correct_answer,
179
+ "options": options
180
+ })
181
+
182
+ return quiz_items
183
+
184
+ # Main app
185
+ def main():
186
+ load_css()
187
+
188
+ # App title
189
+ st.markdown('<div class="title-box"><h1>🧠 Text-to-Quiz Generator</h1></div>', unsafe_allow_html=True)
190
+
191
+ col1, col2 = st.columns([2, 1])
192
+
193
+ with col1:
194
+ st.markdown("### Enter a passage to generate quiz questions")
195
+ passage = st.text_area(
196
+ "Paste your text here:",
197
+ height=200,
198
+ placeholder="Enter a paragraph or article here to generate quiz questions..."
199
+ )
200
+
201
+ with col2:
202
+ st.markdown("### Settings")
203
+ num_questions = st.slider("Number of questions to generate", 3, 10, 5)
204
+ st.markdown("---")
205
+ st.markdown("""
206
+ **Tips for best results:**
207
+ - Use clear, factual content
208
+ - Include specific details
209
+ - Text length: 100-500 words works best
210
+ - Educational content works better than narrative
211
+ """)
212
+
213
+ # Generate Quiz button
214
+ if st.button("🧠 Generate Quiz"):
215
+ if passage and len(passage) > 50:
216
+ # Loading the model (with the cached resource)
217
+ with st.spinner("Loading AI model..."):
218
+ qg_pipeline = load_model()
219
+
220
+ if qg_pipeline:
221
+ # Generate questions
222
+ with st.spinner("Generating questions..."):
223
+ # Add a small delay for UX
224
+ time.sleep(1)
225
+ questions_answers = generate_questions(qg_pipeline, passage, num_questions)
226
+
227
+ if questions_answers:
228
+ # Create quiz
229
+ quiz_items = create_quiz(questions_answers)
230
+
231
+ if quiz_items:
232
+ # Store in session state
233
+ st.session_state.quiz_items = quiz_items
234
+ st.session_state.user_answers = {}
235
+ st.session_state.quiz_submitted = False
236
+ st.session_state.show_explanations = False
237
+ st.experimental_rerun()
238
+ else:
239
+ st.error("Couldn't create valid quiz questions. Please try a different text or add more content.")
240
+ else:
241
+ st.error("Failed to generate questions. Please try a different passage.")
242
+ else:
243
+ st.error("Failed to load the question generation model. Please try again.")
244
+ else:
245
+ st.warning("Please enter a longer passage (at least 50 characters).")
246
+
247
+ # Display quiz if available in session state
248
+ if "quiz_items" in st.session_state and st.session_state.quiz_items:
249
+ st.markdown("---")
250
+ st.markdown("## Your Quiz")
251
+
252
+ quiz_items = st.session_state.quiz_items
253
+
254
+ # Create a form for the quiz
255
+ with st.form("quiz_form"):
256
+ for i, item in enumerate(quiz_items):
257
+ st.markdown(f'<div class="question-box"><h3>Question {i+1}</h3><p>{item["question"]}</p></div>', unsafe_allow_html=True)
258
+
259
+ key = f"question_{item['id']}"
260
+ st.session_state.user_answers[key] = st.radio(
261
+ "Select your answer:",
262
+ options=item["options"],
263
+ key=key
264
+ )
265
+
266
+ submit_button = st.form_submit_button("Submit Answers")
267
+
268
+ if submit_button:
269
+ st.session_state.quiz_submitted = True
270
 
271
+ # Show results if quiz was submitted
272
+ if st.session_state.quiz_submitted:
273
  score = 0
274
+
275
+ st.markdown("## Quiz Results")
276
+
277
+ for i, item in enumerate(quiz_items):
278
+ key = f"question_{item['id']}"
279
+ user_answer = st.session_state.user_answers[key]
280
+ correct = user_answer == item["correct_answer"]
281
+
282
+ if correct:
283
  score += 1
284
+ st.markdown(f'<div class="feedback" style="background-color: #d4edda; border-left: 5px solid #28a745;"><h4>Question {i+1}: Correct! βœ…</h4><p><strong>Your answer:</strong> {user_answer}</p></div>', unsafe_allow_html=True)
285
  else:
286
+ st.markdown(f'<div class="feedback" style="background-color: #f8d7da; border-left: 5px solid #dc3545;"><h4>Question {i+1}: Incorrect ❌</h4><p><strong>Your answer:</strong> {user_answer}<br><strong>Correct answer:</strong> {item["correct_answer"]}</p></div>', unsafe_allow_html=True)
287
+
288
+ # Show score
289
+ percentage = (score / len(quiz_items)) * 100
290
+
291
+ if percentage >= 80:
292
+ color = "#28a745" # Green
293
+ message = "Excellent! πŸ†"
294
+ elif percentage >= 60:
295
+ color = "#17a2b8" # Blue
296
+ message = "Good job! πŸ‘"
297
+ else:
298
+ color = "#ffc107" # Yellow
299
+ message = "Keep practicing! πŸ“š"
300
+
301
+ st.markdown(f'<div class="score-box" style="background-color: {color}15; border-left: 5px solid {color};">{message}<br>Your Score: {score}/{len(quiz_items)} ({percentage:.1f}%)</div>', unsafe_allow_html=True)
302
+
303
+ # Restart button
304
+ if st.button("Generate Another Quiz"):
305
+ # Clear session state and rerun
306
+ for key in ["quiz_items", "user_answers", "quiz_submitted", "show_explanations"]:
307
+ if key in st.session_state:
308
+ del st.session_state[key]
309
+ st.experimental_rerun()
310
+
311
+ if __name__ == "__main__":
312
+ main()