|
""" |
|
MedGenesis β dual-LLM orchestrator |
|
---------------------------------- |
|
β’ Accepts llm = "openai" | "gemini" (falls back to OpenAI) |
|
β’ Returns one unified dict the UI can rely on. |
|
""" |
|
from __future__ import annotations |
|
import asyncio, itertools, logging |
|
from typing import Dict, Any, List, Tuple |
|
|
|
from mcp.arxiv import fetch_arxiv |
|
from mcp.pubmed import fetch_pubmed |
|
from mcp.ncbi import search_gene, get_mesh_definition |
|
from mcp.mygene import fetch_gene_info |
|
from mcp.ensembl import fetch_ensembl |
|
from mcp.opentargets import fetch_ot |
|
from mcp.umls import lookup_umls |
|
from mcp.openfda import fetch_drug_safety |
|
from mcp.disgenet import disease_to_genes |
|
from mcp.clinicaltrials import search_trials |
|
from mcp.cbio import fetch_cbio |
|
from mcp.openai_utils import ai_summarize, ai_qa |
|
from mcp.gemini import gemini_summarize, gemini_qa |
|
|
|
log = logging.getLogger(__name__) |
|
_DEF = "openai" |
|
|
|
|
|
|
|
def _llm_router(engine: str = _DEF) -> Tuple: |
|
if engine.lower() == "gemini": |
|
return gemini_summarize, gemini_qa, "gemini" |
|
return ai_summarize, ai_qa, "openai" |
|
|
|
async def _gather_safely(*aws, as_list: bool = True): |
|
"""await gather() that converts Exception β RuntimeError placeholder""" |
|
out = await asyncio.gather(*aws, return_exceptions=True) |
|
if as_list: |
|
|
|
return [x for x in out if not isinstance(x, Exception)] |
|
return out |
|
|
|
async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]: |
|
jobs = [] |
|
for k in keys: |
|
jobs += [ |
|
search_gene(k), |
|
get_mesh_definition(k), |
|
fetch_gene_info(k), |
|
fetch_ensembl(k), |
|
fetch_ot(k), |
|
] |
|
res = await _gather_safely(*jobs, as_list=False) |
|
|
|
|
|
combo = lambda idx: [r for i, r in enumerate(res) if i % 5 == idx and r] |
|
return { |
|
"ncbi" : combo(0), |
|
"mesh" : combo(1), |
|
"mygene" : combo(2), |
|
"ensembl" : combo(3), |
|
"ot_assoc" : combo(4), |
|
} |
|
|
|
|
|
|
|
async def orchestrate_search(query: str, *, llm: str = _DEF) -> Dict[str, Any]: |
|
"""Main entry β returns dict for the Streamlit UI""" |
|
|
|
arxiv_task = asyncio.create_task(fetch_arxiv(query)) |
|
pubmed_task = asyncio.create_task(fetch_pubmed(query)) |
|
papers_raw = await _gather_safely(arxiv_task, pubmed_task) |
|
papers = list(itertools.chain.from_iterable(papers_raw))[:30] |
|
|
|
|
|
kws = {w for p in papers for w in (p["summary"][:500].split()) if w.isalpha()} |
|
kws = list(kws)[:10] |
|
|
|
|
|
umls_f = [_safe_task(lookup_umls, k) for k in kws] |
|
fda_f = [_safe_task(fetch_drug_safety, k) for k in kws] |
|
gene_bundle = asyncio.create_task(_gene_enrichment(kws)) |
|
trials_task = asyncio.create_task(search_trials(query, max_studies=20)) |
|
cbio_task = asyncio.create_task(fetch_cbio(kws[0] if kws else "")) |
|
|
|
umls, fda, gene_dat, trials, variants = await asyncio.gather( |
|
_gather_safely(*umls_f), |
|
_gather_safely(*fda_f), |
|
gene_bundle, |
|
trials_task, |
|
cbio_task, |
|
) |
|
|
|
|
|
summarise_fn, _, engine = _llm_router(llm) |
|
summary = await summarise_fn(" ".join(p["summary"] for p in papers)[:12000]) |
|
|
|
return { |
|
"papers" : papers, |
|
"umls" : umls, |
|
"drug_safety" : fda, |
|
"ai_summary" : summary, |
|
"llm_used" : engine, |
|
"genes" : gene_dat["ncbi"] + gene_dat["ensembl"] + gene_dat["mygene"], |
|
"mesh_defs" : gene_dat["mesh"], |
|
"gene_disease" : gene_dat["ot_assoc"], |
|
"clinical_trials" : trials, |
|
"variants" : variants or [], |
|
} |
|
|
|
|
|
async def answer_ai_question(question: str, *, context: str, llm: str = _DEF) -> Dict[str, str]: |
|
"""Follow-up QA using chosen LLM.""" |
|
_, qa_fn, _ = _llm_router(llm) |
|
return {"answer": await qa_fn(f"Q: {question}\nContext: {context}\nA:")} |
|
|
|
|
|
|
|
def _safe_task(fn, *args): |
|
"""Helper to wrap callable β Task returning RuntimeError on exception.""" |
|
async def _wrapper(): |
|
try: |
|
return await fn(*args) |
|
except Exception as exc: |
|
log.warning("background task %s failed: %s", fn.__name__, exc) |
|
return RuntimeError(str(exc)) |
|
return asyncio.create_task(_wrapper()) |
|
|