|
import os |
|
import re |
|
import json |
|
import math |
|
import requests |
|
import threading |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
import streamlit as st |
|
import pandas as pd |
|
|
|
|
|
import nltk |
|
nltk.download('punkt') |
|
from nltk.tokenize import sent_tokenize |
|
|
|
|
|
from transformers import pipeline |
|
|
|
|
|
import openai |
|
import google.generativeai as genai |
|
|
|
|
|
|
|
|
|
""" |
|
In your Hugging Face Space: |
|
1. Add environment secrets: |
|
- OPENAI_API_KEY (if using OpenAI) |
|
- GEMINI_API_KEY (if using Google PaLM/Gemini) |
|
- MY_PUBMED_EMAIL (to identify yourself to NCBI) |
|
2. In requirements.txt, install: |
|
- streamlit |
|
- requests |
|
- nltk |
|
- transformers |
|
- torch |
|
- openai (if using OpenAI) |
|
- google-generativeai (if using Gemini) |
|
- pandas |
|
""" |
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") |
|
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "") |
|
MY_PUBMED_EMAIL = os.getenv("MY_PUBMED_EMAIL", "[email protected]") |
|
|
|
if OPENAI_API_KEY: |
|
openai.api_key = OPENAI_API_KEY |
|
|
|
if GEMINI_API_KEY: |
|
genai.configure(api_key=GEMINI_API_KEY) |
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_summarizer(): |
|
""" |
|
Load a summarization model (e.g., BART, PEGASUS, T5). |
|
For a more concise summarization, consider: 'google/pegasus-xsum' |
|
For a balanced approach, 'facebook/bart-large-cnn' is popular. |
|
""" |
|
return pipeline( |
|
"summarization", |
|
model="facebook/bart-large-cnn", |
|
tokenizer="facebook/bart-large-cnn" |
|
) |
|
|
|
summarizer = load_summarizer() |
|
|
|
|
|
|
|
|
|
def search_pubmed(query, max_results=3): |
|
""" |
|
Searches PubMed for PMIDs matching the query. |
|
Includes recommended 'tool' and 'email' in the request. |
|
""" |
|
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" |
|
params = { |
|
"db": "pubmed", |
|
"term": query, |
|
"retmax": max_results, |
|
"retmode": "json", |
|
"tool": "ElysiumRAG", |
|
"email": MY_PUBMED_EMAIL |
|
} |
|
resp = requests.get(base_url, params=params) |
|
resp.raise_for_status() |
|
data = resp.json() |
|
id_list = data.get("esearchresult", {}).get("idlist", []) |
|
return id_list |
|
|
|
def fetch_one_abstract(pmid): |
|
""" |
|
Fetches a single abstract for a given PMID using EFetch. |
|
Returns (pmid, text). |
|
""" |
|
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi" |
|
params = { |
|
"db": "pubmed", |
|
"retmode": "text", |
|
"rettype": "abstract", |
|
"id": pmid, |
|
"tool": "ElysiumRAG", |
|
"email": MY_PUBMED_EMAIL |
|
} |
|
resp = requests.get(base_url, params=params) |
|
resp.raise_for_status() |
|
raw_text = resp.text.strip() |
|
|
|
|
|
if not raw_text: |
|
return (pmid, "No abstract text found.") |
|
|
|
return (pmid, raw_text) |
|
|
|
def fetch_pubmed_abstracts(pmids): |
|
""" |
|
Parallel fetching of multiple PMIDs to reduce overall latency. |
|
Returns {pmid: abstract_text}. |
|
""" |
|
abstracts_map = {} |
|
with ThreadPoolExecutor(max_workers=min(len(pmids), 5)) as executor: |
|
future_to_pmid = {executor.submit(fetch_one_abstract, pmid): pmid for pmid in pmids} |
|
for future in as_completed(future_to_pmid): |
|
pmid = future_to_pmid[future] |
|
try: |
|
pmid_result, text = future.result() |
|
abstracts_map[pmid_result] = text |
|
except Exception as e: |
|
abstracts_map[pmid] = f"Error fetching abstract: {str(e)}" |
|
return abstracts_map |
|
|
|
|
|
|
|
|
|
def chunk_and_summarize(abstract_text, chunk_size=512): |
|
""" |
|
Splits a large abstract into manageable chunks (by sentences), |
|
then summarizes each chunk with the Hugging Face pipeline. |
|
Returns a combined summary for the entire abstract. |
|
""" |
|
|
|
sentences = sent_tokenize(abstract_text) |
|
chunks = [] |
|
|
|
current_chunk = [] |
|
current_length = 0 |
|
for sent in sentences: |
|
tokens_in_sent = len(sent.split()) |
|
|
|
if current_length + tokens_in_sent > chunk_size: |
|
chunks.append(" ".join(current_chunk)) |
|
current_chunk = [] |
|
current_length = 0 |
|
current_chunk.append(sent) |
|
current_length += tokens_in_sent |
|
|
|
|
|
if current_chunk: |
|
chunks.append(" ".join(current_chunk)) |
|
|
|
|
|
summarized_pieces = [] |
|
for c in chunks: |
|
summary_out = summarizer( |
|
c, |
|
max_length=100, |
|
min_length=30, |
|
do_sample=False |
|
) |
|
summarized_pieces.append(summary_out[0]['summary_text']) |
|
|
|
|
|
final_summary = " ".join(summarized_pieces) |
|
return final_summary.strip() |
|
|
|
|
|
|
|
|
|
def openai_chat(system_prompt, user_message, model="gpt-3.5-turbo", temperature=0.3): |
|
""" |
|
Basic ChatCompletion with a system + user role for OpenAI. |
|
""" |
|
if not OPENAI_API_KEY: |
|
return "Error: OpenAI API key not provided." |
|
try: |
|
response = openai.ChatCompletion.create( |
|
model=model, |
|
messages=[ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": user_message} |
|
], |
|
temperature=temperature |
|
) |
|
return response.choices[0].message["content"].strip() |
|
except Exception as e: |
|
return f"Error calling OpenAI: {str(e)}" |
|
|
|
def gemini_chat(system_prompt, user_message, model_name="models/chat-bison-001", temperature=0.3): |
|
""" |
|
Basic PaLM2/Gemini chat call using google.generativeai. |
|
""" |
|
if not GEMINI_API_KEY: |
|
return "Error: Gemini API key not provided." |
|
try: |
|
model = genai.GenerativeModel(model_name=model_name) |
|
chat_session = model.start_chat(history=[("system", system_prompt)]) |
|
reply = chat_session.send_message(user_message, temperature=temperature) |
|
return reply.text |
|
except Exception as e: |
|
return f"Error calling Gemini: {str(e)}" |
|
|
|
|
|
|
|
|
|
def build_system_prompt_with_refs(pmids, summarized_map): |
|
""" |
|
Creates a system prompt that includes the summarized abstracts alongside |
|
labeled references. This allows the LLM to quote or cite specific references. |
|
""" |
|
|
|
system_context = ( |
|
"You have access to the following summarized PubMed articles. " |
|
"When relevant, cite them in your final answer using their reference label.\n\n" |
|
) |
|
for idx, pmid in enumerate(pmids, start=1): |
|
ref_label = f"[Ref{idx}]" |
|
system_context += f"{ref_label} (PMID {pmid}): {summarized_map[pmid]}\n\n" |
|
system_context += "Use this contextual info to provide a concise, evidence-based answer." |
|
return system_context |
|
|
|
|
|
|
|
|
|
def main(): |
|
st.set_page_config(page_title="Enhanced RAG + PubMed", layout="wide") |
|
st.title("Enhanced RAG + PubMed: Production-Ready Medical Insights") |
|
|
|
st.markdown(""" |
|
**Welcome** to an advanced demonstration of **Retrieval-Augmented Generation (RAG)** |
|
using PubMed E-utilities, Hugging Face Summarization, and optional LLM calls (OpenAI or Gemini). |
|
|
|
This version includes: |
|
- **Parallel** fetching for multiple PMIDs |
|
- Advanced **chunking & summarization** of large abstracts |
|
- **Reference labeling** in the final answer |
|
- Clear disclaimers & best-practice structures |
|
|
|
--- |
|
**Disclaimer**: This is a demonstration prototype for educational or research purposes. |
|
It is *not* a substitute for professional medical advice. Always consult a qualified |
|
healthcare provider for personal health decisions. |
|
""") |
|
|
|
user_query = st.text_area( |
|
"Enter your medical question or topic:", |
|
placeholder="e.g., 'What are the latest treatments for type 2 diabetes complications?'", |
|
height=120 |
|
) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
max_papers = st.slider( |
|
"Number of PubMed Articles to Retrieve", |
|
min_value=1, |
|
max_value=10, |
|
value=3, |
|
help="Number of articles to fetch & summarize." |
|
) |
|
with col2: |
|
selected_llm = st.selectbox( |
|
"Select LLM for Final Generation", |
|
["OpenAI: GPT-3.5", "Gemini: PaLM2"], |
|
help="Choose which large language model to finalize the answer." |
|
) |
|
|
|
|
|
chunk_size = st.slider( |
|
"Summarization Chunk Size (words)", |
|
min_value=256, |
|
max_value=1024, |
|
value=512, |
|
help="Larger chunks might produce fewer summaries, but risk token limits. Smaller chunks produce more robust summaries." |
|
) |
|
|
|
if st.button("Run Enhanced RAG Pipeline"): |
|
if not user_query.strip(): |
|
st.warning("Please enter a query before running RAG.") |
|
return |
|
|
|
|
|
with st.spinner("Searching PubMed..."): |
|
pmids = search_pubmed(query=user_query, max_results=max_papers) |
|
|
|
if not pmids: |
|
st.error("No matching PubMed results. Try a different query.") |
|
return |
|
|
|
|
|
with st.spinner("Fetching and summarizing abstracts..."): |
|
abstracts_map = fetch_pubmed_abstracts(pmids) |
|
summarized_map = {} |
|
for pmid, abstract_text in abstracts_map.items(): |
|
if "Error fetching" in abstract_text: |
|
summarized_map[pmid] = abstract_text |
|
else: |
|
summarized_map[pmid] = chunk_and_summarize(abstract_text, chunk_size=chunk_size) |
|
|
|
|
|
st.subheader("Retrieved & Summarized PubMed Articles") |
|
for idx, pmid in enumerate(pmids, start=1): |
|
ref_label = f"[Ref{idx}]" |
|
st.markdown(f"**{ref_label} PMID {pmid}**") |
|
st.write(summarized_map[pmid]) |
|
st.write("---") |
|
|
|
|
|
st.subheader("Final Answer") |
|
system_prompt = build_system_prompt_with_refs(pmids, summarized_map) |
|
|
|
with st.spinner("Generating final answer..."): |
|
if selected_llm == "OpenAI: GPT-3.5": |
|
answer = openai_chat(system_prompt=system_prompt, user_message=user_query) |
|
else: |
|
answer = gemini_chat(system_prompt=system_prompt, user_message=user_query) |
|
|
|
st.write(answer) |
|
st.success("RAG Pipeline Complete.") |
|
|
|
|
|
st.markdown("---") |
|
st.markdown(""" |
|
### Production-Ready Enhancements: |
|
1. **Vector Databases & Advanced Retrieval** |
|
- For large-scale usage, index PubMed articles in a vector DB (e.g. Pinecone, Weaviate) to quickly retrieve relevant passages. |
|
2. **Citation Parsing** |
|
- Automatically detect which abstract chunks contributed to each sentence. |
|
3. **Multi-Lingual** |
|
- Integrate translation pipelines for non-English queries or abstracts. |
|
4. **Rate Limiting** |
|
- Respect NCBI's ~3 requests/sec guideline if you're scaling out. |
|
5. **Robust Logging & Error Handling** |
|
- Build out logs, handle exceptions gracefully, and provide fallback prompts if an LLM fails or an abstract is missing. |
|
6. **Privacy & Security** |
|
- This demo only fetches public info. For patient data, ensure HIPAA/GDPR compliance and encrypted data pipelines. |
|
""") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|