Spaces:
Runtime error
Runtime error
File size: 5,817 Bytes
3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee 3dbe475 3701fee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import torch
import transformers
import gradio as gr
from ragatouille import RAGPretrainedModel
import re
from datetime import datetime
import json
import arxiv
from helper import rag_cleaner, get_prompt_text, get_references, get_rag, SaveResponseAndRead, get_md_text_abstract, search_cleaner, get_arxiv_live_search
# Constants
RETRIEVE_RESULTS = 20
LLM_MODELS = ['mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2', 'google/gemma-7b-it', 'None']
DEFAULT_LLM_MODEL = 'mistralai/Mistral-7B-Instruct-v0.2'
GENERATE_KWARGS = {
"temperature": None,
"max_new_tokens": 512,
"top_p": None,
"do_sample": False,
}
# RAG Model setup
RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert")
try:
gr.Info("Setting up retriever, please wait...")
rag_initial_output = RAG.search("What is Generative AI in Healthcare?", k=1)
gr.Info("Retriever working successfully!")
except Exception as e:
gr.Warning(f"Retriever not working: {str(e)}")
# Header setup
mark_text = '# 🩺🔍 Search Results\n'
header_text = "## Arxiv Paper Summary With QA Retrieval Augmented Generation \n"
try:
with open("README.md", "r") as f:
mdfile = f.read()
date_pattern = r'Index Last Updated : \d{4}-\d{2}-\d{2}'
match = re.search(date_pattern, mdfile)
date = match.group().split(': ')[1]
formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y')
header_text += f'Index Last Updated: {formatted_date}\n'
index_info = f"Semantic Search - up to {formatted_date}"
except FileNotFoundError:
index_info = "Semantic Search"
database_choices = [index_info, 'Arxiv Search - Latest - (EXPERIMENTAL)']
# Arxiv API setup
arx_client = arxiv.Client()
is_arxiv_available = True
check_arxiv_result = get_arxiv_live_search("What is Self Rewarding AI and how can it be used in Multi-Agent Systems?", arx_client, RETRIEVE_RESULTS)
if len(check_arxiv_result) == 0:
is_arxiv_available = False
print("Arxiv search not working, switching to default search ...")
database_choices = [index_info]
# Gradio UI setup
with gr.Blocks(theme=gr.themes.Soft()) as demo:
header = gr.Markdown(header_text)
with gr.Group():
search_query = gr.Textbox(label='Search', placeholder='What is Generative AI in Healthcare?')
with gr.Accordion("Advanced Settings", open=False):
with gr.Row(equal_height=True):
llm_model = gr.Dropdown(choices=LLM_MODELS, value=DEFAULT_LLM_MODEL, label='LLM Model')
llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results as context")
database_src = gr.Dropdown(choices=database_choices, value=index_info, label='Search Source')
stream_results = gr.Checkbox(value=True, label="Stream output", visible=False)
output_text = gr.Textbox(show_label=True, container=True, label='LLM Answer', visible=True)
input = gr.Textbox(show_label=False, visible=False)
gr_md = gr.Markdown(mark_text)
def update_with_rag_md(search_query, llm_results_use=5, database_choice=index_info, llm_model_picked=DEFAULT_LLM_MODEL):
prompt_text_from_data = ""
database_to_use = database_choice
if database_choice == index_info:
rag_out = get_rag(search_query, RAG, RETRIEVE_RESULTS)
else:
arxiv_search_success = True
try:
rag_out = get_arxiv_live_search(search_query, arx_client, RETRIEVE_RESULTS)
if len(rag_out) == 0:
arxiv_search_success = False
except Exception as e:
arxiv_search_success = False
gr.Warning(f"Arxiv Search not working: {str(e)}, switching to semantic search ...")
if not arxiv_search_success:
rag_out = get_rag(search_query, RAG, RETRIEVE_RESULTS)
database_to_use = index_info
md_text_updated = mark_text
for i, rag_answer in enumerate(rag_out):
if i < llm_results_use:
md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source=database_to_use, return_prompt_formatting=True)
prompt_text_from_data += f"{i+1}. {prompt_text}"
else:
md_text_paper = get_md_text_abstract(rag_answer, source=database_to_use)
md_text_updated += md_text_paper
prompt = get_prompt_text(search_query, prompt_text_from_data, llm_model_picked=llm_model_picked)
return md_text_updated, prompt
def ask_llm(prompt, llm_model_picked=DEFAULT_LLM_MODEL, stream_outputs=False):
model_disabled_text = "LLM Model is disabled"
output = ""
if llm_model_picked == 'None':
if stream_outputs:
for out in model_disabled_text:
output += out
yield output
else:
return model_disabled_text
client = InferenceClient(llm_model_picked)
try:
response = client.text_generation(prompt, stream=stream_outputs, details=False, return_full_text=False, **GENERATE_KWARGS)
if stream_outputs:
for token in response:
output += token
yield SaveResponseAndRead(output)
else:
output = response
except Exception as e:
gr.Warning(f"LLM Inference failed: {str(e)}")
output = ""
return output
search_query.submit(update_with_rag_md, [search_query, llm_results, database_src, llm_model], [gr_md, input]).success(ask_llm, [input, llm_model, stream_results], output_text)
demo.queue().launch() |