mgbam commited on
Commit
8225d31
·
verified ·
1 Parent(s): 325b480

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +353 -0
app.py CHANGED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import math
5
+ import requests
6
+ import threading
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+
9
+ import streamlit as st
10
+ import pandas as pd
11
+
12
+ # NLP
13
+ import nltk
14
+ nltk.download('punkt')
15
+ from nltk.tokenize import sent_tokenize
16
+
17
+ # Hugging Face Transformers
18
+ from transformers import pipeline
19
+
20
+ # Optional: OpenAI and Google Generative AI
21
+ import openai
22
+ import google.generativeai as genai
23
+
24
+ ###############################################################################
25
+ # CONFIG & ENV #
26
+ ###############################################################################
27
+ """
28
+ In your Hugging Face Space:
29
+ 1. Add environment secrets:
30
+ - OPENAI_API_KEY (if using OpenAI)
31
+ - GEMINI_API_KEY (if using Google PaLM/Gemini)
32
+ - MY_PUBMED_EMAIL (to identify yourself to NCBI)
33
+ 2. In requirements.txt, install:
34
+ - streamlit
35
+ - requests
36
+ - nltk
37
+ - transformers
38
+ - torch
39
+ - openai (if using OpenAI)
40
+ - google-generativeai (if using Gemini)
41
+ - pandas
42
+ """
43
+
44
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
45
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
46
+ MY_PUBMED_EMAIL = os.getenv("MY_PUBMED_EMAIL", "[email protected]")
47
+
48
+ if OPENAI_API_KEY:
49
+ openai.api_key = OPENAI_API_KEY
50
+
51
+ if GEMINI_API_KEY:
52
+ genai.configure(api_key=GEMINI_API_KEY)
53
+
54
+ ###############################################################################
55
+ # SUMMARIZATION PIPELINE #
56
+ ###############################################################################
57
+ @st.cache_resource
58
+ def load_summarizer():
59
+ """
60
+ Load a summarization model (e.g., BART, PEGASUS, T5).
61
+ For a more concise summarization, consider: 'google/pegasus-xsum'
62
+ For a balanced approach, 'facebook/bart-large-cnn' is popular.
63
+ """
64
+ return pipeline(
65
+ "summarization",
66
+ model="facebook/bart-large-cnn",
67
+ tokenizer="facebook/bart-large-cnn"
68
+ )
69
+
70
+ summarizer = load_summarizer()
71
+
72
+ ###############################################################################
73
+ # PUBMED RETRIEVAL (NCBI E-utilities) #
74
+ ###############################################################################
75
+ def search_pubmed(query, max_results=3):
76
+ """
77
+ Searches PubMed for PMIDs matching the query.
78
+ Includes recommended 'tool' and 'email' in the request.
79
+ """
80
+ base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
81
+ params = {
82
+ "db": "pubmed",
83
+ "term": query,
84
+ "retmax": max_results,
85
+ "retmode": "json",
86
+ "tool": "ElysiumRAG",
87
+ "email": MY_PUBMED_EMAIL
88
+ }
89
+ resp = requests.get(base_url, params=params)
90
+ resp.raise_for_status()
91
+ data = resp.json()
92
+ id_list = data.get("esearchresult", {}).get("idlist", [])
93
+ return id_list
94
+
95
+ def fetch_one_abstract(pmid):
96
+ """
97
+ Fetches a single abstract for a given PMID using EFetch.
98
+ Returns (pmid, text).
99
+ """
100
+ base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
101
+ params = {
102
+ "db": "pubmed",
103
+ "retmode": "text",
104
+ "rettype": "abstract",
105
+ "id": pmid,
106
+ "tool": "ElysiumRAG",
107
+ "email": MY_PUBMED_EMAIL
108
+ }
109
+ resp = requests.get(base_url, params=params)
110
+ resp.raise_for_status()
111
+ raw_text = resp.text.strip()
112
+
113
+ # If there's no clear text returned, mark as empty
114
+ if not raw_text:
115
+ return (pmid, "No abstract text found.")
116
+
117
+ return (pmid, raw_text)
118
+
119
+ def fetch_pubmed_abstracts(pmids):
120
+ """
121
+ Parallel fetching of multiple PMIDs to reduce overall latency.
122
+ Returns {pmid: abstract_text}.
123
+ """
124
+ abstracts_map = {}
125
+ with ThreadPoolExecutor(max_workers=min(len(pmids), 5)) as executor:
126
+ future_to_pmid = {executor.submit(fetch_one_abstract, pmid): pmid for pmid in pmids}
127
+ for future in as_completed(future_to_pmid):
128
+ pmid = future_to_pmid[future]
129
+ try:
130
+ pmid_result, text = future.result()
131
+ abstracts_map[pmid_result] = text
132
+ except Exception as e:
133
+ abstracts_map[pmid] = f"Error fetching abstract: {str(e)}"
134
+ return abstracts_map
135
+
136
+ ###############################################################################
137
+ # ABSTRACT CHUNKING + SUMMARIZATION LOGIC #
138
+ ###############################################################################
139
+ def chunk_and_summarize(abstract_text, chunk_size=512):
140
+ """
141
+ Splits a large abstract into manageable chunks (by sentences),
142
+ then summarizes each chunk with the Hugging Face pipeline.
143
+ Returns a combined summary for the entire abstract.
144
+ """
145
+ # We first split by sentences
146
+ sentences = sent_tokenize(abstract_text)
147
+ chunks = []
148
+
149
+ current_chunk = []
150
+ current_length = 0
151
+ for sent in sentences:
152
+ tokens_in_sent = len(sent.split())
153
+ # If adding this sentence exceeds the chunk_size limit, finalize the chunk
154
+ if current_length + tokens_in_sent > chunk_size:
155
+ chunks.append(" ".join(current_chunk))
156
+ current_chunk = []
157
+ current_length = 0
158
+ current_chunk.append(sent)
159
+ current_length += tokens_in_sent
160
+
161
+ # Final chunk if it exists
162
+ if current_chunk:
163
+ chunks.append(" ".join(current_chunk))
164
+
165
+ # Summarize each chunk to avoid hitting token or length constraints
166
+ summarized_pieces = []
167
+ for c in chunks:
168
+ summary_out = summarizer(
169
+ c,
170
+ max_length=100, # tweak for desired summary length
171
+ min_length=30,
172
+ do_sample=False
173
+ )
174
+ summarized_pieces.append(summary_out[0]['summary_text'])
175
+
176
+ # Combine partial summaries into one final text
177
+ final_summary = " ".join(summarized_pieces)
178
+ return final_summary.strip()
179
+
180
+ ###############################################################################
181
+ # LLM CALLS (OpenAI / Gemini) #
182
+ ###############################################################################
183
+ def openai_chat(system_prompt, user_message, model="gpt-3.5-turbo", temperature=0.3):
184
+ """
185
+ Basic ChatCompletion with a system + user role for OpenAI.
186
+ """
187
+ if not OPENAI_API_KEY:
188
+ return "Error: OpenAI API key not provided."
189
+ try:
190
+ response = openai.ChatCompletion.create(
191
+ model=model,
192
+ messages=[
193
+ {"role": "system", "content": system_prompt},
194
+ {"role": "user", "content": user_message}
195
+ ],
196
+ temperature=temperature
197
+ )
198
+ return response.choices[0].message["content"].strip()
199
+ except Exception as e:
200
+ return f"Error calling OpenAI: {str(e)}"
201
+
202
+ def gemini_chat(system_prompt, user_message, model_name="models/chat-bison-001", temperature=0.3):
203
+ """
204
+ Basic PaLM2/Gemini chat call using google.generativeai.
205
+ """
206
+ if not GEMINI_API_KEY:
207
+ return "Error: Gemini API key not provided."
208
+ try:
209
+ model = genai.GenerativeModel(model_name=model_name)
210
+ chat_session = model.start_chat(history=[("system", system_prompt)])
211
+ reply = chat_session.send_message(user_message, temperature=temperature)
212
+ return reply.text
213
+ except Exception as e:
214
+ return f"Error calling Gemini: {str(e)}"
215
+
216
+ ###############################################################################
217
+ # BUILD REFERENCES FOR ANSWER #
218
+ ###############################################################################
219
+ def build_system_prompt_with_refs(pmids, summarized_map):
220
+ """
221
+ Creates a system prompt that includes the summarized abstracts alongside
222
+ labeled references. This allows the LLM to quote or cite specific references.
223
+ """
224
+ # Example of labeling references: [Ref1], [Ref2], etc.
225
+ system_context = (
226
+ "You have access to the following summarized PubMed articles. "
227
+ "When relevant, cite them in your final answer using their reference label.\n\n"
228
+ )
229
+ for idx, pmid in enumerate(pmids, start=1):
230
+ ref_label = f"[Ref{idx}]"
231
+ system_context += f"{ref_label} (PMID {pmid}): {summarized_map[pmid]}\n\n"
232
+ system_context += "Use this contextual info to provide a concise, evidence-based answer."
233
+ return system_context
234
+
235
+ ###############################################################################
236
+ # STREAMLIT APP #
237
+ ###############################################################################
238
+ def main():
239
+ st.set_page_config(page_title="Enhanced RAG + PubMed", layout="wide")
240
+ st.title("Enhanced RAG + PubMed: Production-Ready Medical Insights")
241
+
242
+ st.markdown("""
243
+ **Welcome** to an advanced demonstration of **Retrieval-Augmented Generation (RAG)**
244
+ using PubMed E-utilities, Hugging Face Summarization, and optional LLM calls (OpenAI or Gemini).
245
+
246
+ This version includes:
247
+ - **Parallel** fetching for multiple PMIDs
248
+ - Advanced **chunking & summarization** of large abstracts
249
+ - **Reference labeling** in the final answer
250
+ - Clear disclaimers & best-practice structures
251
+
252
+ ---
253
+ **Disclaimer**: This is a demonstration prototype for educational or research purposes.
254
+ It is *not* a substitute for professional medical advice. Always consult a qualified
255
+ healthcare provider for personal health decisions.
256
+ """)
257
+
258
+ user_query = st.text_area(
259
+ "Enter your medical question or topic:",
260
+ placeholder="e.g., 'What are the latest treatments for type 2 diabetes complications?'",
261
+ height=120
262
+ )
263
+
264
+ # Sidebar or columns for parameters
265
+ col1, col2 = st.columns(2)
266
+ with col1:
267
+ max_papers = st.slider(
268
+ "Number of PubMed Articles to Retrieve",
269
+ min_value=1,
270
+ max_value=10,
271
+ value=3,
272
+ help="Number of articles to fetch & summarize."
273
+ )
274
+ with col2:
275
+ selected_llm = st.selectbox(
276
+ "Select LLM for Final Generation",
277
+ ["OpenAI: GPT-3.5", "Gemini: PaLM2"],
278
+ help="Choose which large language model to finalize the answer."
279
+ )
280
+
281
+ # Additional advanced parameter: chunk size
282
+ chunk_size = st.slider(
283
+ "Summarization Chunk Size (words)",
284
+ min_value=256,
285
+ max_value=1024,
286
+ value=512,
287
+ help="Larger chunks might produce fewer summaries, but risk token limits. Smaller chunks produce more robust summaries."
288
+ )
289
+
290
+ if st.button("Run Enhanced RAG Pipeline"):
291
+ if not user_query.strip():
292
+ st.warning("Please enter a query before running RAG.")
293
+ return
294
+
295
+ # 1. PubMed Search
296
+ with st.spinner("Searching PubMed..."):
297
+ pmids = search_pubmed(query=user_query, max_results=max_papers)
298
+
299
+ if not pmids:
300
+ st.error("No matching PubMed results. Try a different query.")
301
+ return
302
+
303
+ # 2. Fetch abstracts in parallel
304
+ with st.spinner("Fetching and summarizing abstracts..."):
305
+ abstracts_map = fetch_pubmed_abstracts(pmids)
306
+ summarized_map = {}
307
+ for pmid, abstract_text in abstracts_map.items():
308
+ if "Error fetching" in abstract_text:
309
+ summarized_map[pmid] = abstract_text
310
+ else:
311
+ summarized_map[pmid] = chunk_and_summarize(abstract_text, chunk_size=chunk_size)
312
+
313
+ # 3. Display Summaries
314
+ st.subheader("Retrieved & Summarized PubMed Articles")
315
+ for idx, pmid in enumerate(pmids, start=1):
316
+ ref_label = f"[Ref{idx}]"
317
+ st.markdown(f"**{ref_label} PMID {pmid}**")
318
+ st.write(summarized_map[pmid])
319
+ st.write("---")
320
+
321
+ # 4. Build System Prompt
322
+ st.subheader("Final Answer")
323
+ system_prompt = build_system_prompt_with_refs(pmids, summarized_map)
324
+
325
+ with st.spinner("Generating final answer..."):
326
+ if selected_llm == "OpenAI: GPT-3.5":
327
+ answer = openai_chat(system_prompt=system_prompt, user_message=user_query)
328
+ else:
329
+ answer = gemini_chat(system_prompt=system_prompt, user_message=user_query)
330
+
331
+ st.write(answer)
332
+ st.success("RAG Pipeline Complete.")
333
+
334
+ # Production Considerations & Next Steps
335
+ st.markdown("---")
336
+ st.markdown("""
337
+ ### Production-Ready Enhancements:
338
+ 1. **Vector Databases & Advanced Retrieval**
339
+ - For large-scale usage, index PubMed articles in a vector DB (e.g. Pinecone, Weaviate) to quickly retrieve relevant passages.
340
+ 2. **Citation Parsing**
341
+ - Automatically detect which abstract chunks contributed to each sentence.
342
+ 3. **Multi-Lingual**
343
+ - Integrate translation pipelines for non-English queries or abstracts.
344
+ 4. **Rate Limiting**
345
+ - Respect NCBI's ~3 requests/sec guideline if you're scaling out.
346
+ 5. **Robust Logging & Error Handling**
347
+ - Build out logs, handle exceptions gracefully, and provide fallback prompts if an LLM fails or an abstract is missing.
348
+ 6. **Privacy & Security**
349
+ - This demo only fetches public info. For patient data, ensure HIPAA/GDPR compliance and encrypted data pipelines.
350
+ """)
351
+
352
+ if __name__ == "__main__":
353
+ main()