MCP_Res / mcp /orchestrator.py
mgbam's picture
Update mcp/orchestrator.py
9958236 verified
raw
history blame
6.31 kB
#!/usr/bin/env python3
# mcp/orchestrator.py
"""
MedGenesis – dual-LLM orchestrator (v4)
---------------------------------------
• Accepts llm="openai" | "gemini" (defaults to OpenAI)
• Safely runs all data-source calls in parallel
• Uses pytrials for ClinicalTrials.gov and pybioportal for cBioPortal
• Returns one dict that the Streamlit UI can rely on
"""
from __future__ import annotations
import asyncio, itertools, logging
from typing import Dict, Any, List
from mcp.arxiv import fetch_arxiv
from mcp.pubmed import fetch_pubmed
from mcp.ncbi import search_gene, get_mesh_definition
from mcp.mygene import fetch_gene_info
from mcp.ensembl import fetch_ensembl
from mcp.opentargets import fetch_ot
from mcp.umls import lookup_umls
from mcp.openfda import fetch_drug_safety
from mcp.disgenet import disease_to_genes
from mcp.clinicaltrials import fetch_clinical_trials
from mcp.cbio import fetch_cbio_variants
from mcp.openai_utils import ai_summarize, ai_qa
from mcp.gemini import gemini_summarize, gemini_qa
log = logging.getLogger(__name__)
_DEFAULT_LLM = "openai"
def _llm_router(engine: str = _DEFAULT_LLM):
"""Returns (summarize_fn, qa_fn, engine_name)."""
if engine.lower() == "gemini":
return gemini_summarize, gemini_qa, "gemini"
return ai_summarize, ai_qa, "openai"
async def _safe_gather(*tasks, return_exceptions: bool = False):
"""
Wrapper around asyncio.gather that logs failures
and optionally returns exceptions as results.
"""
results = await asyncio.gather(*tasks, return_exceptions=True)
cleaned = []
for idx, res in enumerate(results):
if isinstance(res, Exception):
log.warning("Task %d failed: %s", idx, res)
if return_exceptions:
cleaned.append(res)
else:
cleaned.append(res)
return cleaned
async def orchestrate_search(query: str, llm: str = _DEFAULT_LLM) -> Dict[str, Any]:
"""
Main entry point for MedGenesis UI.
Returns a dict with:
- papers, umls, drug_safety, clinical_trials, variants
- genes, mesh_defs, gene_disease
- ai_summary, llm_used
"""
# 1) Literature (PubMed + arXiv in parallel)
pubmed_t = asyncio.create_task(fetch_pubmed(query, max_results=7))
arxiv_t = asyncio.create_task(fetch_arxiv(query, max_results=7))
papers_raw = await _safe_gather(pubmed_t, arxiv_t)
papers = list(itertools.chain.from_iterable(papers_raw))[:30]
# 2) Keyword seeds from abstracts (first 500 chars, split on whitespace)
seeds = {
w.strip()
for p in papers
for w in p.get("summary", "")[:500].split()
if w.isalpha()
}
seeds = list(seeds)[:10]
# 3) Fan-out all bio-enrichment tasks safely
umls_tasks = [asyncio.create_task(lookup_umls(k)) for k in seeds]
fda_tasks = [asyncio.create_task(fetch_drug_safety(k)) for k in seeds]
gene_enrich_t = asyncio.create_task(_gene_enrichment(seeds))
trials_t = asyncio.create_task(fetch_clinical_trials(query, max_studies=10))
cbio_t = asyncio.create_task(
fetch_cbio_variants(seeds[0]) if seeds else asyncio.sleep(0, result=[])
)
umls_list, fda_list, gene_data, trials, variants = await asyncio.gather(
_safe_gather(*umls_tasks, return_exceptions=True),
_safe_gather(*fda_tasks, return_exceptions=True),
gene_enrich_t,
trials_t,
cbio_t,
)
# 4) Deduplicate and flatten genes
genes = {
g["symbol"]
for source in (gene_data["ncbi"], gene_data["mygene"], gene_data["ensembl"], gene_data["ot_assoc"])
for g in source if isinstance(g, dict) and g.get("symbol")
}
genes = list(genes)
# 5) Dedupe variants by (chrom, pos, ref, alt) if returned as dicts
seen = set()
unique_vars: List[dict] = []
for var in variants or []:
key = (var.get("chromosome"), var.get("startPosition"), var.get("referenceAllele"), var.get("variantAllele"))
if key not in seen:
seen.add(key)
unique_vars.append(var)
# 6) LLM summary
summarize_fn, _, engine_used = _llm_router(llm)
long_text = " ".join(p.get("summary", "") for p in papers)
ai_summary = await summarize_fn(long_text[:12000])
return {
"papers": papers,
"umls": [u for u in umls_list if not isinstance(u, Exception)],
"drug_safety": list(itertools.chain.from_iterable(dfa for dfa in fda_list if isinstance(dfa, list))),
"clinical_trials": trials or [],
"variants": unique_vars,
"genes": gene_data["ncbi"] + gene_data["ensembl"] + gene_data["mygene"],
"mesh_defs": gene_data["mesh"],
"gene_disease": gene_data["ot_assoc"],
"ai_summary": ai_summary,
"llm_used": engine_used,
}
async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
"""
Fan-out gene-related tasks for each seed key:
- NCBI gene lookup
- MeSH definition
- MyGene.info
- Ensembl xrefs
- OpenTargets associations
Returns a dict of lists.
"""
jobs = []
for k in keys:
jobs.extend([
asyncio.create_task(search_gene(k)),
asyncio.create_task(get_mesh_definition(k)),
asyncio.create_task(fetch_gene_info(k)),
asyncio.create_task(fetch_ensembl(k)),
asyncio.create_task(fetch_ot(k)),
])
results = await _safe_gather(*jobs, return_exceptions=True)
# Group back into 5 buckets
def bucket(idx: int):
return [r for i, r in enumerate(results) if i % 5 == idx and not isinstance(r, Exception)]
return {
"ncbi": bucket(0),
"mesh": bucket(1),
"mygene": bucket(2),
"ensembl": bucket(3),
"ot_assoc": bucket(4),
}
async def answer_ai_question(question: str, context: str, llm: str = _DEFAULT_LLM) -> Dict[str, str]:
"""
Follow-up QA: wraps the chosen LLM’s QA function.
"""
_, qa_fn, _ = _llm_router(llm)
prompt = f"Q: {question}\nContext: {context}\nA:"
try:
answer = await qa_fn(prompt)
except Exception as e:
answer = f"LLM error: {e}"
return {"answer": answer}