File size: 13,374 Bytes
8225d31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import os
import re
import json
import math
import requests
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed

import streamlit as st
import pandas as pd

# NLP
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize

# Hugging Face Transformers
from transformers import pipeline

# Optional: OpenAI and Google Generative AI
import openai
import google.generativeai as genai

###############################################################################
#                              CONFIG & ENV                                   #
###############################################################################
"""
In your Hugging Face Space:
1. Add environment secrets:
   - OPENAI_API_KEY       (if using OpenAI)
   - GEMINI_API_KEY       (if using Google PaLM/Gemini)
   - MY_PUBMED_EMAIL      (to identify yourself to NCBI)
2. In requirements.txt, install:
   - streamlit
   - requests
   - nltk
   - transformers
   - torch
   - openai (if using OpenAI)
   - google-generativeai (if using Gemini)
   - pandas
"""

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
MY_PUBMED_EMAIL = os.getenv("MY_PUBMED_EMAIL", "[email protected]")

if OPENAI_API_KEY:
    openai.api_key = OPENAI_API_KEY

if GEMINI_API_KEY:
    genai.configure(api_key=GEMINI_API_KEY)

###############################################################################
#                           SUMMARIZATION PIPELINE                            #
###############################################################################
@st.cache_resource
def load_summarizer():
    """
    Load a summarization model (e.g., BART, PEGASUS, T5).
    For a more concise summarization, consider: 'google/pegasus-xsum'
    For a balanced approach, 'facebook/bart-large-cnn' is popular.
    """
    return pipeline(
        "summarization", 
        model="facebook/bart-large-cnn", 
        tokenizer="facebook/bart-large-cnn"
    )

summarizer = load_summarizer()

###############################################################################
#                      PUBMED RETRIEVAL (NCBI E-utilities)                    #
###############################################################################
def search_pubmed(query, max_results=3):
    """
    Searches PubMed for PMIDs matching the query.
    Includes recommended 'tool' and 'email' in the request.
    """
    base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
    params = {
        "db": "pubmed",
        "term": query,
        "retmax": max_results,
        "retmode": "json",
        "tool": "ElysiumRAG",
        "email": MY_PUBMED_EMAIL
    }
    resp = requests.get(base_url, params=params)
    resp.raise_for_status()
    data = resp.json()
    id_list = data.get("esearchresult", {}).get("idlist", [])
    return id_list

def fetch_one_abstract(pmid):
    """
    Fetches a single abstract for a given PMID using EFetch.
    Returns (pmid, text).
    """
    base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
    params = {
        "db": "pubmed",
        "retmode": "text",
        "rettype": "abstract",
        "id": pmid,
        "tool": "ElysiumRAG",
        "email": MY_PUBMED_EMAIL
    }
    resp = requests.get(base_url, params=params)
    resp.raise_for_status()
    raw_text = resp.text.strip()
    
    # If there's no clear text returned, mark as empty
    if not raw_text:
        return (pmid, "No abstract text found.")
    
    return (pmid, raw_text)

def fetch_pubmed_abstracts(pmids):
    """
    Parallel fetching of multiple PMIDs to reduce overall latency.
    Returns {pmid: abstract_text}.
    """
    abstracts_map = {}
    with ThreadPoolExecutor(max_workers=min(len(pmids), 5)) as executor:
        future_to_pmid = {executor.submit(fetch_one_abstract, pmid): pmid for pmid in pmids}
        for future in as_completed(future_to_pmid):
            pmid = future_to_pmid[future]
            try:
                pmid_result, text = future.result()
                abstracts_map[pmid_result] = text
            except Exception as e:
                abstracts_map[pmid] = f"Error fetching abstract: {str(e)}"
    return abstracts_map

