#!/usr/bin/env python3 # mcp/orchestrator.py """ MedGenesis – dual-LLM orchestrator (v4) --------------------------------------- • Accepts llm="openai" | "gemini" (defaults to OpenAI) • Safely runs all data-source calls in parallel • Uses pytrials for ClinicalTrials.gov and pybioportal for cBioPortal • Returns one dict that the Streamlit UI can rely on """ from __future__ import annotations import asyncio, itertools, logging from typing import Dict, Any, List from mcp.arxiv import fetch_arxiv from mcp.pubmed import fetch_pubmed from mcp.ncbi import search_gene, get_mesh_definition from mcp.mygene import fetch_gene_info from mcp.ensembl import fetch_ensembl from mcp.opentargets import fetch_ot from mcp.umls import lookup_umls from mcp.openfda import fetch_drug_safety from mcp.disgenet import disease_to_genes from mcp.clinicaltrials import fetch_clinical_trials from mcp.cbio import fetch_cbio_variants from mcp.openai_utils import ai_summarize, ai_qa from mcp.gemini import gemini_summarize, gemini_qa log = logging.getLogger(__name__) _DEFAULT_LLM = "openai" def _llm_router(engine: str = _DEFAULT_LLM): """Returns (summarize_fn, qa_fn, engine_name).""" if engine.lower() == "gemini": return gemini_summarize, gemini_qa, "gemini" return ai_summarize, ai_qa, "openai" async def _safe_gather(*tasks, return_exceptions: bool = False): """ Wrapper around asyncio.gather that logs failures and optionally returns exceptions as results. """ results = await asyncio.gather(*tasks, return_exceptions=True) cleaned = [] for idx, res in enumerate(results): if isinstance(res, Exception): log.warning("Task %d failed: %s", idx, res) if return_exceptions: cleaned.append(res) else: cleaned.append(res) return cleaned async def orchestrate_search(query: str, llm: str = _DEFAULT_LLM) -> Dict[str, Any]: """ Main entry point for MedGenesis UI. Returns a dict with: - papers, umls, drug_safety, clinical_trials, variants - genes, mesh_defs, gene_disease - ai_summary, llm_used """ # 1) Literature (PubMed + arXiv in parallel) pubmed_t = asyncio.create_task(fetch_pubmed(query, max_results=7)) arxiv_t = asyncio.create_task(fetch_arxiv(query, max_results=7)) papers_raw = await _safe_gather(pubmed_t, arxiv_t) papers = list(itertools.chain.from_iterable(papers_raw))[:30] # 2) Keyword seeds from abstracts (first 500 chars, split on whitespace) seeds = { w.strip() for p in papers for w in p.get("summary", "")[:500].split() if w.isalpha() } seeds = list(seeds)[:10] # 3) Fan-out all bio-enrichment tasks safely umls_tasks = [asyncio.create_task(lookup_umls(k)) for k in seeds] fda_tasks = [asyncio.create_task(fetch_drug_safety(k)) for k in seeds] gene_enrich_t = asyncio.create_task(_gene_enrichment(seeds)) trials_t = asyncio.create_task(fetch_clinical_trials(query, max_studies=10)) cbio_t = asyncio.create_task( fetch_cbio_variants(seeds[0]) if seeds else asyncio.sleep(0, result=[]) ) umls_list, fda_list, gene_data, trials, variants = await asyncio.gather( _safe_gather(*umls_tasks, return_exceptions=True), _safe_gather(*fda_tasks, return_exceptions=True), gene_enrich_t, trials_t, cbio_t, ) # 4) Deduplicate and flatten genes genes = { g["symbol"] for source in (gene_data["ncbi"], gene_data["mygene"], gene_data["ensembl"], gene_data["ot_assoc"]) for g in source if isinstance(g, dict) and g.get("symbol") } genes = list(genes) # 5) Dedupe variants by (chrom, pos, ref, alt) if returned as dicts seen = set() unique_vars: List[dict] = [] for var in variants or []: key = (var.get("chromosome"), var.get("startPosition"), var.get("referenceAllele"), var.get("variantAllele")) if key not in seen: seen.add(key) unique_vars.append(var) # 6) LLM summary summarize_fn, _, engine_used = _llm_router(llm) long_text = " ".join(p.get("summary", "") for p in papers) ai_summary = await summarize_fn(long_text[:12000]) return { "papers": papers, "umls": [u for u in umls_list if not isinstance(u, Exception)], "drug_safety": list(itertools.chain.from_iterable(dfa for dfa in fda_list if isinstance(dfa, list))), "clinical_trials": trials or [], "variants": unique_vars, "genes": gene_data["ncbi"] + gene_data["ensembl"] + gene_data["mygene"], "mesh_defs": gene_data["mesh"], "gene_disease": gene_data["ot_assoc"], "ai_summary": ai_summary, "llm_used": engine_used, } async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]: """ Fan-out gene-related tasks for each seed key: - NCBI gene lookup - MeSH definition - MyGene.info - Ensembl xrefs - OpenTargets associations Returns a dict of lists. """ jobs = [] for k in keys: jobs.extend([ asyncio.create_task(search_gene(k)), asyncio.create_task(get_mesh_definition(k)), asyncio.create_task(fetch_gene_info(k)), asyncio.create_task(fetch_ensembl(k)), asyncio.create_task(fetch_ot(k)), ]) results = await _safe_gather(*jobs, return_exceptions=True) # Group back into 5 buckets def bucket(idx: int): return [r for i, r in enumerate(results) if i % 5 == idx and not isinstance(r, Exception)] return { "ncbi": bucket(0), "mesh": bucket(1), "mygene": bucket(2), "ensembl": bucket(3), "ot_assoc": bucket(4), } async def answer_ai_question(question: str, context: str, llm: str = _DEFAULT_LLM) -> Dict[str, str]: """ Follow-up QA: wraps the chosen LLM’s QA function. """ _, qa_fn, _ = _llm_router(llm) prompt = f"Q: {question}\nContext: {context}\nA:" try: answer = await qa_fn(prompt) except Exception as e: answer = f"LLM error: {e}" return {"answer": answer}