Medapp / app.py
mgbam's picture
Update app.py
8225d31 verified
raw
history blame
13.4 kB
import os
import re
import json
import math
import requests
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
import streamlit as st
import pandas as pd
# NLP
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
# Hugging Face Transformers
from transformers import pipeline
# Optional: OpenAI and Google Generative AI
import openai
import google.generativeai as genai
###############################################################################
# CONFIG & ENV #
###############################################################################
"""
In your Hugging Face Space:
1. Add environment secrets:
- OPENAI_API_KEY (if using OpenAI)
- GEMINI_API_KEY (if using Google PaLM/Gemini)
- MY_PUBMED_EMAIL (to identify yourself to NCBI)
2. In requirements.txt, install:
- streamlit
- requests
- nltk
- transformers
- torch
- openai (if using OpenAI)
- google-generativeai (if using Gemini)
- pandas
"""
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
MY_PUBMED_EMAIL = os.getenv("MY_PUBMED_EMAIL", "[email protected]")
if OPENAI_API_KEY:
openai.api_key = OPENAI_API_KEY
if GEMINI_API_KEY:
genai.configure(api_key=GEMINI_API_KEY)
###############################################################################
# SUMMARIZATION PIPELINE #
###############################################################################
@st.cache_resource
def load_summarizer():
"""
Load a summarization model (e.g., BART, PEGASUS, T5).
For a more concise summarization, consider: 'google/pegasus-xsum'
For a balanced approach, 'facebook/bart-large-cnn' is popular.
"""
return pipeline(
"summarization",
model="facebook/bart-large-cnn",
tokenizer="facebook/bart-large-cnn"
)
summarizer = load_summarizer()
###############################################################################
# PUBMED RETRIEVAL (NCBI E-utilities) #
###############################################################################
def search_pubmed(query, max_results=3):
"""
Searches PubMed for PMIDs matching the query.
Includes recommended 'tool' and 'email' in the request.
"""
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
params = {
"db": "pubmed",
"term": query,
"retmax": max_results,
"retmode": "json",
"tool": "ElysiumRAG",
"email": MY_PUBMED_EMAIL
}
resp = requests.get(base_url, params=params)
resp.raise_for_status()
data = resp.json()
id_list = data.get("esearchresult", {}).get("idlist", [])
return id_list
def fetch_one_abstract(pmid):
"""
Fetches a single abstract for a given PMID using EFetch.
Returns (pmid, text).
"""
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
params = {
"db": "pubmed",
"retmode": "text",
"rettype": "abstract",
"id": pmid,
"tool": "ElysiumRAG",
"email": MY_PUBMED_EMAIL
}
resp = requests.get(base_url, params=params)
resp.raise_for_status()
raw_text = resp.text.strip()
# If there's no clear text returned, mark as empty
if not raw_text:
return (pmid, "No abstract text found.")
return (pmid, raw_text)
def fetch_pubmed_abstracts(pmids):
"""
Parallel fetching of multiple PMIDs to reduce overall latency.
Returns {pmid: abstract_text}.
"""
abstracts_map = {}
with ThreadPoolExecutor(max_workers=min(len(pmids), 5)) as executor:
future_to_pmid = {executor.submit(fetch_one_abstract, pmid): pmid for pmid in pmids}
for future in as_completed(future_to_pmid):
pmid = future_to_pmid[future]
try:
pmid_result, text = future.result()
abstracts_map[pmid_result] = text
except Exception as e:
abstracts_map[pmid] = f"Error fetching abstract: {str(e)}"
return abstracts_map
###############################################################################
# ABSTRACT CHUNKING + SUMMARIZATION LOGIC #
###############################################################################
def chunk_and_summarize(abstract_text, chunk_size=512):
"""
Splits a large abstract into manageable chunks (by sentences),
then summarizes each chunk with the Hugging Face pipeline.
Returns a combined summary for the entire abstract.
"""
# We first split by sentences
sentences = sent_tokenize(abstract_text)
chunks = []
current_chunk = []
current_length = 0
for sent in sentences:
tokens_in_sent = len(sent.split())
# If adding this sentence exceeds the chunk_size limit, finalize the chunk
if current_length + tokens_in_sent > chunk_size:
chunks.append(" ".join(current_chunk))
current_chunk = []
current_length = 0
current_chunk.append(sent)
current_length += tokens_in_sent
# Final chunk if it exists
if current_chunk:
chunks.append(" ".join(current_chunk))
# Summarize each chunk to avoid hitting token or length constraints
summarized_pieces = []
for c in chunks:
summary_out = summarizer(
c,
max_length=100, # tweak for desired summary length
min_length=30,
do_sample=False
)
summarized_pieces.append(summary_out[0]['summary_text'])
# Combine partial summaries into one final text
final_summary = " ".join(summarized_pieces)
return final_summary.strip()
###############################################################################
# LLM CALLS (OpenAI / Gemini) #
###############################################################################
def openai_chat(system_prompt, user_message, model="gpt-3.5-turbo", temperature=0.3):
"""
Basic ChatCompletion with a system + user role for OpenAI.
"""
if not OPENAI_API_KEY:
return "Error: OpenAI API key not provided."
try:
response = openai.ChatCompletion.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
],
temperature=temperature
)
return response.choices[0].message["content"].strip()
except Exception as e:
return f"Error calling OpenAI: {str(e)}"
def gemini_chat(system_prompt, user_message, model_name="models/chat-bison-001", temperature=0.3):
"""
Basic PaLM2/Gemini chat call using google.generativeai.
"""
if not GEMINI_API_KEY:
return "Error: Gemini API key not provided."
try:
model = genai.GenerativeModel(model_name=model_name)
chat_session = model.start_chat(history=[("system", system_prompt)])
reply = chat_session.send_message(user_message, temperature=temperature)
return reply.text
except Exception as e:
return f"Error calling Gemini: {str(e)}"
###############################################################################
# BUILD REFERENCES FOR ANSWER #
###############################################################################
def build_system_prompt_with_refs(pmids, summarized_map):
"""
Creates a system prompt that includes the summarized abstracts alongside
labeled references. This allows the LLM to quote or cite specific references.
"""
# Example of labeling references: [Ref1], [Ref2], etc.
system_context = (
"You have access to the following summarized PubMed articles. "
"When relevant, cite them in your final answer using their reference label.\n\n"
)
for idx, pmid in enumerate(pmids, start=1):
ref_label = f"[Ref{idx}]"
system_context += f"{ref_label} (PMID {pmid}): {summarized_map[pmid]}\n\n"
system_context += "Use this contextual info to provide a concise, evidence-based answer."
return system_context
###############################################################################
# STREAMLIT APP #
###############################################################################
def main():
st.set_page_config(page_title="Enhanced RAG + PubMed", layout="wide")
st.title("Enhanced RAG + PubMed: Production-Ready Medical Insights")
st.markdown("""
**Welcome** to an advanced demonstration of **Retrieval-Augmented Generation (RAG)**
using PubMed E-utilities, Hugging Face Summarization, and optional LLM calls (OpenAI or Gemini).
This version includes:
- **Parallel** fetching for multiple PMIDs
- Advanced **chunking & summarization** of large abstracts
- **Reference labeling** in the final answer
- Clear disclaimers & best-practice structures
---
**Disclaimer**: This is a demonstration prototype for educational or research purposes.
It is *not* a substitute for professional medical advice. Always consult a qualified
healthcare provider for personal health decisions.
""")
user_query = st.text_area(
"Enter your medical question or topic:",
placeholder="e.g., 'What are the latest treatments for type 2 diabetes complications?'",
height=120
)
# Sidebar or columns for parameters
col1, col2 = st.columns(2)
with col1:
max_papers = st.slider(
"Number of PubMed Articles to Retrieve",
min_value=1,
max_value=10,
value=3,
help="Number of articles to fetch & summarize."
)
with col2:
selected_llm = st.selectbox(
"Select LLM for Final Generation",
["OpenAI: GPT-3.5", "Gemini: PaLM2"],
help="Choose which large language model to finalize the answer."
)
# Additional advanced parameter: chunk size
chunk_size = st.slider(
"Summarization Chunk Size (words)",
min_value=256,
max_value=1024,
value=512,
help="Larger chunks might produce fewer summaries, but risk token limits. Smaller chunks produce more robust summaries."
)
if st.button("Run Enhanced RAG Pipeline"):
if not user_query.strip():
st.warning("Please enter a query before running RAG.")
return
# 1. PubMed Search
with st.spinner("Searching PubMed..."):
pmids = search_pubmed(query=user_query, max_results=max_papers)
if not pmids:
st.error("No matching PubMed results. Try a different query.")
return
# 2. Fetch abstracts in parallel
with st.spinner("Fetching and summarizing abstracts..."):
abstracts_map = fetch_pubmed_abstracts(pmids)
summarized_map = {}
for pmid, abstract_text in abstracts_map.items():
if "Error fetching" in abstract_text:
summarized_map[pmid] = abstract_text
else:
summarized_map[pmid] = chunk_and_summarize(abstract_text, chunk_size=chunk_size)
# 3. Display Summaries
st.subheader("Retrieved & Summarized PubMed Articles")
for idx, pmid in enumerate(pmids, start=1):
ref_label = f"[Ref{idx}]"
st.markdown(f"**{ref_label} PMID {pmid}**")
st.write(summarized_map[pmid])
st.write("---")
# 4. Build System Prompt
st.subheader("Final Answer")
system_prompt = build_system_prompt_with_refs(pmids, summarized_map)
with st.spinner("Generating final answer..."):
if selected_llm == "OpenAI: GPT-3.5":
answer = openai_chat(system_prompt=system_prompt, user_message=user_query)
else:
answer = gemini_chat(system_prompt=system_prompt, user_message=user_query)
st.write(answer)
st.success("RAG Pipeline Complete.")
# Production Considerations & Next Steps
st.markdown("---")
st.markdown("""
### Production-Ready Enhancements:
1. **Vector Databases & Advanced Retrieval**
- For large-scale usage, index PubMed articles in a vector DB (e.g. Pinecone, Weaviate) to quickly retrieve relevant passages.
2. **Citation Parsing**
- Automatically detect which abstract chunks contributed to each sentence.
3. **Multi-Lingual**
- Integrate translation pipelines for non-English queries or abstracts.
4. **Rate Limiting**
- Respect NCBI's ~3 requests/sec guideline if you're scaling out.
5. **Robust Logging & Error Handling**
- Build out logs, handle exceptions gracefully, and provide fallback prompts if an LLM fails or an abstract is missing.
6. **Privacy & Security**
- This demo only fetches public info. For patient data, ensure HIPAA/GDPR compliance and encrypted data pipelines.
""")
if __name__ == "__main__":
main()