Murtaza249 commited on
Commit
c06e820
·
verified ·
1 Parent(s): 2285290

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -30
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import pipeline, AutoTokenizer
4
  import random
5
  import time
6
 
@@ -11,7 +11,7 @@ st.set_page_config(
11
  layout="wide"
12
  )
13
 
14
- # Load the pipeline with caching
15
  @st.cache_resource
16
  def load_model():
17
  try:
@@ -22,24 +22,20 @@ def load_model():
22
  # Using a smaller, more efficient model that works well for question generation
23
  model_name = "valhalla/t5-small-e2e-qg"
24
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
25
 
26
  # Set device
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  print(f"Using device: {device}")
29
 
30
- # Load pipeline
31
- qg_pipeline = pipeline(
32
- "text2text-generation",
33
- model=model_name,
34
- tokenizer=tokenizer,
35
- device=device
36
- )
37
 
38
- return qg_pipeline
39
  except Exception as e:
40
  st.error(f"Error loading model: {str(e)}")
41
  print(f"Error details: {str(e)}")
42
- return None
43
 
44
  # Custom CSS
45
  def load_css():
@@ -87,29 +83,64 @@ def load_css():
87
  """, unsafe_allow_html=True)
88
 
89
  # Function to generate questions from a passage
90
- def generate_questions(pipeline, text, num_questions=5):
91
  try:
92
- # Make sure text is not too long
93
- max_length = 1024
 
 
94
  if len(text) > max_length:
95
- text = text[:max_length]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- # Generate questions
98
- result = pipeline(
99
- text,
100
- max_length=128,
101
- num_return_sequences=num_questions,
102
- clean_up_tokenization_spaces=True
103
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  # Process and extract questions and answers
106
  questions_answers = []
107
- for item in result:
108
- generated_text = item["generated_text"]
109
-
110
- # Handle different potential formats
111
  if "?" in generated_text:
112
- # Try to find question and answer
113
  parts = generated_text.split("?", 1)
114
  if len(parts) > 1:
115
  question = parts[0].strip() + "?"
@@ -129,6 +160,7 @@ def generate_questions(pipeline, text, num_questions=5):
129
  return questions_answers
130
  except Exception as e:
131
  st.error(f"Error generating questions: {str(e)}")
 
132
  return []
133
 
134
  # Function to create quiz from generated Q&A pairs
@@ -181,6 +213,48 @@ def create_quiz(questions_answers, num_options=4):
181
 
182
  return quiz_items
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  # Main app
185
  def main():
186
  load_css()
@@ -215,14 +289,19 @@ def main():
215
  if passage and len(passage) > 50:
216
  # Loading the model (with the cached resource)
217
  with st.spinner("Loading AI model..."):
218
- qg_pipeline = load_model()
219
 
220
- if qg_pipeline:
221
  # Generate questions
222
  with st.spinner("Generating questions..."):
223
  # Add a small delay for UX
224
  time.sleep(1)
225
- questions_answers = generate_questions(qg_pipeline, passage, num_questions)
 
 
 
 
 
226
 
227
  if questions_answers:
228
  # Create quiz
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
4
  import random
5
  import time
6
 
 
11
  layout="wide"
12
  )
13
 
14
+ # Load the model with caching
15
  @st.cache_resource
16
  def load_model():
17
  try:
 
22
  # Using a smaller, more efficient model that works well for question generation
23
  model_name = "valhalla/t5-small-e2e-qg"
24
  tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
26
 
27
  # Set device
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  print(f"Using device: {device}")
30
 
31
+ # Move model to device
32
+ model = model.to(device)
 
 
 
 
 
33
 
34
+ return model, tokenizer, device
35
  except Exception as e:
36
  st.error(f"Error loading model: {str(e)}")
37
  print(f"Error details: {str(e)}")
38
+ return None, None, None
39
 
40
  # Custom CSS
41
  def load_css():
 
83
  """, unsafe_allow_html=True)
84
 
85
  # Function to generate questions from a passage
86
+ def generate_questions(model, tokenizer, device, text, num_questions=5):
87
  try:
88
+ # Process text in chunks if it's too long
89
+ max_length = 512
90
+ chunks = []
91
+
92
  if len(text) > max_length:
