File size: 6,648 Bytes
ff77b73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import requests
import nltk
nltk.download("punkt")
from nltk.tokenize import sent_tokenize
from concurrent.futures import ThreadPoolExecutor, as_completed
from transformers import pipeline, AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
import os
import faiss
import numpy as np
import json
from config import (
PUBMED_EMAIL,
MAX_PUBMED_RESULTS,
DEFAULT_SUMMARIZATION_CHUNK,
VECTORDB_PATH,
EMBEDDING_MODEL_NAME
)
###############################################################################
# SUMMARIZATION & EMBEDDINGS #
###############################################################################
summarizer = pipeline(
"summarization",
model="facebook/bart-large-cnn",
tokenizer="facebook/bart-large-cnn",
)
embed_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
###############################################################################
# PUBMED UTIL FUNCTIONS #
###############################################################################
def search_pubmed(query, max_results=MAX_PUBMED_RESULTS):
"""
Search PubMed for PMIDs matching a query. Returns a list of PMIDs.
"""
url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
params = {
"db": "pubmed",
"term": query,
"retmax": max_results,
"retmode": "json",
"tool": "AdvancedMedicalAI",
"email": PUBMED_EMAIL
}
resp = requests.get(url, params=params)
resp.raise_for_status()
data = resp.json()
return data.get("esearchresult", {}).get("idlist", [])
def fetch_abstract(pmid):
"""
Fetches an abstract for a single PMID via EFetch.
"""
url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
params = {
"db": "pubmed",
"id": pmid,
"retmode": "text",
"rettype": "abstract",
"tool": "AdvancedMedicalAI",
"email": PUBMED_EMAIL
}
resp = requests.get(url, params=params)
resp.raise_for_status()
return resp.text.strip()
def fetch_pubmed_abstracts(pmids):
"""
Parallel fetch for multiple PMIDs. Returns dict {pmid: text}.
"""
results = {}
with ThreadPoolExecutor(max_workers=min(len(pmids), 5)) as executor:
future_to_pmid = {executor.submit(fetch_abstract, pmid): pmid for pmid in pmids}
for future in as_completed(future_to_pmid):
pmid = future_to_pmid[future]
try:
results[pmid] = future.result()
except Exception as e:
results[pmid] = f"Error fetching PMID {pmid}: {str(e)}"
return results
###############################################################################
# SUMMARIZE & CHUNK TEXT #
###############################################################################
def chunk_and_summarize(raw_text, chunk_size=DEFAULT_SUMMARIZATION_CHUNK):
"""
Splits large text into chunks by sentences, then summarizes each chunk, merging results.
"""
sentences = sent_tokenize(raw_text)
chunks = []
current_chunk = []
current_length = 0
for sent in sentences:
token_count = len(sent.split())
if current_length + token_count > chunk_size:
chunks.append(" ".join(current_chunk))
current_chunk = []
current_length = 0
current_chunk.append(sent)
current_length += token_count
if current_chunk:
chunks.append(" ".join(current_chunk))
summary_list = []
for c in chunks:
summ = summarizer(c, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
summary_list.append(summ)
return " ".join(summary_list)
###############################################################################
# SIMPLE VECTOR STORE (FAISS) FOR RAG #
###############################################################################
def create_or_load_faiss_index():
"""
Creates a new FAISS index or loads from disk if it exists.
"""
index_path = os.path.join(VECTORDB_PATH, "faiss_index.bin")
meta_path = os.path.join(VECTORDB_PATH, "faiss_meta.json")
if not os.path.exists(VECTORDB_PATH):
os.makedirs(VECTORDB_PATH)
if os.path.exists(index_path) and os.path.exists(meta_path):
# Load existing index
index = faiss.read_index(index_path)
with open(meta_path, "r") as f:
meta_data = json.load(f)
return index, meta_data
else:
# Create new index
index = faiss.IndexFlatL2(embed_model.get_sentence_embedding_dimension())
meta_data = {}
return index, meta_data
def save_faiss_index(index, meta_data):
"""
Saves the FAISS index and metadata to disk.
"""
index_path = os.path.join(VECTORDB_PATH, "faiss_index.bin")
meta_path = os.path.join(VECTORDB_PATH, "faiss_meta.json")
faiss.write_index(index, index_path)
with open(meta_path, "w") as f:
json.dump(meta_data, f)
def upsert_documents(docs):
"""
Takes in a dict of {pmid: text}, embeds and upserts them into the FAISS index.
Each doc is stored in 'meta_data' with pmid as key.
"""
index, meta_data = create_or_load_faiss_index()
texts = list(docs.values())
pmids = list(docs.keys())
embeddings = embed_model.encode(texts, convert_to_numpy=True)
index.add(embeddings)
# Maintain a simple meta_data: { int_id: { 'pmid': X, 'text': Y } }
# Where int_id is the row in the index
start_id = len(meta_data)
for i, pmid in enumerate(pmids):
meta_data[str(start_id + i)] = {"pmid": pmid, "text": texts[i]}
save_faiss_index(index, meta_data)
def semantic_search(query, top_k=3):
"""
Embeds 'query' and searches the FAISS index for top_k similar docs.
Returns a list of dict with 'pmid' and 'text'.
"""
index, meta_data = create_or_load_faiss_index()
query_embedding = embed_model.encode([query], convert_to_numpy=True)
distances, indices = index.search(query_embedding, top_k)
results = []
for dist, idx_list in zip(distances, indices):
for d, i in zip(dist, idx_list):
# i is row in the index, look up meta_data
doc_info = meta_data[str(i)]
results.append({"pmid": doc_info["pmid"], "text": doc_info["text"], "score": float(d)})
# Sort by ascending distance => best match first
results.sort(key=lambda x: x["score"])
return results
|