File size: 5,711 Bytes
9965499
c30e46a
 
 
 
96208dc
c30e46a
 
 
08c0325
96208dc
 
bc40121
 
 
 
c30e46a
 
 
 
 
96208dc
 
 
c30e46a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc40121
 
c30e46a
bc40121
c30e46a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaba1ed
f400521
c30e46a
 
 
08c0325
3d539ca
96208dc
bc40121
 
 
c30e46a
 
 
 
 
 
86771dc
 
c30e46a
bc40121
c30e46a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
MedGenesis – dual-LLM orchestrator
----------------------------------
β€’ Accepts llm = "openai" | "gemini"   (falls back to OpenAI)
β€’ Returns one unified dict the UI can rely on.
"""
from __future__ import annotations
import asyncio, itertools, logging
from typing import Dict, Any, List, Tuple

from mcp.arxiv            import fetch_arxiv
from mcp.pubmed           import fetch_pubmed
from mcp.ncbi             import search_gene, get_mesh_definition
from mcp.mygene           import fetch_gene_info
from mcp.ensembl          import fetch_ensembl
from mcp.opentargets      import fetch_ot
from mcp.umls             import lookup_umls
from mcp.openfda          import fetch_drug_safety
from mcp.disgenet         import disease_to_genes
from mcp.clinicaltrials   import search_trials
from mcp.cbio             import fetch_cbio
from mcp.openai_utils     import ai_summarize, ai_qa
from mcp.gemini           import gemini_summarize, gemini_qa

log = logging.getLogger(__name__)
_DEF = "openai"                                # default engine


# ─────────────────────────────────── helpers ───────────────────────────────────
def _llm_router(engine: str = _DEF) -> Tuple:
    if engine.lower() == "gemini":
        return gemini_summarize, gemini_qa, "gemini"
    return ai_summarize, ai_qa, "openai"

async def _gather_safely(*aws, as_list: bool = True):
    """await gather() that converts Exception β†’ RuntimeError placeholder"""
    out = await asyncio.gather(*aws, return_exceptions=True)
    if as_list:
        # filter exceptions – keep structure but drop failures
        return [x for x in out if not isinstance(x, Exception)]
    return out

async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
    jobs = []
    for k in keys:
        jobs += [
            search_gene(k),                    # basic gene info
            get_mesh_definition(k),            # MeSH definitions
            fetch_gene_info(k),                # MyGene
            fetch_ensembl(k),                  # Ensembl x-refs
            fetch_ot(k),                       # Open Targets associations
        ]
    res = await _gather_safely(*jobs, as_list=False)

    # slice & compress five-way fan-out
    combo = lambda idx: [r for i, r in enumerate(res) if i % 5 == idx and r]
    return {
        "ncbi"     : combo(0),
        "mesh"     : combo(1),
        "mygene"   : combo(2),
        "ensembl"  : combo(3),
        "ot_assoc" : combo(4),
    }


# ───────────────────────────────── orchestrator ────────────────────────────────
async def orchestrate_search(query: str, *, llm: str = _DEF) -> Dict[str, Any]:
    """Main entry – returns dict for the Streamlit UI"""
    # 1  Literature – run in parallel
    arxiv_task  = asyncio.create_task(fetch_arxiv(query))
    pubmed_task = asyncio.create_task(fetch_pubmed(query))
    papers_raw  = await _gather_safely(arxiv_task, pubmed_task)
    papers      = list(itertools.chain.from_iterable(papers_raw))[:30]   # keep ≀30

    # 2  Keyword extraction (very light – only from abstracts)
    kws = {w for p in papers for w in (p["summary"][:500].split()) if w.isalpha()}
    kws = list(kws)[:10]                           # coarse, fast -> 10 seeds

    # 3  Bio-enrichment fan-out
    umls_f       = [_safe_task(lookup_umls, k) for k in kws]
    fda_f        = [_safe_task(fetch_drug_safety, k) for k in kws]
    gene_bundle  = asyncio.create_task(_gene_enrichment(kws))
    trials_task  = asyncio.create_task(search_trials(query, max_studies=20))
    cbio_task    = asyncio.create_task(fetch_cbio(kws[0] if kws else ""))

    umls, fda, gene_dat, trials, variants = await asyncio.gather(
        _gather_safely(*umls_f),
        _gather_safely(*fda_f),
        gene_bundle,
        trials_task,
        cbio_task,
    )

    # 4  LLM summary
    summarise_fn, _, engine = _llm_router(llm)
    summary = await summarise_fn(" ".join(p["summary"] for p in papers)[:12000])

    return {
        "papers"          : papers,
        "umls"            : umls,
        "drug_safety"     : fda,
        "ai_summary"      : summary,
        "llm_used"        : engine,
        "genes"           : gene_dat["ncbi"] + gene_dat["ensembl"] + gene_dat["mygene"],
        "mesh_defs"       : gene_dat["mesh"],
        "gene_disease"    : gene_dat["ot_assoc"],
        "clinical_trials" : trials,
        "variants"        : variants or [],
    }

# ─────────────────────────────── follow-up QA ─────────────────────────────────
async def answer_ai_question(question: str, *, context: str, llm: str = _DEF) -> Dict[str, str]:
    """Follow-up QA using chosen LLM."""
    _, qa_fn, _ = _llm_router(llm)
    return {"answer": await qa_fn(f"Q: {question}\nContext: {context}\nA:")}


# ─────────────────────────── internal util  ───────────────────────────────────
def _safe_task(fn, *args):
    """Helper to wrap callable β†’ Task returning RuntimeError on exception."""
    async def _wrapper():
        try:
            return await fn(*args)
        except Exception as exc:
            log.warning("background task %s failed: %s", fn.__name__, exc)
            return RuntimeError(str(exc))
    return asyncio.create_task(_wrapper())