MCP_Res / mcp /orchestrator.py
mgbam's picture
Update mcp/orchestrator.py
86771dc verified
raw
history blame
5.79 kB
#!/usr/bin/env python3
"""
MedGenesis – dual‑LLM asynchronous orchestrator
==============================================
β€’ Accepts `llm` argument ("openai" | "gemini"), defaults to "openai".
β€’ Harvests literature (PubMedβ€―+β€―arXiv) β†’ extracts keywords.
β€’ Fans‑out to open APIs for genes, trials, safety, ontology:
– **MyGene.info** for live gene annotations
– **ClinicalTrials.govΒ v2** for recruiting & completed studies
– UMLSβ€―/β€―openFDAβ€―/β€―DisGeNETβ€―/β€―MeSH (existing helpers)
– Optional OpenΒ Targets & DrugCentral via `multi_enrich` if needed.
β€’ Returns a single JSON‑serialisable dict consumed by the Streamlit UI.
"""
from __future__ import annotations
import asyncio
from typing import Any, Dict, List
# ── Literature fetchers ─────────────────────────────────────────────
from mcp.arxiv import fetch_arxiv
from mcp.pubmed import fetch_pubmed
# ── NLP & legacy enrichers ─────────────────────────────────────────
from mcp.nlp import extract_keywords
from mcp.umls import lookup_umls
from mcp.openfda import fetch_drug_safety
from mcp.ncbi import search_gene, get_mesh_definition
from mcp.disgenet import disease_to_genes
# ── Modern high‑throughput APIs ────────────────────────────────────
from mcp.mygene import fetch_gene_info # MyGene.info
from mcp.ctgov import search_trials_v2 # ClinicalTrials.gov v2
# from mcp.targets import fetch_ot_associations # (optional future use)
# ── LLM utilities ─────────────────────────────────────────────────
from mcp.openai_utils import ai_summarize, ai_qa
from mcp.gemini import gemini_summarize, gemini_qa
# ------------------------------------------------------------------
# LLM router
# ------------------------------------------------------------------
def _get_llm(llm: str):
"""Return (summarize_fn, qa_fn) based on requested engine."""
if llm and llm.lower() == "gemini":
return gemini_summarize, gemini_qa
return ai_summarize, ai_qa # default β†’ OpenAI
# ------------------------------------------------------------------
# Helper: batch NCBIΒ /Β MeSHΒ /Β DisGeNET enrichment for keyword list
# ------------------------------------------------------------------
async def _enrich_ncbi_mesh_disg(keys: List[str]) -> Dict[str, Any]:
jobs = [search_gene(k) for k in keys] + \
[get_mesh_definition(k) for k in keys] + \
[disease_to_genes(k) for k in keys]
results = await asyncio.gather(*jobs, return_exceptions=True)
genes, mesh_defs, disg_links = [], [], []
n = len(keys)
for idx, res in enumerate(results):
if isinstance(res, Exception):
continue
bucket = idx // n # 0Β =Β gene, 1Β =Β mesh, 2Β =Β disg
if bucket == 0:
genes.extend(res)
elif bucket == 1:
mesh_defs.append(res)
else:
disg_links.extend(res)
return {"genes": genes, "meshes": mesh_defs, "disgenet": disg_links}
# ------------------------------------------------------------------
# Main orchestrator
# ------------------------------------------------------------------
async def orchestrate_search(query: str, *, llm: str = "openai") -> Dict[str, Any]:
"""Master async pipeline – returns dict consumed by UI."""
# 1)Β Literature --------------------------------------------------
arxiv_task = asyncio.create_task(fetch_arxiv(query))
pubmed_task = asyncio.create_task(fetch_pubmed(query))
papers = sum(await asyncio.gather(arxiv_task, pubmed_task), [])
# 2)Β Keyword extraction -----------------------------------------
corpus = " ".join(p["summary"] for p in papers)
keywords = extract_keywords(corpus)[:8]
# 3)Β Fan‑out enrichment -----------------------------------------
umls_tasks = [lookup_umls(k) for k in keywords]
fda_tasks = [fetch_drug_safety(k) for k in keywords]
ncbi_task = asyncio.create_task(_enrich_ncbi_mesh_disg(keywords))
mygene_task = asyncio.create_task(fetch_gene_info(query)) # top gene hit
trials_task = asyncio.create_task(search_trials_v2(query, max_n=20))
umls, fda, ncbi_data, mygene, trials = await asyncio.gather(
asyncio.gather(*umls_tasks, return_exceptions=True),
asyncio.gather(*fda_tasks, return_exceptions=True),
ncbi_task,
mygene_task,
trials_task,
)
# 4)Β LLM summary -------------------------------------------------
summarize_fn, _ = _get_llm(llm)
ai_summary = await summarize_fn(corpus)
# 5)Β Assemble payload -------------------------------------------
return {
"papers" : papers,
"umls" : umls,
"drug_safety" : fda,
"ai_summary" : ai_summary,
"llm_used" : llm.lower(),
# Gene & variant context
"genes" : (ncbi_data["genes"] or []) + ([mygene] if mygene else []),
"mesh_defs" : ncbi_data["meshes"],
"gene_disease" : ncbi_data["disgenet"],
# Clinical trials
"clinical_trials": trials,
}
# ------------------------------------------------------------------
async def answer_ai_question(question: str, *, context: str, llm: str = "openai") -> Dict[str, str]:
"""One‑shot follow‑up QA using selected engine."""
_, qa_fn = _get_llm(llm)
return {"answer": await qa_fn(question, context)}