93
+ # Simple chunking based on sentences
94
+ sentences = text.split('. ')
95
+ current_chunk = ""
96
+
97
+ for sentence in sentences:
98
+ if len(current_chunk) + len(sentence) < max_length:
99
+ current_chunk += sentence + ". "
100
+ else:
101
+ chunks.append(current_chunk)
102
+ current_chunk = sentence + ". "
103
+
104
+ if current_chunk:
105
+ chunks.append(current_chunk)
106
+ else:
107
+ chunks = [text]
108
 
109
+ all_generated_texts = []
110
+
111
+ # Process each chunk
112
+ for chunk in chunks:
113
+ inputs = tokenizer(chunk, return_tensors="pt", max_length=512, truncation=True)
114
+ inputs = {k: v.to(device) for k, v in inputs.items()}
115
+
116
+ # Generate with beam search for multiple diverse outputs
117
+ with torch.no_grad():
118
+ outputs = model.generate(
119
+ inputs["input_ids"],
120
+ max_length=64,
121
+ num_beams=5,
122
+ num_return_sequences=min(3, num_questions), # Generate up to 3 questions per chunk
123
+ temperature=1.0,
124
+ diversity_penalty=1.0,
125
+ num_beam_groups=5,
126
+ early_stopping=True
127
+ )
128
+
129
+ decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
130
+ all_generated_texts.extend(decoded_outputs)
131
+
132
+ # If we have enough questions, stop
133
+ if len(all_generated_texts) >= num_questions:
134
+ break
135
+
136
+ # Ensure we don't return more than num_questions
137
+ all_generated_texts = all_generated_texts[:num_questions]
138
 
139
  # Process and extract questions and answers
140
  questions_answers = []
141
+ for generated_text in all_generated_texts:
142
+ # Try to find question and answer
 
 
143
  if "?" in generated_text:
 
144
  parts = generated_text.split("?", 1)
145
  if len(parts) > 1:
146
  question = parts[0].strip() + "?"
 
160
  return questions_answers
161
  except Exception as e:
162
  st.error(f"Error generating questions: {str(e)}")
163
+ print(f"Detailed error: {str(e)}")
164
  return []
165
 
166
  # Function to create quiz from generated Q&A pairs
 
213
 
214
  return quiz_items
215
 
216
+ # Alternative question generation using simpler approach
217
+ def generate_questions_simple(text, num_questions=5):
218
+ try:
219
+ # Simple question generation for demonstration
220
+ # In a real app, you'd use a proper NLP model
221
+
222
+ # Extract sentences
223
+ sentences = text.split('.')
224
+ sentences = [s.strip() for s in sentences if len(s.strip()) > 20]
225
+
226
+ # Select random sentences to turn into questions
227
+ if len(sentences) < num_questions:
228
+ selected_sentences = sentences
229
+ else:
230
+ selected_sentences = random.sample(sentences, num_questions)
231
+
232
+ questions_answers = []
233
+
234
+ # Simple transformation of sentences into questions
235
+ for sentence in selected_sentences:
236
+ # Very simple question generation (not ideal but works as fallback)
237
+ words = sentence.split()
238
+ if len(words) < 5:
239
+ continue
240
+
241
+ # Extract key entities for answer
242
+ potential_answer = " ".join(words[-3:])
243
+
244
+ # Create question from beginning of sentence
245
+ question_words = words[:len(words)-3]
246
+ question = " ".join(question_words) + "?"
247
+
248
+ questions_answers.append({
249
+ "question": question,
250
+ "answer": potential_answer
251
+ })
252
+
253
+ return questions_answers
254
+ except Exception as e:
255
+ print(f"Error in simple question generation: {str(e)}")
256
+ return []
257
+
258
  # Main app
259
  def main():
260
  load_css()
 
289
  if passage and len(passage) > 50:
290
  # Loading the model (with the cached resource)
291
  with st.spinner("Loading AI model..."):
292
+ model, tokenizer, device = load_model()
293
 
294
+ if model and tokenizer and device:
295
  # Generate questions
296
  with st.spinner("Generating questions..."):
297
  # Add a small delay for UX
298
  time.sleep(1)
299
+ questions_answers = generate_questions(model, tokenizer, device, passage, num_questions)
300
+
301
+ # If primary method fails, try fallback approach
302
+ if not questions_answers:
303
+ st.warning("Advanced question generation failed. Using simple approach instead.")
304
+ questions_answers = generate_questions_simple(passage, num_questions)
305
 
306
  if questions_answers:
307
  # Create quiz