File size: 6,205 Bytes
9958236
0bd4f6b
b15fc81
9958236
b15fc81
 
 
0bd4f6b
9958236
0bd4f6b
 
0fb7617
9958236
 
 
 
 
 
 
 
 
 
 
b7db50c
9958236
 
0bd4f6b
 
9958236
0bd4f6b
 
0fb7617
 
0bd4f6b
 
 
 
 
9958236
0fb7617
 
 
9958236
0fb7617
b15fc81
 
 
9958236
b15fc81
9958236
b15fc81
9958236
 
 
b15fc81
0fb7617
 
 
 
 
 
 
 
 
 
b15fc81
 
 
 
 
 
 
 
0fb7617
 
 
 
 
b15fc81
 
 
 
 
 
 
 
 
9958236
0fb7617
 
 
 
 
 
 
 
 
 
b15fc81
0fb7617
 
 
9958236
 
0fb7617
9958236
0fb7617
 
 
 
9958236
 
 
0fb7617
9958236
 
0fb7617
b15fc81
 
b7db50c
9958236
 
 
 
 
0fb7617
9958236
 
 
 
0fb7617
9958236
 
b15fc81
 
9958236
 
 
0fb7617
 
 
b15fc81
0fb7617
 
 
 
 
 
9958236
0fb7617
 
9958236
0fb7617
 
b15fc81
0fb7617
9958236
0bd4f6b
9958236
 
0fb7617
 
 
9958236
 
 
 
 
 
 
2a8cf8d
 
 
9958236
0fb7617
b7db50c
0fb7617
0bd4f6b
9958236
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
#!/usr/bin/env python3
"""
MedGenesis – dual-LLM orchestrator (v5)
---------------------------------------
• No external 'pytrials' dependency.
• Uses direct HTTP for clinical trials.
• Clean async fan-out, dual-LLM support.
"""

from __future__ import annotations
import asyncio, itertools, logging
from typing import Dict, Any, List, Tuple

from mcp.arxiv           import fetch_arxiv
from mcp.pubmed          import fetch_pubmed
from mcp.ncbi            import search_gene, get_mesh_definition
from mcp.mygene          import fetch_gene_info
from mcp.ensembl         import fetch_ensembl
from mcp.opentargets     import fetch_ot
from mcp.umls            import lookup_umls
from mcp.openfda         import fetch_drug_safety
from mcp.disgenet        import disease_to_genes
from mcp.clinicaltrials  import fetch_clinical_trials
from mcp.cbio            import fetch_cbio
from mcp.openai_utils    import ai_summarize, ai_qa
from mcp.gemini          import gemini_summarize, gemini_qa

log = logging.getLogger(__name__)
_DEFAULT_LLM = "openai"


def _llm_router(engine: str = _DEFAULT_LLM) -> Tuple:
    """Choose summarization and QA functions based on engine name."""
    if engine.lower() == "gemini":
        return gemini_summarize, gemini_qa, "gemini"
    return ai_summarize, ai_qa, "openai"


async def _safe_gather(*tasks, return_exceptions: bool = False):
    """
    Await multiple coroutines, log any exceptions, and optionally return them.
    """
    results = await asyncio.gather(*tasks, return_exceptions=True)
    cleaned: List[Any] = []
    for r in results:
        if isinstance(r, Exception):
            log.warning("Task failed: %s", r)
            if return_exceptions:
                cleaned.append(r)
        else:
            cleaned.append(r)
    return cleaned


async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
    """
    Fan-out gene-related endpoints for each seed keyword:
      - NCBI gene lookup
      - MeSH definition
      - MyGene.info
      - Ensembl cross-refs
      - OpenTargets associations
    Returns a dict of results.
    """
    jobs: List[asyncio.Task] = []
    for k in keys:
        jobs.extend([
            asyncio.create_task(search_gene(k)),
            asyncio.create_task(get_mesh_definition(k)),
            asyncio.create_task(fetch_gene_info(k)),
            asyncio.create_task(fetch_ensembl(k)),
            asyncio.create_task(fetch_ot(k)),
        ])
    results = await _safe_gather(*jobs, return_exceptions=True)

    def bucket(idx: int) -> List[Any]:
        return [res for i, res in enumerate(results) if i % 5 == idx and not isinstance(res, Exception)]

    return {
        "ncbi": bucket(0),
        "mesh": bucket(1),
        "mygene": bucket(2),
        "ensembl": bucket(3),
        "ot_assoc": bucket(4),
    }


