Update app.py
Browse files
app.py
CHANGED
|
@@ -24,26 +24,22 @@ if st.button('Run semantic question answering'):
|
|
| 24 |
except Exception as e:
|
| 25 |
qa_result = str(e)
|
| 26 |
|
| 27 |
-
# top_5_hits = kws_result['hits']['hits'][:5] # print("First 5 results:")
|
| 28 |
top_10_hits = kws_result['hits']['hits'][:10] # print("First 10 results:")
|
| 29 |
top_5_text = [{'text': hit['_source']['content'][:500],
|
| 30 |
'confidence': hit['_score']} for hit in top_10_hits[:5] ]
|
| 31 |
-
|
| 32 |
-
#
|
| 33 |
-
# top_5_para = [hit['_source']['content'][:5000] for hit in top_5_hits]
|
| 34 |
|
| 35 |
DPR_MODEL = "deepset/roberta-base-squad2" #, model="distilbert-base-cased-distilled-squad"
|
| 36 |
pipe_exqa = pipeline("question-answering", model=DPR_MODEL)
|
| 37 |
-
qa_results = [pipe_exqa(question=question, context=paragraph) for paragraph in
|
| 38 |
-
# qa_results = [pipe_exqa(question=question, context=paragraph) for paragraph in top_3_para]
|
| 39 |
qa_results = sorted(qa_results, key=lambda x: x['score'], reverse=True)
|
| 40 |
|
| 41 |
for i, qa_result in enumerate(qa_results):
|
| 42 |
if "answer" in qa_result.keys(): # and qa_result["answer"] is not ""
|
| 43 |
answer_span, answer_score = qa_result["answer"], qa_result["score"]
|
| 44 |
st.write(f'Answer: **{answer_span}**')
|
| 45 |
-
|
| 46 |
-
paragraph = top_5_para[i]
|
| 47 |
start_par, stop_para = max(0, qa_result["start"]-86), min(qa_result["end"]+90, len(paragraph))
|
| 48 |
answer_context = paragraph[start_par:stop_para].replace(answer_span, f'**{answer_span}**')
|
| 49 |
qa_result.update({'context': answer_context, 'paragraph': paragraph})
|
|
|
|
| 24 |
except Exception as e:
|
| 25 |
qa_result = str(e)
|
| 26 |
|
|
|
|
| 27 |
top_10_hits = kws_result['hits']['hits'][:10] # print("First 10 results:")
|
| 28 |
top_5_text = [{'text': hit['_source']['content'][:500],
|
| 29 |
'confidence': hit['_score']} for hit in top_10_hits[:5] ]
|
| 30 |
+
top_3_para = [hit['_source']['content'][:5000] for hit in top_10_hits[:3]]
|
| 31 |
+
# TODO: split + re-rank
|
|
|
|
| 32 |
|
| 33 |
DPR_MODEL = "deepset/roberta-base-squad2" #, model="distilbert-base-cased-distilled-squad"
|
| 34 |
pipe_exqa = pipeline("question-answering", model=DPR_MODEL)
|
| 35 |
+
qa_results = [pipe_exqa(question=question, context=paragraph) for paragraph in top_3_para]
|
|
|
|
| 36 |
qa_results = sorted(qa_results, key=lambda x: x['score'], reverse=True)
|
| 37 |
|
| 38 |
for i, qa_result in enumerate(qa_results):
|
| 39 |
if "answer" in qa_result.keys(): # and qa_result["answer"] is not ""
|
| 40 |
answer_span, answer_score = qa_result["answer"], qa_result["score"]
|
| 41 |
st.write(f'Answer: **{answer_span}**')
|
| 42 |
+
paragraph = top_3_para[i]
|
|
|
|
| 43 |
start_par, stop_para = max(0, qa_result["start"]-86), min(qa_result["end"]+90, len(paragraph))
|
| 44 |
answer_context = paragraph[start_par:stop_para].replace(answer_span, f'**{answer_span}**')
|
| 45 |
qa_result.update({'context': answer_context, 'paragraph': paragraph})
|