Qiwei97 commited on
Commit
dbe4e7d
·
1 Parent(s): 0f64541

Create questiongenerator.py

Browse files
Files changed (1) hide show
  1. questiongenerator.py +354 -0
questiongenerator.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+ import spacy
7
+ import re
8
+ import random
9
+ import json
10
+ import en_core_web_sm
11
+ from transformers import (
12
+ AutoTokenizer,
13
+ AutoModelForSeq2SeqLM,
14
+ AutoModelForSequenceClassification,
15
+ )
16
+
17
+
18
+ class QuestionGenerator:
19
+ def __init__(self, model_dir=None):
20
+
21
+ QG_PRETRAINED = "iarfmoose/t5-base-question-generator"
22
+ self.ANSWER_TOKEN = "<answer>"
23
+ self.CONTEXT_TOKEN = "<context>"
24
+ self.SEQ_LENGTH = 512
25
+
26
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+
28
+ self.qg_tokenizer = AutoTokenizer.from_pretrained(QG_PRETRAINED, use_fast=False)
29
+ self.qg_model = AutoModelForSeq2SeqLM.from_pretrained(QG_PRETRAINED)
30
+ self.qg_model.to(self.device)
31
+
32
+ self.qa_evaluator = QAEvaluator(model_dir)
33
+
34
+ def generate(
35
+ self, article, use_evaluator=True, num_questions=None, answer_style="all"
36
+ ):
37
+
38
+ print("Generating questions...\n")
39
+
40
+ qg_inputs, qg_answers = self.generate_qg_inputs(article, answer_style)
41
+ generated_questions = self.generate_questions_from_inputs(qg_inputs)
42
+
43
+ message = "{} questions doesn't match {} answers".format(
44
+ len(generated_questions), len(qg_answers)
45
+ )
46
+ assert len(generated_questions) == len(qg_answers), message
47
+
48
+ if use_evaluator:
49
+
50
+ print("Evaluating QA pairs...\n")
51
+
52
+ encoded_qa_pairs = self.qa_evaluator.encode_qa_pairs(
53
+ generated_questions, qg_answers
54
+ )
55
+ scores = self.qa_evaluator.get_scores(encoded_qa_pairs)
56
+ if num_questions:
57
+ qa_list = self._get_ranked_qa_pairs(
58
+ generated_questions, qg_answers, scores, num_questions
59
+ )
60
+ else:
61
+ qa_list = self._get_ranked_qa_pairs(
62
+ generated_questions, qg_answers, scores
63
+ )
64
+
65
+ else:
66
+ print("Skipping evaluation step.\n")
67
+ qa_list = self._get_all_qa_pairs(generated_questions, qg_answers)
68
+
69
+ return qa_list
70
+
71
+ def generate_qg_inputs(self, text, answer_style):
72
+
73
+ VALID_ANSWER_STYLES = ["all", "sentences", "multiple_choice"]
74
+
75
+ if answer_style not in VALID_ANSWER_STYLES:
76
+ raise ValueError(
77
+ "Invalid answer style {}. Please choose from {}".format(
78
+ answer_style, VALID_ANSWER_STYLES
79
+ )
80
+ )
81
+
82
+ inputs = []
83
+ answers = []
84
+
85
+ if answer_style == "sentences" or answer_style == "all":
86
+ segments = self._split_into_segments(text)
87
+ for segment in segments:
88
+ sentences = self._split_text(segment)
89
+ prepped_inputs, prepped_answers = self._prepare_qg_inputs(
90
+ sentences, segment
91
+ )
92
+ inputs.extend(prepped_inputs)
93
+ answers.extend(prepped_answers)
94
+
95
+ if answer_style == "multiple_choice" or answer_style == "all":
96
+ sentences = self._split_text(text)
97
+ prepped_inputs, prepped_answers = self._prepare_qg_inputs_MC(sentences)
98
+ inputs.extend(prepped_inputs)
99
+ answers.extend(prepped_answers)
100
+
101
+ return inputs, answers
102
+
103
+ def generate_questions_from_inputs(self, qg_inputs):
104
+ generated_questions = []
105
+
106
+ for qg_input in qg_inputs:
107
+ question = self._generate_question(qg_input)
108
+ generated_questions.append(question)
109
+
110
+ return generated_questions
111
+
112
+ def _split_text(self, text):
113
+ MAX_SENTENCE_LEN = 128
114
+
115
+ sentences = re.findall(".*?[.!\?]", text)
116
+
117
+ cut_sentences = []
118
+ for sentence in sentences:
119
+ if len(sentence) > MAX_SENTENCE_LEN:
120
+ cut_sentences.extend(re.split("[,;:)]", sentence))
121
+ # temporary solution to remove useless post-quote sentence fragments
122
+ cut_sentences = [s for s in sentences if len(s.split(" ")) > 5]
123
+ sentences = sentences + cut_sentences
124
+
125
+ return list(set([s.strip(" ") for s in sentences]))
126
+
127
+ def _split_into_segments(self, text):
128
+ MAX_TOKENS = 490
129
+
130
+ paragraphs = text.split("\n")
131
+ tokenized_paragraphs = [
132
+ self.qg_tokenizer(p)["input_ids"] for p in paragraphs if len(p) > 0
133
+ ]
134
+
135
+ segments = []
136
+ while len(tokenized_paragraphs) > 0:
137
+ segment = []
138
+ while len(segment) < MAX_TOKENS and len(tokenized_paragraphs) > 0:
139
+ paragraph = tokenized_paragraphs.pop(0)
140
+ segment.extend(paragraph)
141
+ segments.append(segment)
142
+ return [self.qg_tokenizer.decode(s) for s in segments]
143
+
144
+ def _prepare_qg_inputs(self, sentences, text):
145
+ inputs = []
146
+ answers = []
147
+
148
+ for sentence in sentences:
149
+ qg_input = "{} {} {} {}".format(
150
+ self.ANSWER_TOKEN, sentence, self.CONTEXT_TOKEN, text
151
+ )
152
+ inputs.append(qg_input)
153
+ answers.append(sentence)
154
+
155
+ return inputs, answers
156
+
157
+ def _prepare_qg_inputs_MC(self, sentences):
158
+
159
+ spacy_nlp = en_core_web_sm.load()
160
+ docs = list(spacy_nlp.pipe(sentences, disable=["parser"]))
161
+ inputs_from_text = []
162
+ answers_from_text = []
163
+
164
+ for i in range(len(sentences)):
165
+ entities = docs[i].ents
166
+ if entities:
167
+ for entity in entities:
168
+ qg_input = "{} {} {} {}".format(
169
+ self.ANSWER_TOKEN, entity, self.CONTEXT_TOKEN, sentences[i]
170
+ )
171
+ answers = self._get_MC_answers(entity, docs)
172
+ inputs_from_text.append(qg_input)
173
+ answers_from_text.append(answers)
174
+
175
+ return inputs_from_text, answers_from_text
176
+
177
+ def _get_MC_answers(self, correct_answer, docs):
178
+
179
+ entities = []
180
+ for doc in docs:
181
+ entities.extend([{"text": e.text, "label_": e.label_} for e in doc.ents])
182
+
183
+ # remove duplicate elements
184
+ entities_json = [json.dumps(kv) for kv in entities]
185
+ pool = set(entities_json)
186
+ num_choices = (
187
+ min(4, len(pool)) - 1
188
+ ) # -1 because we already have the correct answer
189
+
190
+ # add the correct answer
191
+ final_choices = []
192
+ correct_label = correct_answer.label_
193
+ final_choices.append({"answer": correct_answer.text, "correct": True})
194
+ pool.remove(
195
+ json.dumps({"text": correct_answer.text, "label_": correct_answer.label_})
196
+ )
197
+
198
+ # find answers with the same NER label
199
+ matches = [e for e in pool if correct_label in e]
200
+
201
+ # if we don't have enough then add some other random answers
202
+ if len(matches) < num_choices:
203
+ choices = matches
204
+ pool = pool.difference(set(choices))
205
+ choices.extend(random.sample(pool, num_choices - len(choices)))
206
+ else:
207
+ choices = random.sample(matches, num_choices)
208
+
209
+ choices = [json.loads(s) for s in choices]
210
+ for choice in choices:
211
+ final_choices.append({"answer": choice["text"], "correct": False})
212
+ random.shuffle(final_choices)
213
+ return final_choices
214
+
215
+ def _generate_question(self, qg_input):
216
+ self.qg_model.eval()
217
+ encoded_input = self._encode_qg_input(qg_input)
218
+ with torch.no_grad():
219
+ output = self.qg_model.generate(input_ids=encoded_input["input_ids"])
220
+ question = self.qg_tokenizer.decode(output[0], skip_special_tokens=True)
221
+ return question
222
+
223
+ def _encode_qg_input(self, qg_input):
224
+ return self.qg_tokenizer(
225
+ qg_input,
226
+ padding='max_length',
227
+ max_length=self.SEQ_LENGTH,
228
+ truncation=True,
229
+ return_tensors="pt",
230
+ ).to(self.device)
231
+
232
+ def _get_ranked_qa_pairs(
233
+ self, generated_questions, qg_answers, scores, num_questions=10
234
+ ):
235
+ if num_questions > len(scores):
236
+ num_questions = len(scores)
237
+ print(
238
+ "\nWas only able to generate {} questions. For more questions, please input a longer text.".format(
239
+ num_questions
240
+ )
241
+ )
242
+
243
+ qa_list = []
244
+ for i in range(num_questions):
245
+ index = scores[i]
246
+ qa = self._make_dict(
247
+ generated_questions[index].split("?")[0] + "?", qg_answers[index]
248
+ )
249
+ qa_list.append(qa)
250
+ return qa_list
251
+
252
+ def _get_all_qa_pairs(self, generated_questions, qg_answers):
253
+ qa_list = []
254
+ for i in range(len(generated_questions)):
255
+ qa = self._make_dict(
256
+ generated_questions[i].split("?")[0] + "?", qg_answers[i]
257
+ )
258
+ qa_list.append(qa)
259
+ return qa_list
260
+
261
+ def _make_dict(self, question, answer):
262
+ qa = {}
263
+ qa["question"] = question
264
+ qa["answer"] = answer
265
+ return qa
266
+
267
+
268
+ class QAEvaluator:
269
+ def __init__(self, model_dir=None):
270
+
271
+ QAE_PRETRAINED = "iarfmoose/bert-base-cased-qa-evaluator"
272
+ self.SEQ_LENGTH = 512
273
+
274
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
275
+
276
+ self.qae_tokenizer = AutoTokenizer.from_pretrained(QAE_PRETRAINED)
277
+ self.qae_model = AutoModelForSequenceClassification.from_pretrained(
278
+ QAE_PRETRAINED
279
+ )
280
+ self.qae_model.to(self.device)
281
+
282
+ def encode_qa_pairs(self, questions, answers):
283
+ encoded_pairs = []
284
+ for i in range(len(questions)):
285
+ encoded_qa = self._encode_qa(questions[i], answers[i])
286
+ encoded_pairs.append(encoded_qa.to(self.device))
287
+ return encoded_pairs
288
+
289
+ def get_scores(self, encoded_qa_pairs):
290
+ scores = {}
291
+ self.qae_model.eval()
292
+ with torch.no_grad():
293
+ for i in range(len(encoded_qa_pairs)):
294
+ scores[i] = self._evaluate_qa(encoded_qa_pairs[i])
295
+
296
+ return [
297
+ k for k, v in sorted(scores.items(), key=lambda item: item[1], reverse=True)
298
+ ]
299
+
300
+ def _encode_qa(self, question, answer):
301
+ if type(answer) is list:
302
+ for a in answer:
303
+ if a["correct"]:
304
+ correct_answer = a["answer"]
305
+ else:
306
+ correct_answer = answer
307
+ return self.qae_tokenizer(
308
+ text=question,
309
+ text_pair=correct_answer,
310
+ padding="max_length",
311
+ max_length=self.SEQ_LENGTH,
312
+ truncation=True,
313
+ return_tensors="pt",
314
+ )
315
+
316
+ def _evaluate_qa(self, encoded_qa_pair):
317
+ output = self.qae_model(**encoded_qa_pair)
318
+ return output[0][0][1]
319
+
320
+
321
+ def print_qa(qa_list, show_answers=True):
322
+ for i in range(len(qa_list)):
323
+ space = " " * int(np.where(i < 9, 3, 4)) # wider space for 2 digit q nums
324
+
325
+ print("{}) Q: {}".format(i + 1, qa_list[i]["question"]))
326
+
327
+ answer = qa_list[i]["answer"]
328
+
329
+ # print a list of multiple choice answers
330
+ if type(answer) is list:
331
+
332
+ if show_answers:
333
+ print(
334
+ "{}A: 1.".format(space),
335
+ answer[0]["answer"],
336
+ np.where(answer[0]["correct"], "(correct)", ""),
337
+ )
338
+ for j in range(1, len(answer)):
339
+ print(
340
+ "{}{}.".format(space + " ", j + 1),
341
+ answer[j]["answer"],
342
+ np.where(answer[j]["correct"] == True, "(correct)", ""),
343
+ )
344
+
345
+ else:
346
+ print("{}A: 1.".format(space), answer[0]["answer"])
347
+ for j in range(1, len(answer)):
348
+ print("{}{}.".format(space + " ", j + 1), answer[j]["answer"])
349
+ print("")
350
+
351
+ # print full sentence answers
352
+ else:
353
+ if show_answers:
354
+ print("{}A:".format(space), answer, "\n")