File size: 5,794 Bytes
86771dc
9965499
86771dc
 
 
 
 
 
 
 
 
 
9965499
3d539ca
86771dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12007d6
86771dc
 
12007d6
86771dc
3637999
86771dc
 
 
 
 
 
 
12007d6
86771dc
12007d6
86771dc
 
 
 
9965499
86771dc
 
 
 
 
 
 
9965499
86771dc
9965499
12007d6
86771dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9965499
 
86771dc
 
 
9965499
86771dc
3d539ca
86771dc
 
 
 
 
3d539ca
86771dc
 
 
 
12007d6
86771dc
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
"""
MedGenesis – dual‑LLM asynchronous orchestrator
==============================================
β€’ Accepts `llm` argument ("openai" | "gemini"), defaults to "openai".
β€’ Harvests literature (PubMedβ€―+β€―arXiv) β†’ extracts keywords.
β€’ Fans‑out to open APIs for genes, trials, safety, ontology:
    – **MyGene.info** for live gene annotations
    – **ClinicalTrials.govΒ v2** for recruiting & completed studies
    – UMLSβ€―/β€―openFDAβ€―/β€―DisGeNETβ€―/β€―MeSH (existing helpers)
    – Optional OpenΒ Targets & DrugCentral via `multi_enrich` if needed.
β€’ Returns a single JSON‑serialisable dict consumed by the Streamlit UI.
"""

from __future__ import annotations

import asyncio
from typing import Any, Dict, List

# ── Literature fetchers ─────────────────────────────────────────────
from mcp.arxiv  import fetch_arxiv
from mcp.pubmed import fetch_pubmed

# ── NLP & legacy enrichers ─────────────────────────────────────────
from mcp.nlp      import extract_keywords
from mcp.umls     import lookup_umls
from mcp.openfda  import fetch_drug_safety
from mcp.ncbi     import search_gene, get_mesh_definition
from mcp.disgenet import disease_to_genes

# ── Modern high‑throughput APIs ────────────────────────────────────
from mcp.mygene   import fetch_gene_info            # MyGene.info
from mcp.ctgov    import search_trials_v2           # ClinicalTrials.gov v2
# from mcp.targets  import fetch_ot_associations    # (optional future use)

# ── LLM utilities ─────────────────────────────────────────────────
from mcp.openai_utils import ai_summarize, ai_qa
from mcp.gemini       import gemini_summarize, gemini_qa

# ------------------------------------------------------------------
# LLM router
# ------------------------------------------------------------------

def _get_llm(llm: str):
    """Return (summarize_fn, qa_fn) based on requested engine."""
    if llm and llm.lower() == "gemini":
        return gemini_summarize, gemini_qa
    return ai_summarize, ai_qa  # default β†’ OpenAI

# ------------------------------------------------------------------
# Helper: batch NCBIΒ /Β MeSHΒ /Β DisGeNET enrichment for keyword list
# ------------------------------------------------------------------
async def _enrich_ncbi_mesh_disg(keys: List[str]) -> Dict[str, Any]:
    jobs = [search_gene(k) for k in keys] + \
           [get_mesh_definition(k) for k in keys] + \
           [disease_to_genes(k) for k in keys]

    results = await asyncio.gather(*jobs, return_exceptions=True)

    genes, mesh_defs, disg_links = [], [], []
    n = len(keys)
    for idx, res in enumerate(results):
        if isinstance(res, Exception):
            continue
        bucket = idx // n  # 0Β =Β gene, 1Β =Β mesh, 2Β =Β disg
        if bucket == 0:
            genes.extend(res)
        elif bucket == 1:
            mesh_defs.append(res)
        else:
            disg_links.extend(res)

    return {"genes": genes, "meshes": mesh_defs, "disgenet": disg_links}

# ------------------------------------------------------------------
# Main orchestrator
# ------------------------------------------------------------------
async def orchestrate_search(query: str, *, llm: str = "openai") -> Dict[str, Any]:
    """Master async pipeline – returns dict consumed by UI."""

    # 1)Β Literature --------------------------------------------------
    arxiv_task  = asyncio.create_task(fetch_arxiv(query))
    pubmed_task = asyncio.create_task(fetch_pubmed(query))
    papers = sum(await asyncio.gather(arxiv_task, pubmed_task), [])

    # 2)Β Keyword extraction -----------------------------------------
    corpus = " ".join(p["summary"] for p in papers)
    keywords = extract_keywords(corpus)[:8]

    # 3)Β Fan‑out enrichment -----------------------------------------
    umls_tasks  = [lookup_umls(k)       for k in keywords]
    fda_tasks   = [fetch_drug_safety(k) for k in keywords]

    ncbi_task   = asyncio.create_task(_enrich_ncbi_mesh_disg(keywords))
    mygene_task = asyncio.create_task(fetch_gene_info(query))           # top gene hit
    trials_task = asyncio.create_task(search_trials_v2(query, max_n=20))

    umls, fda, ncbi_data, mygene, trials = await asyncio.gather(
        asyncio.gather(*umls_tasks, return_exceptions=True),
        asyncio.gather(*fda_tasks,  return_exceptions=True),
        ncbi_task,
        mygene_task,
        trials_task,
    )

    # 4)Β LLM summary -------------------------------------------------
    summarize_fn, _ = _get_llm(llm)
    ai_summary = await summarize_fn(corpus)

    # 5)Β Assemble payload -------------------------------------------
    return {
        "papers"         : papers,
        "umls"           : umls,
        "drug_safety"    : fda,
        "ai_summary"     : ai_summary,
        "llm_used"       : llm.lower(),

        # Gene & variant context
        "genes"          : (ncbi_data["genes"] or []) + ([mygene] if mygene else []),
        "mesh_defs"      : ncbi_data["meshes"],
        "gene_disease"   : ncbi_data["disgenet"],

        # Clinical trials
        "clinical_trials": trials,
    }

# ------------------------------------------------------------------
async def answer_ai_question(question: str, *, context: str, llm: str = "openai") -> Dict[str, str]:
    """One‑shot follow‑up QA using selected engine."""
    _, qa_fn = _get_llm(llm)
    return {"answer": await qa_fn(question, context)}