###############################################################################
#                  ABSTRACT CHUNKING + SUMMARIZATION LOGIC                    #
###############################################################################
def chunk_and_summarize(abstract_text, chunk_size=512):
    """
    Splits a large abstract into manageable chunks (by sentences),
    then summarizes each chunk with the Hugging Face pipeline.
    Returns a combined summary for the entire abstract.
    """
    # We first split by sentences
    sentences = sent_tokenize(abstract_text)
    chunks = []
    
    current_chunk = []
    current_length = 0
    for sent in sentences:
        tokens_in_sent = len(sent.split())
        # If adding this sentence exceeds the chunk_size limit, finalize the chunk
        if current_length + tokens_in_sent > chunk_size:
            chunks.append(" ".join(current_chunk))
            current_chunk = []
            current_length = 0
        current_chunk.append(sent)
        current_length += tokens_in_sent

    # Final chunk if it exists
    if current_chunk:
        chunks.append(" ".join(current_chunk))

    # Summarize each chunk to avoid hitting token or length constraints
    summarized_pieces = []
    for c in chunks:
        summary_out = summarizer(
            c,
            max_length=100,   # tweak for desired summary length
            min_length=30,
            do_sample=False
        )
        summarized_pieces.append(summary_out[0]['summary_text'])
    
    # Combine partial summaries into one final text
    final_summary = " ".join(summarized_pieces)
    return final_summary.strip()

###############################################################################
#                      LLM CALLS (OpenAI / Gemini)                            #
###############################################################################
def openai_chat(system_prompt, user_message, model="gpt-3.5-turbo", temperature=0.3):
    """
    Basic ChatCompletion with a system + user role for OpenAI.
    """
    if not OPENAI_API_KEY:
        return "Error: OpenAI API key not provided."
    try:
        response = openai.ChatCompletion.create(
            model=model,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_message}
            ],
            temperature=temperature
        )
        return response.choices[0].message["content"].strip()
    except Exception as e:
        return f"Error calling OpenAI: {str(e)}"

def gemini_chat(system_prompt, user_message, model_name="models/chat-bison-001", temperature=0.3):
    """
    Basic PaLM2/Gemini chat call using google.generativeai.
    """
    if not GEMINI_API_KEY:
        return "Error: Gemini API key not provided."
    try:
        model = genai.GenerativeModel(model_name=model_name)
        chat_session = model.start_chat(history=[("system", system_prompt)])
        reply = chat_session.send_message(user_message, temperature=temperature)
        return reply.text
    except Exception as e:
        return f"Error calling Gemini: {str(e)}"

###############################################################################
#                         BUILD REFERENCES FOR ANSWER                         #
###############################################################################
def build_system_prompt_with_refs(pmids, summarized_map):
    """
    Creates a system prompt that includes the summarized abstracts alongside 
    labeled references. This allows the LLM to quote or cite specific references.
    """
    # Example of labeling references: [Ref1], [Ref2], etc.
    system_context = (
        "You have access to the following summarized PubMed articles. "
        "When relevant, cite them in your final answer using their reference label.\n\n"
    )
    for idx, pmid in enumerate(pmids, start=1):
        ref_label = f"[Ref{idx}]"
        system_context += f"{ref_label} (PMID {pmid}): {summarized_map[pmid]}\n\n"
    system_context += "Use this contextual info to provide a concise, evidence-based answer."
    return system_context

