MCP_Res / mcp /orchestrator.py
mgbam's picture
Update mcp/orchestrator.py
b7db50c verified
#!/usr/bin/env python3
"""
MedGenesis – dual-LLM orchestrator (v5)
---------------------------------------
• No external 'pytrials' dependency.
• Uses direct HTTP for clinical trials.
• Clean async fan-out, dual-LLM support.
"""
from __future__ import annotations
import asyncio, itertools, logging
from typing import Dict, Any, List, Tuple
from mcp.arxiv import fetch_arxiv
from mcp.pubmed import fetch_pubmed
from mcp.ncbi import search_gene, get_mesh_definition
from mcp.mygene import fetch_gene_info
from mcp.ensembl import fetch_ensembl
from mcp.opentargets import fetch_ot
from mcp.umls import lookup_umls
from mcp.openfda import fetch_drug_safety
from mcp.disgenet import disease_to_genes
from mcp.clinicaltrials import fetch_clinical_trials
from mcp.cbio import fetch_cbio
from mcp.openai_utils import ai_summarize, ai_qa
from mcp.gemini import gemini_summarize, gemini_qa
log = logging.getLogger(__name__)
_DEFAULT_LLM = "openai"
def _llm_router(engine: str = _DEFAULT_LLM) -> Tuple:
"""Choose summarization and QA functions based on engine name."""
if engine.lower() == "gemini":
return gemini_summarize, gemini_qa, "gemini"
return ai_summarize, ai_qa, "openai"
async def _safe_gather(*tasks, return_exceptions: bool = False):
"""
Await multiple coroutines, log any exceptions, and optionally return them.
"""
results = await asyncio.gather(*tasks, return_exceptions=True)
cleaned: List[Any] = []
for r in results:
if isinstance(r, Exception):
log.warning("Task failed: %s", r)
if return_exceptions:
cleaned.append(r)
else:
cleaned.append(r)
return cleaned
async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
"""
Fan-out gene-related endpoints for each seed keyword:
- NCBI gene lookup
- MeSH definition
- MyGene.info
- Ensembl cross-refs
- OpenTargets associations
Returns a dict of results.
"""
jobs: List[asyncio.Task] = []
for k in keys:
jobs.extend([
asyncio.create_task(search_gene(k)),
asyncio.create_task(get_mesh_definition(k)),
asyncio.create_task(fetch_gene_info(k)),
asyncio.create_task(fetch_ensembl(k)),
asyncio.create_task(fetch_ot(k)),
])
results = await _safe_gather(*jobs, return_exceptions=True)
def bucket(idx: int) -> List[Any]:
return [res for i, res in enumerate(results) if i % 5 == idx and not isinstance(res, Exception)]
return {
"ncbi": bucket(0),
"mesh": bucket(1),
"mygene": bucket(2),
"ensembl": bucket(3),
"ot_assoc": bucket(4),
}
async def orchestrate_search(query: str, llm: str = _DEFAULT_LLM) -> Dict[str, Any]:
"""
Main entry point. Performs:
1. Literature fetch (PubMed + arXiv)
2. Keyword seed extraction
3. Bio-enrichment (UMLS, OpenFDA, gene services)
4. Clinical trials lookup
5. cBioPortal variants
6. AI LLM summary
Returns a unified dict for the UI.
"""
# 1) Literature
pubmed_t = asyncio.create_task(fetch_pubmed(query, max_results=7))
arxiv_t = asyncio.create_task(fetch_arxiv(query, max_results=7))
papers_raw = await _safe_gather(pubmed_t, arxiv_t)
papers = list(itertools.chain.from_iterable(papers_raw))[:30]
# 2) Seed keywords
seeds = {
w.strip()
for p in papers
for w in p.get("summary", "")[:500].split()
if w.isalpha()
}
seeds = list(seeds)[:10]
# 3) Bio-enrichment fan-out
umls_tasks = [asyncio.create_task(lookup_umls(k)) for k in seeds]
fda_tasks = [asyncio.create_task(fetch_drug_safety(k)) for k in seeds]
gene_task = asyncio.create_task(_gene_enrichment(seeds))
trials_t = asyncio.create_task(fetch_clinical_trials(query, max_studies=10))
cbio_t = asyncio.create_task(
fetch_cbio(seeds[0]) if seeds else asyncio.sleep(0, result=[])
)
umls_list, fda_list, gene_data, trials, variants = await asyncio.gather(
_safe_gather(*umls_tasks, return_exceptions=True),
_safe_gather(*fda_tasks, return_exceptions=True),
gene_task,
trials_t,
cbio_t,
)
# 4) Deduplicate gene symbols from enrichment
genes = {
g["symbol"]
for src in (gene_data["ncbi"], gene_data["mygene"], gene_data["ensembl"], gene_data["ot_assoc"])
for g in src if isinstance(g, dict) and g.get("symbol")
}
genes = list(genes)
# 5) Deduplicate variants by genomic coordinates
seen: set = set()
unique_vars: List[dict] = []
for v in variants or []:
key = (
v.get("chromosome"),
v.get("startPosition"),
v.get("referenceAllele"),
v.get("variantAllele"),
)
if key not in seen:
seen.add(key)
unique_vars.append(v)
# 6) LLM-driven summary
summarize_fn, _, engine_used = _llm_router(llm)
combined = " ".join(p.get("summary", "") for p in papers)
ai_summary = await summarize_fn(combined[:12000])
return {
"papers": papers,
"umls": [u for u in umls_list if not isinstance(u, Exception)],
"drug_safety": list(
itertools.chain.from_iterable(dfa for dfa in fda_list if isinstance(dfa, list))
),
"clinical_trials": trials or [],
"variants": unique_vars,
"genes": gene_data["ncbi"] + gene_data["ensembl"] + gene_data["mygene"],
"mesh_defs": gene_data["mesh"],
"gene_disease": gene_data["ot_assoc"],
"ai_summary": ai_summary,
"llm_used": engine_used,
}
async def answer_ai_question(question: str, context: str, llm: str = _DEFAULT_LLM) -> Dict[str, str]:
"""
Follow-up QA: uses the chosen LLM’s QA function.
"""
_, qa_fn, _ = _llm_router(llm)
prompt = f"Q: {question}\nContext: {context}\nA:"
try:
answer = await qa_fn(prompt)
except Exception as e:
answer = f"LLM error: {e}"
return {"answer": answer}