Medapp / pubmed_rag.py
mgbam's picture
Update pubmed_rag.py
f33af5a verified
import requests
from transformers import pipeline
from nltk.tokenize import sent_tokenize
import nltk
from config import MY_PUBMED_EMAIL, MAX_PUBMED_RESULTS, SUMMARIZATION_CHUNK_SIZE
nltk.download("punkt")
# Summarization pipeline
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
def search_pubmed(query, max_results=MAX_PUBMED_RESULTS):
"""
Search PubMed for articles matching the query.
"""
url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
params = {
"db": "pubmed",
"term": query,
"retmax": max_results,
"retmode": "json",
"tool": "AdvancedMedicalAI",
"email": MY_PUBMED_EMAIL,
}
response = requests.get(url, params=params, timeout=10)
response.raise_for_status()
data = response.json()
return data.get("esearchresult", {}).get("idlist", [])
def fetch_abstract(pmid):
"""
Fetch the abstract of a given PubMed ID.
"""
url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
params = {
"db": "pubmed",
"id": pmid,
"retmode": "text",
"rettype": "abstract",
"tool": "AdvancedMedicalAI",
"email": MY_PUBMED_EMAIL,
}
response = requests.get(url, params=params, timeout=10)
response.raise_for_status()
return response.text.strip()
def fetch_pubmed_abstracts(pmids):
"""
Fetch multiple abstracts for a list of PMIDs.
"""
results = {}
for pmid in pmids:
try:
abstract = fetch_abstract(pmid)
results[pmid] = abstract
except Exception as e:
results[pmid] = f"Error fetching PMID {pmid}: {e}"
return results
def summarize_text(text, chunk_size=SUMMARIZATION_CHUNK_SIZE):
"""
Summarize long text using a chunking strategy.
"""
sentences = sent_tokenize(text)
chunks = []
current_chunk = []
current_length = 0
for sentence in sentences:
tokens = len(sentence.split())
if current_length + tokens > chunk_size:
chunks.append(" ".join(current_chunk))
current_chunk = []
current_length = 0
current_chunk.append(sentence)
current_length += tokens
if current_chunk:
chunks.append(" ".join(current_chunk))
summaries = []
for chunk in chunks:
summary = summarizer(chunk, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
summaries.append(summary)
return " ".join(summaries)