Refactor retrieve function and enhance UI with Gradio components for improved query handling
Browse files
app.py
CHANGED
|
@@ -367,26 +367,34 @@ return_type = List[Hit]
|
|
| 367 |
|
| 368 |
|
| 369 |
## YOUR_CODE_STARTS_HERE
|
| 370 |
-
def retrieve(query: str, topk: int
|
| 371 |
-
ranking = bm25_retriever.retrieve(query=query, topk=
|
| 372 |
hits = []
|
| 373 |
for cid, score in ranking.items():
|
| 374 |
text = bm25_retriever.index.doc_texts[bm25_retriever.index.cid2docid[cid]]
|
| 375 |
hits.append({"cid": cid, "score": score, "text": text})
|
| 376 |
return hits
|
| 377 |
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
demo.launch()
|
|
|
|
|
|
| 367 |
|
| 368 |
|
| 369 |
## YOUR_CODE_STARTS_HERE
|
| 370 |
+
def retrieve(query: str, topk: int=10) -> return_type:
|
| 371 |
+
ranking = bm25_retriever.retrieve(query=query, topk=topk)
|
| 372 |
hits = []
|
| 373 |
for cid, score in ranking.items():
|
| 374 |
text = bm25_retriever.index.doc_texts[bm25_retriever.index.cid2docid[cid]]
|
| 375 |
hits.append({"cid": cid, "score": score, "text": text})
|
| 376 |
return hits
|
| 377 |
|
| 378 |
+
with gr.Blocks(theme=gr.themes.Ocean()) as demo:
|
| 379 |
+
gr.Markdown("# BM25 Retriever")
|
| 380 |
+
gr.Markdown("Retrieve documents based on the query using BM25 Retriever")
|
| 381 |
+
query = gr.Textbox(lines=3, placeholder="Enter your query here...")
|
| 382 |
+
topk = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Top-K")
|
| 383 |
+
# Search Button
|
| 384 |
+
examples = gr.Examples(
|
| 385 |
+
examples=[
|
| 386 |
+
["What are the differences between immunodeficiency and autoimmune diseases?"],
|
| 387 |
+
["What are the causes of immunodeficiency?"],
|
| 388 |
+
["What are the symptoms of immunodeficiency?"],
|
| 389 |
+
],
|
| 390 |
+
inputs=[query],
|
| 391 |
+
)
|
| 392 |
+
search_button = gr.Button("Search", elem_id="search_button")
|
| 393 |
+
results_section = gr.JSON(elem_id="results_section")
|
| 394 |
+
search_button.click(
|
| 395 |
+
retrieve,
|
| 396 |
+
inputs=[query, topk],
|
| 397 |
+
outputs=results_section,
|
| 398 |
+
)
|
| 399 |
demo.launch()
|
| 400 |
+
## YOUR_CODE_ENDS_HERE
|