Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,99 +1,46 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import pipeline
|
3 |
-
import torch
|
4 |
import random
|
5 |
|
6 |
-
|
7 |
-
@st.cache(allow_output_mutation=True)
|
8 |
-
def load_models():
|
9 |
-
# Load NER pipeline
|
10 |
-
ner = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", tokenizer="dbmdz/bert-large-cased-finetuned-conll03-english")
|
11 |
-
# Load question generation model
|
12 |
-
qg_tokenizer = AutoTokenizer.from_pretrained("valhalla/t5-small-e2e-qg")
|
13 |
-
qg_model = AutoModelForSeq2SeqLM.from_pretrained("valhalla/t5-small-e2e-qg")
|
14 |
-
# Load question answering model
|
15 |
-
qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
|
16 |
-
qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
|
17 |
-
return ner, qg_tokenizer, qg_model, qa_tokenizer, qa_model
|
18 |
|
19 |
-
|
|
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
input_text = f"generate questions: {text}"
|
24 |
-
input_ids = qg_tokenizer.encode(input_text, return_tensors="pt")
|
25 |
-
outputs = qg_model.generate(input_ids, max_length=256, num_return_sequences=num_questions)
|
26 |
-
questions = [qg_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
|
27 |
-
return questions
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
outputs = qa_model(**inputs)
|
34 |
-
answer_start = torch.argmax(outputs.start_logits)
|
35 |
-
answer_end = torch.argmax(outputs.end_logits) + 1
|
36 |
-
answer = qa_tokenizer.convert_tokens_to_string(qa_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
|
37 |
-
return answer
|
38 |
|
39 |
-
|
40 |
-
# Extract entities using NER
|
41 |
-
entities = ner(text)
|
42 |
-
entity_dict = {}
|
43 |
-
for entity in entities:
|
44 |
-
word = entity['word']
|
45 |
-
entity_type = entity['entity']
|
46 |
-
if entity_type not in entity_dict:
|
47 |
-
entity_dict[entity_type] = set()
|
48 |
-
entity_dict[entity_type].add(word)
|
49 |
-
return entity_dict
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
distractors = list(entity_dict[answer_type] - {answer})
|
54 |
-
if len(distractors) >= num_distractors:
|
55 |
-
return random.sample(distractors, num_distractors)
|
56 |
-
# Fallback: select random words from the text
|
57 |
-
words = text.split()
|
58 |
-
distractors = random.sample(words, min(len(words), num_distractors))
|
59 |
-
return [d for d in distractors if d != answer][:num_distractors]
|
60 |
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
text = st.text_area("Enter the text to generate a quiz from:", height=200)
|
65 |
-
if st.button("Generate Quiz"):
|
66 |
-
if text:
|
67 |
-
# Extract entities
|
68 |
-
entity_dict = get_entities(text)
|
69 |
-
# Generate questions
|
70 |
-
questions = generate_questions(text)
|
71 |
-
quiz = []
|
72 |
-
for question in questions:
|
73 |
-
answer = get_answer(question, text)
|
74 |
-
# Determine answer type
|
75 |
-
answer_type = None
|
76 |
-
for ent_type, ents in entity_dict.items():
|
77 |
-
if answer in ents:
|
78 |
-
answer_type = ent_type
|
79 |
-
break
|
80 |
-
if answer_type:
|
81 |
-
distractors = generate_distractors(answer, entity_dict, answer_type)
|
82 |
-
options = [answer] + distractors
|
83 |
-
random.shuffle(options)
|
84 |
-
quiz.append({"question": question, "options": options, "answer": answer})
|
85 |
-
|
86 |
-
if quiz:
|
87 |
-
st.subheader("Generated Quiz")
|
88 |
-
for i, q in enumerate(quiz, 1):
|
89 |
-
st.write(f"**Question {i}:** {q['question']}")
|
90 |
-
user_answer = st.radio("Choose an answer:", q['options'], key=f"q{i}")
|
91 |
-
if st.button("Check Answer", key=f"check{i}"):
|
92 |
-
if user_answer == q['answer']:
|
93 |
-
st.success("Correct!")
|
94 |
-
else:
|
95 |
-
st.error(f"Incorrect. The correct answer is: {q['answer']}")
|
96 |
-
else:
|
97 |
-
st.warning("Could not generate enough questions. Try with a different text.")
|
98 |
-
else:
|
99 |
-
st.warning("Please enter some text to generate the quiz.")
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import pipeline
|
|
|
3 |
import random
|
4 |
|
5 |
+
st.set_page_config(page_title="Text-to-Quiz Generator", layout="centered")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
st.title("π Text-to-Quiz Generator")
|
8 |
+
st.write("Enter a passage below and let AI generate a quiz for you!")
|
9 |
|
10 |
+
# Text input
|
11 |
+
text_input = st.text_area("Paste your text here:", height=200)
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
# Load model pipeline from Hugging Face
|
14 |
+
@st.cache_resource
|
15 |
+
def load_pipeline():
|
16 |
+
return pipeline("e2e-qg") # End-to-End Question Generation (QG)
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
qg_pipeline = load_pipeline()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
+
if st.button("Generate Quiz") and text_input:
|
21 |
+
st.subheader("π Generated Quiz")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
try:
|
24 |
+
questions = qg_pipeline(text_input)
|
25 |
+
|
26 |
+
for i, item in enumerate(questions):
|
27 |
+
question = item['question']
|
28 |
+
correct_answer = item['answer']
|
29 |
+
# Generate fake choices
|
30 |
+
choices = [correct_answer] + [f"Option {chr(66 + j)}" for j in range(3)]
|
31 |
+
random.shuffle(choices)
|
32 |
+
|
33 |
+
st.markdown(f"**Q{i+1}. {question}**")
|
34 |
+
user_answer = st.radio("Choose an answer:", choices, key=i)
|
35 |
+
|
36 |
+
if user_answer == correct_answer:
|
37 |
+
st.success("β
Correct!")
|
38 |
+
else:
|
39 |
+
st.error(f"β Incorrect. Correct answer: **{correct_answer}**")
|
40 |
+
st.markdown("---")
|
41 |
+
|
42 |
+
except Exception as e:
|
43 |
+
st.error(f"Something went wrong: {e}")
|
44 |
+
else:
|
45 |
+
st.info("π Paste some text and click 'Generate Quiz'")
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|