|
import requests |
|
import nltk |
|
nltk.download("punkt") |
|
from nltk.tokenize import sent_tokenize |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
from transformers import pipeline, AutoTokenizer, AutoModel |
|
from sentence_transformers import SentenceTransformer |
|
import os |
|
import faiss |
|
import numpy as np |
|
import json |
|
|
|
from config import ( |
|
PUBMED_EMAIL, |
|
MAX_PUBMED_RESULTS, |
|
DEFAULT_SUMMARIZATION_CHUNK, |
|
VECTORDB_PATH, |
|
EMBEDDING_MODEL_NAME |
|
) |
|
|
|
|
|
|
|
|
|
|
|
summarizer = pipeline( |
|
"summarization", |
|
model="facebook/bart-large-cnn", |
|
tokenizer="facebook/bart-large-cnn", |
|
) |
|
|
|
embed_model = SentenceTransformer(EMBEDDING_MODEL_NAME) |
|
|
|
|
|
|
|
|
|
|
|
def search_pubmed(query, max_results=MAX_PUBMED_RESULTS): |
|
""" |
|
Search PubMed for PMIDs matching a query. Returns a list of PMIDs. |
|
""" |
|
url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" |
|
params = { |
|
"db": "pubmed", |
|
"term": query, |
|
"retmax": max_results, |
|
"retmode": "json", |
|
"tool": "AdvancedMedicalAI", |
|
"email": PUBMED_EMAIL |
|
} |
|
resp = requests.get(url, params=params) |
|
resp.raise_for_status() |
|
data = resp.json() |
|
return data.get("esearchresult", {}).get("idlist", []) |
|
|
|
def fetch_abstract(pmid): |
|
""" |
|
Fetches an abstract for a single PMID via EFetch. |
|
""" |
|
url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi" |
|
params = { |
|
"db": "pubmed", |
|
"id": pmid, |
|
"retmode": "text", |
|
"rettype": "abstract", |
|
"tool": "AdvancedMedicalAI", |
|
"email": PUBMED_EMAIL |
|
} |
|
resp = requests.get(url, params=params) |
|
resp.raise_for_status() |
|
return resp.text.strip() |
|
|
|
def fetch_pubmed_abstracts(pmids): |
|
""" |
|
Parallel fetch for multiple PMIDs. Returns dict {pmid: text}. |
|
""" |
|
results = {} |
|
with ThreadPoolExecutor(max_workers=min(len(pmids), 5)) as executor: |
|
future_to_pmid = {executor.submit(fetch_abstract, pmid): pmid for pmid in pmids} |
|
for future in as_completed(future_to_pmid): |
|
pmid = future_to_pmid[future] |
|
try: |
|
results[pmid] = future.result() |
|
except Exception as e: |
|
results[pmid] = f"Error fetching PMID {pmid}: {str(e)}" |
|
return results |
|
|
|
|
|
|
|
|
|
|
|
def chunk_and_summarize(raw_text, chunk_size=DEFAULT_SUMMARIZATION_CHUNK): |
|
""" |
|
Splits large text into chunks by sentences, then summarizes each chunk, merging results. |
|
""" |
|
sentences = sent_tokenize(raw_text) |
|
chunks = [] |
|
current_chunk = [] |
|
current_length = 0 |
|
|
|
for sent in sentences: |
|
token_count = len(sent.split()) |
|
if current_length + token_count > chunk_size: |
|
chunks.append(" ".join(current_chunk)) |
|
current_chunk = [] |
|
current_length = 0 |
|
current_chunk.append(sent) |
|
current_length += token_count |
|
|
|
if current_chunk: |
|
chunks.append(" ".join(current_chunk)) |
|
|
|
summary_list = [] |
|
for c in chunks: |
|
summ = summarizer(c, max_length=100, min_length=30, do_sample=False)[0]["summary_text"] |
|
summary_list.append(summ) |
|
return " ".join(summary_list) |
|
|
|
|
|
|
|
|
|
|
|
def create_or_load_faiss_index(): |
|
""" |
|
Creates a new FAISS index or loads from disk if it exists. |
|
""" |
|
index_path = os.path.join(VECTORDB_PATH, "faiss_index.bin") |
|
meta_path = os.path.join(VECTORDB_PATH, "faiss_meta.json") |
|
|
|
if not os.path.exists(VECTORDB_PATH): |
|
os.makedirs(VECTORDB_PATH) |
|
|
|
if os.path.exists(index_path) and os.path.exists(meta_path): |
|
|
|
index = faiss.read_index(index_path) |
|
with open(meta_path, "r") as f: |
|
meta_data = json.load(f) |
|
return index, meta_data |
|
else: |
|
|
|
index = faiss.IndexFlatL2(embed_model.get_sentence_embedding_dimension()) |
|
meta_data = {} |
|
return index, meta_data |
|
|
|
def save_faiss_index(index, meta_data): |
|
""" |
|
Saves the FAISS index and metadata to disk. |
|
""" |
|
index_path = os.path.join(VECTORDB_PATH, "faiss_index.bin") |
|
meta_path = os.path.join(VECTORDB_PATH, "faiss_meta.json") |
|
|
|
faiss.write_index(index, index_path) |
|
with open(meta_path, "w") as f: |
|
json.dump(meta_data, f) |
|
|
|
def upsert_documents(docs): |
|
""" |
|
Takes in a dict of {pmid: text}, embeds and upserts them into the FAISS index. |
|
Each doc is stored in 'meta_data' with pmid as key. |
|
""" |
|
index, meta_data = create_or_load_faiss_index() |
|
|
|
texts = list(docs.values()) |
|
pmids = list(docs.keys()) |
|
|
|
embeddings = embed_model.encode(texts, convert_to_numpy=True) |
|
index.add(embeddings) |
|
|
|
|
|
|
|
start_id = len(meta_data) |
|
for i, pmid in enumerate(pmids): |
|
meta_data[str(start_id + i)] = {"pmid": pmid, "text": texts[i]} |
|
|
|
save_faiss_index(index, meta_data) |
|
|
|
def semantic_search(query, top_k=3): |
|
""" |
|
Embeds 'query' and searches the FAISS index for top_k similar docs. |
|
Returns a list of dict with 'pmid' and 'text'. |
|
""" |
|
index, meta_data = create_or_load_faiss_index() |
|
|
|
query_embedding = embed_model.encode([query], convert_to_numpy=True) |
|
distances, indices = index.search(query_embedding, top_k) |
|
|
|
results = [] |
|
for dist, idx_list in zip(distances, indices): |
|
for d, i in zip(dist, idx_list): |
|
|
|
doc_info = meta_data[str(i)] |
|
results.append({"pmid": doc_info["pmid"], "text": doc_info["text"], "score": float(d)}) |
|
|
|
results.sort(key=lambda x: x["score"]) |
|
return results |
|
|