File size: 6,307 Bytes
9958236 0bd4f6b 9958236 0bd4f6b 9958236 0bd4f6b 9958236 0bd4f6b 9958236 0bd4f6b 9958236 0bd4f6b 9958236 0bd4f6b 9958236 2a8cf8d 9958236 0bd4f6b 9958236 08c0325 3d539ca 9958236 86771dc 9958236 0bd4f6b 9958236 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
#!/usr/bin/env python3
# mcp/orchestrator.py
"""
MedGenesis – dual-LLM orchestrator (v4)
---------------------------------------
• Accepts llm="openai" | "gemini" (defaults to OpenAI)
• Safely runs all data-source calls in parallel
• Uses pytrials for ClinicalTrials.gov and pybioportal for cBioPortal
• Returns one dict that the Streamlit UI can rely on
"""
from __future__ import annotations
import asyncio, itertools, logging
from typing import Dict, Any, List
from mcp.arxiv import fetch_arxiv
from mcp.pubmed import fetch_pubmed
from mcp.ncbi import search_gene, get_mesh_definition
from mcp.mygene import fetch_gene_info
from mcp.ensembl import fetch_ensembl
from mcp.opentargets import fetch_ot
from mcp.umls import lookup_umls
from mcp.openfda import fetch_drug_safety
from mcp.disgenet import disease_to_genes
from mcp.clinicaltrials import fetch_clinical_trials
from mcp.cbio import fetch_cbio_variants
from mcp.openai_utils import ai_summarize, ai_qa
from mcp.gemini import gemini_summarize, gemini_qa
log = logging.getLogger(__name__)
_DEFAULT_LLM = "openai"
def _llm_router(engine: str = _DEFAULT_LLM):
"""Returns (summarize_fn, qa_fn, engine_name)."""
if engine.lower() == "gemini":
return gemini_summarize, gemini_qa, "gemini"
return ai_summarize, ai_qa, "openai"
async def _safe_gather(*tasks, return_exceptions: bool = False):
"""
Wrapper around asyncio.gather that logs failures
and optionally returns exceptions as results.
"""
results = await asyncio.gather(*tasks, return_exceptions=True)
cleaned = []
for idx, res in enumerate(results):
if isinstance(res, Exception):
log.warning("Task %d failed: %s", idx, res)
if return_exceptions:
cleaned.append(res)
else:
cleaned.append(res)
return cleaned
async def orchestrate_search(query: str, llm: str = _DEFAULT_LLM) -> Dict[str, Any]:
"""
Main entry point for MedGenesis UI.
Returns a dict with:
- papers, umls, drug_safety, clinical_trials, variants
- genes, mesh_defs, gene_disease
- ai_summary, llm_used
"""
# 1) Literature (PubMed + arXiv in parallel)
pubmed_t = asyncio.create_task(fetch_pubmed(query, max_results=7))
arxiv_t = asyncio.create_task(fetch_arxiv(query, max_results=7))
papers_raw = await _safe_gather(pubmed_t, arxiv_t)
papers = list(itertools.chain.from_iterable(papers_raw))[:30]
# 2) Keyword seeds from abstracts (first 500 chars, split on whitespace)
seeds = {
w.strip()
for p in papers
for w in p.get("summary", "")[:500].split()
if w.isalpha()
}
seeds = list(seeds)[:10]
# 3) Fan-out all bio-enrichment tasks safely
umls_tasks = [asyncio.create_task(lookup_umls(k)) for k in seeds]
fda_tasks = [asyncio.create_task(fetch_drug_safety(k)) for k in seeds]
gene_enrich_t = asyncio.create_task(_gene_enrichment(seeds))
trials_t = asyncio.create_task(fetch_clinical_trials(query, max_studies=10))
cbio_t = asyncio.create_task(
fetch_cbio_variants(seeds[0]) if seeds else asyncio.sleep(0, result=[])
)
umls_list, fda_list, gene_data, trials, variants = await asyncio.gather(
_safe_gather(*umls_tasks, return_exceptions=True),
_safe_gather(*fda_tasks, return_exceptions=True),
gene_enrich_t,
trials_t,
cbio_t,
)
# 4) Deduplicate and flatten genes
genes = {
g["symbol"]
for source in (gene_data["ncbi"], gene_data["mygene"], gene_data["ensembl"], gene_data["ot_assoc"])
for g in source if isinstance(g, dict) and g.get("symbol")
}
genes = list(genes)
# 5) Dedupe variants by (chrom, pos, ref, alt) if returned as dicts
seen = set()
unique_vars: List[dict] = []
for var in variants or []:
key = (var.get("chromosome"), var.get("startPosition"), var.get("referenceAllele"), var.get("variantAllele"))
if key not in seen:
seen.add(key)
unique_vars.append(var)
# 6) LLM summary
summarize_fn, _, engine_used = _llm_router(llm)
long_text = " ".join(p.get("summary", "") for p in papers)
ai_summary = await summarize_fn(long_text[:12000])
return {
"papers": papers,
"umls": [u for u in umls_list if not isinstance(u, Exception)],
"drug_safety": list(itertools.chain.from_iterable(dfa for dfa in fda_list if isinstance(dfa, list))),
"clinical_trials": trials or [],
"variants": unique_vars,
"genes": gene_data["ncbi"] + gene_data["ensembl"] + gene_data["mygene"],
"mesh_defs": gene_data["mesh"],
"gene_disease": gene_data["ot_assoc"],
"ai_summary": ai_summary,
"llm_used": engine_used,
}
async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
"""
Fan-out gene-related tasks for each seed key:
- NCBI gene lookup
- MeSH definition
- MyGene.info
- Ensembl xrefs
- OpenTargets associations
Returns a dict of lists.
"""
jobs = []
for k in keys:
jobs.extend([
asyncio.create_task(search_gene(k)),
asyncio.create_task(get_mesh_definition(k)),
asyncio.create_task(fetch_gene_info(k)),
asyncio.create_task(fetch_ensembl(k)),
asyncio.create_task(fetch_ot(k)),
])
results = await _safe_gather(*jobs, return_exceptions=True)
# Group back into 5 buckets
def bucket(idx: int):
return [r for i, r in enumerate(results) if i % 5 == idx and not isinstance(r, Exception)]
return {
"ncbi": bucket(0),
"mesh": bucket(1),
"mygene": bucket(2),
"ensembl": bucket(3),
"ot_assoc": bucket(4),
}
async def answer_ai_question(question: str, context: str, llm: str = _DEFAULT_LLM) -> Dict[str, str]:
"""
Follow-up QA: wraps the chosen LLM’s QA function.
"""
_, qa_fn, _ = _llm_router(llm)
prompt = f"Q: {question}\nContext: {context}\nA:"
try:
answer = await qa_fn(prompt)
except Exception as e:
answer = f"LLM error: {e}"
return {"answer": answer}
|