MCP_Res / mcp /orchestrator.py
mgbam's picture
Update mcp/orchestrator.py
146b143 verified
raw
history blame
5.21 kB
#!/usr/bin/env python3
"""MedGenesis – orchestrator (v4.1, context‑safe)
Runs an async pipeline that fetches literature, enriches with biomedical
APIs, and summarises via either OpenAI or Gemini. Fully resilient:
β€’ HTTPS arXiv
β€’ 403‑proof ClinicalTrials.gov helper
β€’ Filters out failed enrichment calls so UI never crashes
β€’ Follow‑up QA passes `context=` kwarg (fixes TypeError)
"""
from __future__ import annotations
import asyncio
from typing import Any, Dict, List
# ── async fetchers ──────────────────────────────────────────────────
from mcp.arxiv import fetch_arxiv
from mcp.pubmed import fetch_pubmed
from mcp.nlp import extract_keywords
from mcp.umls import lookup_umls
from mcp.openfda import fetch_drug_safety
from mcp.ncbi import search_gene, get_mesh_definition
from mcp.disgenet import disease_to_genes
from mcp.mygene import fetch_gene_info
from mcp.ctgov import search_trials # v2β†’v1 helper
# ── LLM helpers ────────────────────────────────────────────────────
from mcp.openai_utils import ai_summarize, ai_qa
from mcp.gemini import gemini_summarize, gemini_qa
# ------------------------------------------------------------------
# LLM router
# ------------------------------------------------------------------
_DEF_LLM = "openai"
def _llm_router(name: str | None):
if name and name.lower() == "gemini":
return gemini_summarize, gemini_qa, "gemini"
return ai_summarize, ai_qa, "openai"
# ------------------------------------------------------------------
# Keyword enrichment bundle (NCBI / MeSH / DisGeNET)
# ------------------------------------------------------------------
async def _enrich_keywords(keys: List[str]) -> Dict[str, Any]:
jobs: List[asyncio.Future] = []
for k in keys:
jobs += [search_gene(k), get_mesh_definition(k), disease_to_genes(k)]
res = await asyncio.gather(*jobs, return_exceptions=True)
genes, meshes, disg = [], [], []
for idx, r in enumerate(res):
if isinstance(r, Exception):
continue
bucket = idx % 3
if bucket == 0:
genes.extend(r)
elif bucket == 1:
meshes.append(r)
else:
disg.extend(r)
return {"genes": genes, "meshes": meshes, "disgenet": disg}
# ------------------------------------------------------------------
# Orchestrator main
# ------------------------------------------------------------------
async def orchestrate_search(query: str, *, llm: str = _DEF_LLM) -> Dict[str, Any]:
"""Fetch + enrich + summarise; returns dict for Streamlit UI."""
# 1) Literature --------------------------------------------------
arxiv_task = asyncio.create_task(fetch_arxiv(query, max_results=10))
pubmed_task = asyncio.create_task(fetch_pubmed(query, max_results=10))
papers: List[Dict] = []
for res in await asyncio.gather(arxiv_task, pubmed_task, return_exceptions=True):
if not isinstance(res, Exception):
papers.extend(res)
# 2) Keyword extraction -----------------------------------------
corpus = " ".join(p.get("summary", "") for p in papers)
keywords = extract_keywords(corpus)[:8]
# 3) Enrichment fan‑out -----------------------------------------
umls_f = [lookup_umls(k) for k in keywords]
fda_f = [fetch_drug_safety(k) for k in keywords]
ncbi_f = asyncio.create_task(_enrich_keywords(keywords))
gene_f = asyncio.create_task(fetch_gene_info(query))
trials_f = asyncio.create_task(search_trials(query, max_studies=20))
umls, fda, ncbi, mygene, trials = await asyncio.gather(
asyncio.gather(*umls_f, return_exceptions=True),
asyncio.gather(*fda_f, return_exceptions=True),
ncbi_f,
gene_f,
trials_f,
)
# filter out failed calls --------------------------------------
umls = [u for u in umls if isinstance(u, dict)]
fda = [d for d in fda if isinstance(d, (dict, list))]
# 4) LLM summary -------------------------------------------------
summarize_fn, _, engine = _llm_router(llm)
ai_summary = await summarize_fn(corpus) if corpus else ""
# 5) Assemble payload -------------------------------------------
return {
"papers" : papers,
"umls" : umls,
"drug_safety" : fda,
"ai_summary" : ai_summary,
"llm_used" : engine,
"genes" : (ncbi["genes"] or []) + ([mygene] if mygene else []),
"mesh_defs" : ncbi["meshes"],
"gene_disease" : ncbi["disgenet"],
"clinical_trials": trials,
}
# ------------------------------------------------------------------
async def answer_ai_question(question: str, *, context: str, llm: str = _DEF_LLM) -> Dict[str, str]:
"""Follow‑up QA using selected LLM (context kwarg fixed)."""
_, qa_fn, _ = _llm_router(llm)
return {"answer": await qa_fn(question, context=context)}