File size: 6,648 Bytes
ff77b73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import requests
import nltk
nltk.download("punkt")
from nltk.tokenize import sent_tokenize
from concurrent.futures import ThreadPoolExecutor, as_completed

from transformers import pipeline, AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
import os
import faiss
import numpy as np
import json

from config import (
    PUBMED_EMAIL,
    MAX_PUBMED_RESULTS,
    DEFAULT_SUMMARIZATION_CHUNK,
    VECTORDB_PATH,
    EMBEDDING_MODEL_NAME
)

###############################################################################
#                          SUMMARIZATION & EMBEDDINGS                         #
###############################################################################

summarizer = pipeline(
    "summarization",
    model="facebook/bart-large-cnn",
    tokenizer="facebook/bart-large-cnn",
)

embed_model = SentenceTransformer(EMBEDDING_MODEL_NAME)

###############################################################################
#                            PUBMED UTIL FUNCTIONS                            #
###############################################################################

def search_pubmed(query, max_results=MAX_PUBMED_RESULTS):
    """
    Search PubMed for PMIDs matching a query. Returns a list of PMIDs.
    """
    url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
    params = {
        "db": "pubmed",
        "term": query,
        "retmax": max_results,
        "retmode": "json",
        "tool": "AdvancedMedicalAI",
        "email": PUBMED_EMAIL
    }
    resp = requests.get(url, params=params)
    resp.raise_for_status()
    data = resp.json()
    return data.get("esearchresult", {}).get("idlist", [])

def fetch_abstract(pmid):
    """
    Fetches an abstract for a single PMID via EFetch.
    """
    url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
    params = {
        "db": "pubmed",
        "id": pmid,
        "retmode": "text",
        "rettype": "abstract",
        "tool": "AdvancedMedicalAI",
        "email": PUBMED_EMAIL
    }
    resp = requests.get(url, params=params)
    resp.raise_for_status()
    return resp.text.strip()

def fetch_pubmed_abstracts(pmids):
    """
    Parallel fetch for multiple PMIDs. Returns dict {pmid: text}.
    """
    results = {}
    with ThreadPoolExecutor(max_workers=min(len(pmids), 5)) as executor:
        future_to_pmid = {executor.submit(fetch_abstract, pmid): pmid for pmid in pmids}
        for future in as_completed(future_to_pmid):
            pmid = future_to_pmid[future]
            try:
                results[pmid] = future.result()
            except Exception as e:
                results[pmid] = f"Error fetching PMID {pmid}: {str(e)}"
    return results

###############################################################################
#                           SUMMARIZE & CHUNK TEXT                            #
###############################################################################

def chunk_and_summarize(raw_text, chunk_size=DEFAULT_SUMMARIZATION_CHUNK):
    """
    Splits large text into chunks by sentences, then summarizes each chunk, merging results.
    """
    sentences = sent_tokenize(raw_text)
    chunks = []
    current_chunk = []
    current_length = 0

    for sent in sentences:
        token_count = len(sent.split())
        if current_length + token_count > chunk_size:
            chunks.append(" ".join(current_chunk))
            current_chunk = []
            current_length = 0
        current_chunk.append(sent)
        current_length += token_count

    if current_chunk:
        chunks.append(" ".join(current_chunk))

    summary_list = []
    for c in chunks:
        summ = summarizer(c, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
        summary_list.append(summ)
    return " ".join(summary_list)

###############################################################################
#                    SIMPLE VECTOR STORE (FAISS) FOR RAG                      #
###############################################################################

def create_or_load_faiss_index():
    """
    Creates a new FAISS index or loads from disk if it exists.
    """
    index_path = os.path.join(VECTORDB_PATH, "faiss_index.bin")
    meta_path = os.path.join(VECTORDB_PATH, "faiss_meta.json")

    if not os.path.exists(VECTORDB_PATH):
        os.makedirs(VECTORDB_PATH)

    if os.path.exists(index_path) and os.path.exists(meta_path):
        # Load existing index
        index = faiss.read_index(index_path)
        with open(meta_path, "r") as f:
            meta_data = json.load(f)
        return index, meta_data
    else:
        # Create new index
        index = faiss.IndexFlatL2(embed_model.get_sentence_embedding_dimension())
        meta_data = {}
        return index, meta_data

def save_faiss_index(index, meta_data):
    """
    Saves the FAISS index and metadata to disk.
    """
    index_path = os.path.join(VECTORDB_PATH, "faiss_index.bin")
    meta_path = os.path.join(VECTORDB_PATH, "faiss_meta.json")

    faiss.write_index(index, index_path)
    with open(meta_path, "w") as f:
        json.dump(meta_data, f)

def upsert_documents(docs):
    """
    Takes in a dict of {pmid: text}, embeds and upserts them into the FAISS index.
    Each doc is stored in 'meta_data' with pmid as key.
    """
    index, meta_data = create_or_load_faiss_index()

    texts = list(docs.values())
    pmids = list(docs.keys())

    embeddings = embed_model.encode(texts, convert_to_numpy=True)
    index.add(embeddings)

    # Maintain a simple meta_data: { int_id: { 'pmid': X, 'text': Y } }
    # Where int_id is the row in the index
    start_id = len(meta_data)
    for i, pmid in enumerate(pmids):
        meta_data[str(start_id + i)] = {"pmid": pmid, "text": texts[i]}

    save_faiss_index(index, meta_data)

def semantic_search(query, top_k=3):
    """
    Embeds 'query' and searches the FAISS index for top_k similar docs.
    Returns a list of dict with 'pmid' and 'text'.
    """
    index, meta_data = create_or_load_faiss_index()

    query_embedding = embed_model.encode([query], convert_to_numpy=True)
    distances, indices = index.search(query_embedding, top_k)

    results = []
    for dist, idx_list in zip(distances, indices):
        for d, i in zip(dist, idx_list):
            # i is row in the index, look up meta_data
            doc_info = meta_data[str(i)]
            results.append({"pmid": doc_info["pmid"], "text": doc_info["text"], "score": float(d)})
    # Sort by ascending distance => best match first
    results.sort(key=lambda x: x["score"])
    return results