async def orchestrate_search(query: str, llm: str = _DEFAULT_LLM) -> Dict[str, Any]:
    """
    Main entry point. Performs:
      1. Literature fetch (PubMed + arXiv)
      2. Keyword seed extraction
      3. Bio-enrichment (UMLS, OpenFDA, gene services)
      4. Clinical trials lookup
      5. cBioPortal variants
      6. AI LLM summary
    Returns a unified dict for the UI.
    """
    # 1) Literature
    pubmed_t = asyncio.create_task(fetch_pubmed(query, max_results=7))
    arxiv_t  = asyncio.create_task(fetch_arxiv(query, max_results=7))
    papers_raw = await _safe_gather(pubmed_t, arxiv_t)
    papers = list(itertools.chain.from_iterable(papers_raw))[:30]

    # 2) Seed keywords
    seeds = {
        w.strip()
        for p in papers
        for w in p.get("summary", "")[:500].split()
        if w.isalpha()
    }
    seeds = list(seeds)[:10]

    # 3) Bio-enrichment fan-out
    umls_tasks = [asyncio.create_task(lookup_umls(k)) for k in seeds]
    fda_tasks  = [asyncio.create_task(fetch_drug_safety(k)) for k in seeds]
    gene_task  = asyncio.create_task(_gene_enrichment(seeds))
    trials_t   = asyncio.create_task(fetch_clinical_trials(query, max_studies=10))
    cbio_t     = asyncio.create_task(
        fetch_cbio(seeds[0]) if seeds else asyncio.sleep(0, result=[])
    )

    umls_list, fda_list, gene_data, trials, variants = await asyncio.gather(
        _safe_gather(*umls_tasks, return_exceptions=True),
        _safe_gather(*fda_tasks, return_exceptions=True),
        gene_task,
        trials_t,
        cbio_t,
    )

    # 4) Deduplicate gene symbols from enrichment
    genes = {
        g["symbol"]
        for src in (gene_data["ncbi"], gene_data["mygene"], gene_data["ensembl"], gene_data["ot_assoc"])
        for g in src if isinstance(g, dict) and g.get("symbol")
    }
    genes = list(genes)

    # 5) Deduplicate variants by genomic coordinates
    seen: set = set()
    unique_vars: List[dict] = []
    for v in variants or []:
        key = (
            v.get("chromosome"),
            v.get("startPosition"),
            v.get("referenceAllele"),
            v.get("variantAllele"),
        )
        if key not in seen:
            seen.add(key)
            unique_vars.append(v)

    # 6) LLM-driven summary
    summarize_fn, _, engine_used = _llm_router(llm)
    combined = " ".join(p.get("summary", "") for p in papers)
    ai_summary = await summarize_fn(combined[:12000])

    return {
        "papers": papers,
        "umls": [u for u in umls_list if not isinstance(u, Exception)],
        "drug_safety": list(
            itertools.chain.from_iterable(dfa for dfa in fda_list if isinstance(dfa, list))
        ),
        "clinical_trials": trials or [],
        "variants": unique_vars,
        "genes": gene_data["ncbi"] + gene_data["ensembl"] + gene_data["mygene"],
        "mesh_defs": gene_data["mesh"],
        "gene_disease": gene_data["ot_assoc"],
        "ai_summary": ai_summary,
        "llm_used": engine_used,
    }


async def answer_ai_question(question: str, context: str, llm: str = _DEFAULT_LLM) -> Dict[str, str]:
    """
    Follow-up QA: uses the chosen LLM’s QA function.
    """
    _, qa_fn, _ = _llm_router(llm)
    prompt = f"Q: {question}\nContext: {context}\nA:"
    try:
        answer = await qa_fn(prompt)
    except Exception as e:
        answer = f"LLM error: {e}"
    return {"answer": answer}