File size: 6,307 Bytes
9958236
 
 
0bd4f6b
9958236
 
 
 
 
 
0bd4f6b
9958236
0bd4f6b
 
9958236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bd4f6b
 
9958236
0bd4f6b
 
9958236
 
0bd4f6b
 
 
 
 
9958236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bd4f6b
9958236
 
 
 
 
 
 
 
 
 
2a8cf8d
 
 
9958236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bd4f6b
9958236
 
 
 
 
08c0325
3d539ca
9958236
 
 
 
 
86771dc
 
9958236
 
 
 
 
0bd4f6b
9958236
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#!/usr/bin/env python3
# mcp/orchestrator.py

"""
MedGenesis – dual-LLM orchestrator (v4)
---------------------------------------
• Accepts llm="openai" | "gemini"  (defaults to OpenAI)
• Safely runs all data-source calls in parallel
• Uses pytrials for ClinicalTrials.gov and pybioportal for cBioPortal
• Returns one dict that the Streamlit UI can rely on
"""

from __future__ import annotations
import asyncio, itertools, logging
from typing import Dict, Any, List

from mcp.arxiv           import fetch_arxiv
from mcp.pubmed          import fetch_pubmed
from mcp.ncbi            import search_gene, get_mesh_definition
from mcp.mygene          import fetch_gene_info
from mcp.ensembl         import fetch_ensembl
from mcp.opentargets     import fetch_ot
from mcp.umls            import lookup_umls
from mcp.openfda         import fetch_drug_safety
from mcp.disgenet        import disease_to_genes
from mcp.clinicaltrials  import fetch_clinical_trials
from mcp.cbio            import fetch_cbio_variants
from mcp.openai_utils    import ai_summarize, ai_qa
from mcp.gemini          import gemini_summarize, gemini_qa

log = logging.getLogger(__name__)
_DEFAULT_LLM = "openai"


def _llm_router(engine: str = _DEFAULT_LLM):
    """Returns (summarize_fn, qa_fn, engine_name)."""
    if engine.lower() == "gemini":
        return gemini_summarize, gemini_qa, "gemini"
    return ai_summarize, ai_qa, "openai"


async def _safe_gather(*tasks, return_exceptions: bool = False):
    """
    Wrapper around asyncio.gather that logs failures
    and optionally returns exceptions as results.
    """
    results = await asyncio.gather(*tasks, return_exceptions=True)
    cleaned = []
    for idx, res in enumerate(results):
        if isinstance(res, Exception):
            log.warning("Task %d failed: %s", idx, res)
            if return_exceptions:
                cleaned.append(res)
        else:
            cleaned.append(res)
    return cleaned


async def orchestrate_search(query: str, llm: str = _DEFAULT_LLM) -> Dict[str, Any]:
    """
    Main entry point for MedGenesis UI.
    Returns a dict with:
      - papers, umls, drug_safety, clinical_trials, variants
      - genes, mesh_defs, gene_disease
      - ai_summary, llm_used
    """
    # 1) Literature (PubMed + arXiv in parallel)
    pubmed_t = asyncio.create_task(fetch_pubmed(query, max_results=7))
    arxiv_t  = asyncio.create_task(fetch_arxiv(query, max_results=7))
    papers_raw = await _safe_gather(pubmed_t, arxiv_t)
    papers = list(itertools.chain.from_iterable(papers_raw))[:30]

    # 2) Keyword seeds from abstracts (first 500 chars, split on whitespace)
    seeds = {
        w.strip()
        for p in papers
        for w in p.get("summary", "")[:500].split()
        if w.isalpha()
    }
    seeds = list(seeds)[:10]

    # 3) Fan-out all bio-enrichment tasks safely
    umls_tasks = [asyncio.create_task(lookup_umls(k)) for k in seeds]
    fda_tasks  = [asyncio.create_task(fetch_drug_safety(k)) for k in seeds]
    gene_enrich_t = asyncio.create_task(_gene_enrichment(seeds))
    trials_t      = asyncio.create_task(fetch_clinical_trials(query, max_studies=10))
    cbio_t        = asyncio.create_task(
        fetch_cbio_variants(seeds[0]) if seeds else asyncio.sleep(0, result=[])
    )

    umls_list, fda_list, gene_data, trials, variants = await asyncio.gather(
        _safe_gather(*umls_tasks, return_exceptions=True),
        _safe_gather(*fda_tasks, return_exceptions=True),
        gene_enrich_t,
        trials_t,
        cbio_t,
    )

    # 4) Deduplicate and flatten genes
    genes = {
        g["symbol"]
        for source in (gene_data["ncbi"], gene_data["mygene"], gene_data["ensembl"], gene_data["ot_assoc"])
        for g in source if isinstance(g, dict) and g.get("symbol")
    }
    genes = list(genes)

    # 5) Dedupe variants by (chrom, pos, ref, alt) if returned as dicts
    seen = set()
    unique_vars: List[dict] = []
    for var in variants or []:
        key = (var.get("chromosome"), var.get("startPosition"), var.get("referenceAllele"), var.get("variantAllele"))
        if key not in seen:
            seen.add(key)
            unique_vars.append(var)

    # 6) LLM summary
    summarize_fn, _, engine_used = _llm_router(llm)
    long_text = " ".join(p.get("summary", "") for p in papers)
    ai_summary = await summarize_fn(long_text[:12000])

    return {
        "papers": papers,
        "umls": [u for u in umls_list if not isinstance(u, Exception)],
        "drug_safety": list(itertools.chain.from_iterable(dfa for dfa in fda_list if isinstance(dfa, list))),
        "clinical_trials": trials or [],
        "variants": unique_vars,
        "genes": gene_data["ncbi"] + gene_data["ensembl"] + gene_data["mygene"],
        "mesh_defs": gene_data["mesh"],
        "gene_disease": gene_data["ot_assoc"],
        "ai_summary": ai_summary,
        "llm_used": engine_used,
    }


async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
    """
    Fan-out gene-related tasks for each seed key:
      - NCBI gene lookup
      - MeSH definition
      - MyGene.info
      - Ensembl xrefs
      - OpenTargets associations
    Returns a dict of lists.
    """
    jobs = []
    for k in keys:
        jobs.extend([
            asyncio.create_task(search_gene(k)),
            asyncio.create_task(get_mesh_definition(k)),
            asyncio.create_task(fetch_gene_info(k)),
            asyncio.create_task(fetch_ensembl(k)),
            asyncio.create_task(fetch_ot(k)),
        ])

    results = await _safe_gather(*jobs, return_exceptions=True)

    # Group back into 5 buckets
    def bucket(idx: int):
        return [r for i, r in enumerate(results) if i % 5 == idx and not isinstance(r, Exception)]

    return {
        "ncbi": bucket(0),
        "mesh": bucket(1),
        "mygene": bucket(2),
        "ensembl": bucket(3),
        "ot_assoc": bucket(4),
    }


async def answer_ai_question(question: str, context: str, llm: str = _DEFAULT_LLM) -> Dict[str, str]:
    """
    Follow-up QA: wraps the chosen LLM’s QA function.
    """
    _, qa_fn, _ = _llm_router(llm)
    prompt = f"Q: {question}\nContext: {context}\nA:"
    try:
        answer = await qa_fn(prompt)
    except Exception as e:
        answer = f"LLM error: {e}"
    return {"answer": answer}