###############################################################################
#                                STREAMLIT APP                                #
###############################################################################
def main():
    st.set_page_config(page_title="Enhanced RAG + PubMed", layout="wide")
    st.title("Enhanced RAG + PubMed: Production-Ready Medical Insights")

    st.markdown("""
    **Welcome** to an advanced demonstration of **Retrieval-Augmented Generation (RAG)** 
    using PubMed E-utilities, Hugging Face Summarization, and optional LLM calls (OpenAI or Gemini).
    
    This version includes:
    - **Parallel** fetching for multiple PMIDs
    - Advanced **chunking & summarization** of large abstracts
    - **Reference labeling** in the final answer
    - Clear disclaimers & best-practice structures
    
    ---
    **Disclaimer**: This is a demonstration prototype for educational or research purposes.
    It is *not* a substitute for professional medical advice. Always consult a qualified
    healthcare provider for personal health decisions.
    """)

    user_query = st.text_area(
        "Enter your medical question or topic:",
        placeholder="e.g., 'What are the latest treatments for type 2 diabetes complications?'",
        height=120
    )

    # Sidebar or columns for parameters
    col1, col2 = st.columns(2)
    with col1:
        max_papers = st.slider(
            "Number of PubMed Articles to Retrieve",
            min_value=1,
            max_value=10,
            value=3,
            help="Number of articles to fetch & summarize."
        )
    with col2:
        selected_llm = st.selectbox(
            "Select LLM for Final Generation",
            ["OpenAI: GPT-3.5", "Gemini: PaLM2"],
            help="Choose which large language model to finalize the answer."
        )

    # Additional advanced parameter: chunk size
    chunk_size = st.slider(
        "Summarization Chunk Size (words)",
        min_value=256,
        max_value=1024,
        value=512,
        help="Larger chunks might produce fewer summaries, but risk token limits. Smaller chunks produce more robust summaries."
    )

    if st.button("Run Enhanced RAG Pipeline"):
        if not user_query.strip():
            st.warning("Please enter a query before running RAG.")
            return

        # 1. PubMed Search
        with st.spinner("Searching PubMed..."):
            pmids = search_pubmed(query=user_query, max_results=max_papers)
        
        if not pmids:
            st.error("No matching PubMed results. Try a different query.")
            return

        # 2. Fetch abstracts in parallel
        with st.spinner("Fetching and summarizing abstracts..."):
            abstracts_map = fetch_pubmed_abstracts(pmids)
            summarized_map = {}
            for pmid, abstract_text in abstracts_map.items():
                if "Error fetching" in abstract_text:
                    summarized_map[pmid] = abstract_text
                else:
                    summarized_map[pmid] = chunk_and_summarize(abstract_text, chunk_size=chunk_size)

        # 3. Display Summaries
        st.subheader("Retrieved & Summarized PubMed Articles")
        for idx, pmid in enumerate(pmids, start=1):
            ref_label = f"[Ref{idx}]"
            st.markdown(f"**{ref_label} PMID {pmid}**")
            st.write(summarized_map[pmid])
            st.write("---")

        # 4. Build System Prompt
        st.subheader("Final Answer")
        system_prompt = build_system_prompt_with_refs(pmids, summarized_map)

        with st.spinner("Generating final answer..."):
            if selected_llm == "OpenAI: GPT-3.5":
                answer = openai_chat(system_prompt=system_prompt, user_message=user_query)
            else:
                answer = gemini_chat(system_prompt=system_prompt, user_message=user_query)

        st.write(answer)
        st.success("RAG Pipeline Complete.")

    # Production Considerations & Next Steps
    st.markdown("---")
    st.markdown("""
    ### Production-Ready Enhancements:
    1. **Vector Databases & Advanced Retrieval**  
       - For large-scale usage, index PubMed articles in a vector DB (e.g. Pinecone, Weaviate) to quickly retrieve relevant passages.  
    2. **Citation Parsing**  
       - Automatically detect which abstract chunks contributed to each sentence.  
    3. **Multi-Lingual**  
       - Integrate translation pipelines for non-English queries or abstracts.  
    4. **Rate Limiting**  
       - Respect NCBI's ~3 requests/sec guideline if you're scaling out.  
    5. **Robust Logging & Error Handling**  
       - Build out logs, handle exceptions gracefully, and provide fallback prompts if an LLM fails or an abstract is missing.  
    6. **Privacy & Security**  
       - This demo only fetches public info. For patient data, ensure HIPAA/GDPR compliance and encrypted data pipelines.
    """)

if __name__ == "__main__":
    main()