Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
-
from transformers import pipeline, AutoTokenizer
|
4 |
import random
|
5 |
import time
|
6 |
|
@@ -11,7 +11,7 @@ st.set_page_config(
|
|
11 |
layout="wide"
|
12 |
)
|
13 |
|
14 |
-
# Load the
|
15 |
@st.cache_resource
|
16 |
def load_model():
|
17 |
try:
|
@@ -22,24 +22,20 @@ def load_model():
|
|
22 |
# Using a smaller, more efficient model that works well for question generation
|
23 |
model_name = "valhalla/t5-small-e2e-qg"
|
24 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
25 |
|
26 |
# Set device
|
27 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
print(f"Using device: {device}")
|
29 |
|
30 |
-
#
|
31 |
-
|
32 |
-
"text2text-generation",
|
33 |
-
model=model_name,
|
34 |
-
tokenizer=tokenizer,
|
35 |
-
device=device
|
36 |
-
)
|
37 |
|
38 |
-
return
|
39 |
except Exception as e:
|
40 |
st.error(f"Error loading model: {str(e)}")
|
41 |
print(f"Error details: {str(e)}")
|
42 |
-
return None
|
43 |
|
44 |
# Custom CSS
|
45 |
def load_css():
|
@@ -87,29 +83,64 @@ def load_css():
|
|
87 |
""", unsafe_allow_html=True)
|
88 |
|
89 |
# Function to generate questions from a passage
|
90 |
-
def generate_questions(
|
91 |
try:
|
92 |
-
#
|
93 |
-
max_length =
|
|
|
|
|
94 |
if len(text) > max_length:
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
# Process and extract questions and answers
|
106 |
questions_answers = []
|
107 |
-
for
|
108 |
-
|
109 |
-
|
110 |
-
# Handle different potential formats
|
111 |
if "?" in generated_text:
|
112 |
-
# Try to find question and answer
|
113 |
parts = generated_text.split("?", 1)
|
114 |
if len(parts) > 1:
|
115 |
question = parts[0].strip() + "?"
|
@@ -129,6 +160,7 @@ def generate_questions(pipeline, text, num_questions=5):
|
|
129 |
return questions_answers
|
130 |
except Exception as e:
|
131 |
st.error(f"Error generating questions: {str(e)}")
|
|
|
132 |
return []
|
133 |
|
134 |
# Function to create quiz from generated Q&A pairs
|
@@ -181,6 +213,48 @@ def create_quiz(questions_answers, num_options=4):
|
|
181 |
|
182 |
return quiz_items
|
183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
# Main app
|
185 |
def main():
|
186 |
load_css()
|
@@ -215,14 +289,19 @@ def main():
|
|
215 |
if passage and len(passage) > 50:
|
216 |
# Loading the model (with the cached resource)
|
217 |
with st.spinner("Loading AI model..."):
|
218 |
-
|
219 |
|
220 |
-
if
|
221 |
# Generate questions
|
222 |
with st.spinner("Generating questions..."):
|
223 |
# Add a small delay for UX
|
224 |
time.sleep(1)
|
225 |
-
questions_answers = generate_questions(
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
if questions_answers:
|
228 |
# Create quiz
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
import random
|
5 |
import time
|
6 |
|
|
|
11 |
layout="wide"
|
12 |
)
|
13 |
|
14 |
+
# Load the model with caching
|
15 |
@st.cache_resource
|
16 |
def load_model():
|
17 |
try:
|
|
|
22 |
# Using a smaller, more efficient model that works well for question generation
|
23 |
model_name = "valhalla/t5-small-e2e-qg"
|
24 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
25 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
26 |
|
27 |
# Set device
|
28 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
print(f"Using device: {device}")
|
30 |
|
31 |
+
# Move model to device
|
32 |
+
model = model.to(device)
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
return model, tokenizer, device
|
35 |
except Exception as e:
|
36 |
st.error(f"Error loading model: {str(e)}")
|
37 |
print(f"Error details: {str(e)}")
|
38 |
+
return None, None, None
|
39 |
|
40 |
# Custom CSS
|
41 |
def load_css():
|
|
|
83 |
""", unsafe_allow_html=True)
|
84 |
|
85 |
# Function to generate questions from a passage
|
86 |
+
def generate_questions(model, tokenizer, device, text, num_questions=5):
|
87 |
try:
|
88 |
+
# Process text in chunks if it's too long
|
89 |
+
max_length = 512
|
90 |
+
chunks = []
|
91 |
+
|
92 |
if len(text) > max_length:
|
93 |
+
# Simple chunking based on sentences
|
94 |
+
sentences = text.split('. ')
|
95 |
+
current_chunk = ""
|
96 |
+
|
97 |
+
for sentence in sentences:
|
98 |
+
if len(current_chunk) + len(sentence) < max_length:
|
99 |
+
current_chunk += sentence + ". "
|
100 |
+
else:
|
101 |
+
chunks.append(current_chunk)
|
102 |
+
current_chunk = sentence + ". "
|
103 |
+
|
104 |
+
if current_chunk:
|
105 |
+
chunks.append(current_chunk)
|
106 |
+
else:
|
107 |
+
chunks = [text]
|
108 |
|
109 |
+
all_generated_texts = []
|
110 |
+
|
111 |
+
# Process each chunk
|
112 |
+
for chunk in chunks:
|
113 |
+
inputs = tokenizer(chunk, return_tensors="pt", max_length=512, truncation=True)
|
114 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
115 |
+
|
116 |
+
# Generate with beam search for multiple diverse outputs
|
117 |
+
with torch.no_grad():
|
118 |
+
outputs = model.generate(
|
119 |
+
inputs["input_ids"],
|
120 |
+
max_length=64,
|
121 |
+
num_beams=5,
|
122 |
+
num_return_sequences=min(3, num_questions), # Generate up to 3 questions per chunk
|
123 |
+
temperature=1.0,
|
124 |
+
diversity_penalty=1.0,
|
125 |
+
num_beam_groups=5,
|
126 |
+
early_stopping=True
|
127 |
+
)
|
128 |
+
|
129 |
+
decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
130 |
+
all_generated_texts.extend(decoded_outputs)
|
131 |
+
|
132 |
+
# If we have enough questions, stop
|
133 |
+
if len(all_generated_texts) >= num_questions:
|
134 |
+
break
|
135 |
+
|
136 |
+
# Ensure we don't return more than num_questions
|
137 |
+
all_generated_texts = all_generated_texts[:num_questions]
|
138 |
|
139 |
# Process and extract questions and answers
|
140 |
questions_answers = []
|
141 |
+
for generated_text in all_generated_texts:
|
142 |
+
# Try to find question and answer
|
|
|
|
|
143 |
if "?" in generated_text:
|
|
|
144 |
parts = generated_text.split("?", 1)
|
145 |
if len(parts) > 1:
|
146 |
question = parts[0].strip() + "?"
|
|
|
160 |
return questions_answers
|
161 |
except Exception as e:
|
162 |
st.error(f"Error generating questions: {str(e)}")
|
163 |
+
print(f"Detailed error: {str(e)}")
|
164 |
return []
|
165 |
|
166 |
# Function to create quiz from generated Q&A pairs
|
|
|
213 |
|
214 |
return quiz_items
|
215 |
|
216 |
+
# Alternative question generation using simpler approach
|
217 |
+
def generate_questions_simple(text, num_questions=5):
|
218 |
+
try:
|
219 |
+
# Simple question generation for demonstration
|
220 |
+
# In a real app, you'd use a proper NLP model
|
221 |
+
|
222 |
+
# Extract sentences
|
223 |
+
sentences = text.split('.')
|
224 |
+
sentences = [s.strip() for s in sentences if len(s.strip()) > 20]
|
225 |
+
|
226 |
+
# Select random sentences to turn into questions
|
227 |
+
if len(sentences) < num_questions:
|
228 |
+
selected_sentences = sentences
|
229 |
+
else:
|
230 |
+
selected_sentences = random.sample(sentences, num_questions)
|
231 |
+
|
232 |
+
questions_answers = []
|
233 |
+
|
234 |
+
# Simple transformation of sentences into questions
|
235 |
+
for sentence in selected_sentences:
|
236 |
+
# Very simple question generation (not ideal but works as fallback)
|
237 |
+
words = sentence.split()
|
238 |
+
if len(words) < 5:
|
239 |
+
continue
|
240 |
+
|
241 |
+
# Extract key entities for answer
|
242 |
+
potential_answer = " ".join(words[-3:])
|
243 |
+
|
244 |
+
# Create question from beginning of sentence
|
245 |
+
question_words = words[:len(words)-3]
|
246 |
+
question = " ".join(question_words) + "?"
|
247 |
+
|
248 |
+
questions_answers.append({
|
249 |
+
"question": question,
|
250 |
+
"answer": potential_answer
|
251 |
+
})
|
252 |
+
|
253 |
+
return questions_answers
|
254 |
+
except Exception as e:
|
255 |
+
print(f"Error in simple question generation: {str(e)}")
|
256 |
+
return []
|
257 |
+
|
258 |
# Main app
|
259 |
def main():
|
260 |
load_css()
|
|
|
289 |
if passage and len(passage) > 50:
|
290 |
# Loading the model (with the cached resource)
|
291 |
with st.spinner("Loading AI model..."):
|
292 |
+
model, tokenizer, device = load_model()
|
293 |
|
294 |
+
if model and tokenizer and device:
|
295 |
# Generate questions
|
296 |
with st.spinner("Generating questions..."):
|
297 |
# Add a small delay for UX
|
298 |
time.sleep(1)
|
299 |
+
questions_answers = generate_questions(model, tokenizer, device, passage, num_questions)
|
300 |
+
|
301 |
+
# If primary method fails, try fallback approach
|
302 |
+
if not questions_answers:
|
303 |
+
st.warning("Advanced question generation failed. Using simple approach instead.")
|
304 |
+
questions_answers = generate_questions_simple(passage, num_questions)
|
305 |
|
306 |
if questions_answers:
|
307 |
# Create quiz
|