File size: 13,374 Bytes
8225d31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 |
import os
import re
import json
import math
import requests
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
import streamlit as st
import pandas as pd
# NLP
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
# Hugging Face Transformers
from transformers import pipeline
# Optional: OpenAI and Google Generative AI
import openai
import google.generativeai as genai
###############################################################################
# CONFIG & ENV #
###############################################################################
"""
In your Hugging Face Space:
1. Add environment secrets:
- OPENAI_API_KEY (if using OpenAI)
- GEMINI_API_KEY (if using Google PaLM/Gemini)
- MY_PUBMED_EMAIL (to identify yourself to NCBI)
2. In requirements.txt, install:
- streamlit
- requests
- nltk
- transformers
- torch
- openai (if using OpenAI)
- google-generativeai (if using Gemini)
- pandas
"""
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
MY_PUBMED_EMAIL = os.getenv("MY_PUBMED_EMAIL", "[email protected]")
if OPENAI_API_KEY:
openai.api_key = OPENAI_API_KEY
if GEMINI_API_KEY:
genai.configure(api_key=GEMINI_API_KEY)
###############################################################################
# SUMMARIZATION PIPELINE #
###############################################################################
@st.cache_resource
def load_summarizer():
"""
Load a summarization model (e.g., BART, PEGASUS, T5).
For a more concise summarization, consider: 'google/pegasus-xsum'
For a balanced approach, 'facebook/bart-large-cnn' is popular.
"""
return pipeline(
"summarization",
model="facebook/bart-large-cnn",
tokenizer="facebook/bart-large-cnn"
)
summarizer = load_summarizer()
###############################################################################
# PUBMED RETRIEVAL (NCBI E-utilities) #
###############################################################################
def search_pubmed(query, max_results=3):
"""
Searches PubMed for PMIDs matching the query.
Includes recommended 'tool' and 'email' in the request.
"""
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
params = {
"db": "pubmed",
"term": query,
"retmax": max_results,
"retmode": "json",
"tool": "ElysiumRAG",
"email": MY_PUBMED_EMAIL
}
resp = requests.get(base_url, params=params)
resp.raise_for_status()
data = resp.json()
id_list = data.get("esearchresult", {}).get("idlist", [])
return id_list
def fetch_one_abstract(pmid):
"""
Fetches a single abstract for a given PMID using EFetch.
Returns (pmid, text).
"""
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
params = {
"db": "pubmed",
"retmode": "text",
"rettype": "abstract",
"id": pmid,
"tool": "ElysiumRAG",
"email": MY_PUBMED_EMAIL
}
resp = requests.get(base_url, params=params)
resp.raise_for_status()
raw_text = resp.text.strip()
# If there's no clear text returned, mark as empty
if not raw_text:
return (pmid, "No abstract text found.")
return (pmid, raw_text)
def fetch_pubmed_abstracts(pmids):
"""
Parallel fetching of multiple PMIDs to reduce overall latency.
Returns {pmid: abstract_text}.
"""
abstracts_map = {}
with ThreadPoolExecutor(max_workers=min(len(pmids), 5)) as executor:
future_to_pmid = {executor.submit(fetch_one_abstract, pmid): pmid for pmid in pmids}
for future in as_completed(future_to_pmid):
pmid = future_to_pmid[future]
try:
pmid_result, text = future.result()
abstracts_map[pmid_result] = text
except Exception as e:
abstracts_map[pmid] = f"Error fetching abstract: {str(e)}"
return abstracts_map
###############################################################################
# ABSTRACT CHUNKING + SUMMARIZATION LOGIC #
###############################################################################
def chunk_and_summarize(abstract_text, chunk_size=512):
"""
Splits a large abstract into manageable chunks (by sentences),
then summarizes each chunk with the Hugging Face pipeline.
Returns a combined summary for the entire abstract.
"""
# We first split by sentences
sentences = sent_tokenize(abstract_text)
chunks = []
current_chunk = []
current_length = 0
for sent in sentences:
tokens_in_sent = len(sent.split())
# If adding this sentence exceeds the chunk_size limit, finalize the chunk
if current_length + tokens_in_sent > chunk_size:
chunks.append(" ".join(current_chunk))
current_chunk = []
current_length = 0
current_chunk.append(sent)
current_length += tokens_in_sent
# Final chunk if it exists
if current_chunk:
chunks.append(" ".join(current_chunk))
# Summarize each chunk to avoid hitting token or length constraints
summarized_pieces = []
for c in chunks:
summary_out = summarizer(
c,
max_length=100, # tweak for desired summary length
min_length=30,
do_sample=False
)
summarized_pieces.append(summary_out[0]['summary_text'])
# Combine partial summaries into one final text
final_summary = " ".join(summarized_pieces)
return final_summary.strip()
###############################################################################
# LLM CALLS (OpenAI / Gemini) #
###############################################################################
def openai_chat(system_prompt, user_message, model="gpt-3.5-turbo", temperature=0.3):
"""
Basic ChatCompletion with a system + user role for OpenAI.
"""
if not OPENAI_API_KEY:
return "Error: OpenAI API key not provided."
try:
response = openai.ChatCompletion.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
],
temperature=temperature
)
return response.choices[0].message["content"].strip()
except Exception as e:
return f"Error calling OpenAI: {str(e)}"
def gemini_chat(system_prompt, user_message, model_name="models/chat-bison-001", temperature=0.3):
"""
Basic PaLM2/Gemini chat call using google.generativeai.
"""
if not GEMINI_API_KEY:
return "Error: Gemini API key not provided."
try:
model = genai.GenerativeModel(model_name=model_name)
chat_session = model.start_chat(history=[("system", system_prompt)])
reply = chat_session.send_message(user_message, temperature=temperature)
return reply.text
except Exception as e:
return f"Error calling Gemini: {str(e)}"
###############################################################################
# BUILD REFERENCES FOR ANSWER #
###############################################################################
def build_system_prompt_with_refs(pmids, summarized_map):
"""
Creates a system prompt that includes the summarized abstracts alongside
labeled references. This allows the LLM to quote or cite specific references.
"""
# Example of labeling references: [Ref1], [Ref2], etc.
system_context = (
"You have access to the following summarized PubMed articles. "
"When relevant, cite them in your final answer using their reference label.\n\n"
)
for idx, pmid in enumerate(pmids, start=1):
ref_label = f"[Ref{idx}]"
system_context += f"{ref_label} (PMID {pmid}): {summarized_map[pmid]}\n\n"
system_context += "Use this contextual info to provide a concise, evidence-based answer."
return system_context
###############################################################################
# STREAMLIT APP #
###############################################################################
def main():
st.set_page_config(page_title="Enhanced RAG + PubMed", layout="wide")
st.title("Enhanced RAG + PubMed: Production-Ready Medical Insights")
st.markdown("""
**Welcome** to an advanced demonstration of **Retrieval-Augmented Generation (RAG)**
using PubMed E-utilities, Hugging Face Summarization, and optional LLM calls (OpenAI or Gemini).
This version includes:
- **Parallel** fetching for multiple PMIDs
- Advanced **chunking & summarization** of large abstracts
- **Reference labeling** in the final answer
- Clear disclaimers & best-practice structures
---
**Disclaimer**: This is a demonstration prototype for educational or research purposes.
It is *not* a substitute for professional medical advice. Always consult a qualified
healthcare provider for personal health decisions.
""")
user_query = st.text_area(
"Enter your medical question or topic:",
placeholder="e.g., 'What are the latest treatments for type 2 diabetes complications?'",
height=120
)
# Sidebar or columns for parameters
col1, col2 = st.columns(2)
with col1:
max_papers = st.slider(
"Number of PubMed Articles to Retrieve",
min_value=1,
max_value=10,
value=3,
help="Number of articles to fetch & summarize."
)
with col2:
selected_llm = st.selectbox(
"Select LLM for Final Generation",
["OpenAI: GPT-3.5", "Gemini: PaLM2"],
help="Choose which large language model to finalize the answer."
)
# Additional advanced parameter: chunk size
chunk_size = st.slider(
"Summarization Chunk Size (words)",
min_value=256,
max_value=1024,
value=512,
help="Larger chunks might produce fewer summaries, but risk token limits. Smaller chunks produce more robust summaries."
)
if st.button("Run Enhanced RAG Pipeline"):
if not user_query.strip():
st.warning("Please enter a query before running RAG.")
return
# 1. PubMed Search
with st.spinner("Searching PubMed..."):
pmids = search_pubmed(query=user_query, max_results=max_papers)
if not pmids:
st.error("No matching PubMed results. Try a different query.")
return
# 2. Fetch abstracts in parallel
with st.spinner("Fetching and summarizing abstracts..."):
abstracts_map = fetch_pubmed_abstracts(pmids)
summarized_map = {}
for pmid, abstract_text in abstracts_map.items():
if "Error fetching" in abstract_text:
summarized_map[pmid] = abstract_text
else:
summarized_map[pmid] = chunk_and_summarize(abstract_text, chunk_size=chunk_size)
# 3. Display Summaries
st.subheader("Retrieved & Summarized PubMed Articles")
for idx, pmid in enumerate(pmids, start=1):
ref_label = f"[Ref{idx}]"
st.markdown(f"**{ref_label} PMID {pmid}**")
st.write(summarized_map[pmid])
st.write("---")
# 4. Build System Prompt
st.subheader("Final Answer")
system_prompt = build_system_prompt_with_refs(pmids, summarized_map)
with st.spinner("Generating final answer..."):
if selected_llm == "OpenAI: GPT-3.5":
answer = openai_chat(system_prompt=system_prompt, user_message=user_query)
else:
answer = gemini_chat(system_prompt=system_prompt, user_message=user_query)
st.write(answer)
st.success("RAG Pipeline Complete.")
# Production Considerations & Next Steps
st.markdown("---")
st.markdown("""
### Production-Ready Enhancements:
1. **Vector Databases & Advanced Retrieval**
- For large-scale usage, index PubMed articles in a vector DB (e.g. Pinecone, Weaviate) to quickly retrieve relevant passages.
2. **Citation Parsing**
- Automatically detect which abstract chunks contributed to each sentence.
3. **Multi-Lingual**
- Integrate translation pipelines for non-English queries or abstracts.
4. **Rate Limiting**
- Respect NCBI's ~3 requests/sec guideline if you're scaling out.
5. **Robust Logging & Error Handling**
- Build out logs, handle exceptions gracefully, and provide fallback prompts if an LLM fails or an abstract is missing.
6. **Privacy & Security**
- This demo only fetches public info. For patient data, ensure HIPAA/GDPR compliance and encrypted data pipelines.
""")
if __name__ == "__main__":
main()
|