mgbam commited on
Commit
2c2342d
Β·
verified Β·
1 Parent(s): 1ec3999

Update mcp/orchestrator.py

Browse files
Files changed (1) hide show
  1. mcp/orchestrator.py +75 -122
mcp/orchestrator.py CHANGED
@@ -1,127 +1,80 @@
1
- """
2
- MedGenesis – dual-LLM orchestrator
3
- ----------------------------------
4
- β€’ Accepts llm = "openai" | "gemini" (falls back to OpenAI)
5
- β€’ Returns one unified dict the UI can rely on.
6
- """
7
- from __future__ import annotations
8
- import asyncio, itertools, logging
9
- from typing import Dict, Any, List, Tuple
10
-
11
- from mcp.arxiv import fetch_arxiv
12
- from mcp.pubmed import fetch_pubmed
13
- from mcp.ncbi import search_gene, get_mesh_definition
14
- from mcp.mygene import fetch_gene_info
15
- from mcp.ensembl import fetch_ensembl
16
- from mcp.opentargets import fetch_ot
17
- from mcp.umls import lookup_umls
18
- from mcp.openfda import fetch_drug_safety
19
- from mcp.disgenet import disease_to_genes
20
- from mcp.clinicaltrials import search_trials
21
- from mcp.cbio import fetch_cbio
22
- from mcp.openai_utils import ai_summarize, ai_qa
23
- from mcp.gemini import gemini_summarize, gemini_qa
24
-
25
- log = logging.getLogger(__name__)
26
- _DEF = "openai" # default engine
27
-
28
-
29
- # ─────────────────────────────────── helpers ───────────────────────────────────
30
- def _llm_router(engine: str = _DEF) -> Tuple:
31
- if engine.lower() == "gemini":
32
- return gemini_summarize, gemini_qa, "gemini"
33
- return ai_summarize, ai_qa, "openai"
34
-
35
- async def _gather_safely(*aws, as_list: bool = True):
36
- """await gather() that converts Exception β†’ RuntimeError placeholder"""
37
- out = await asyncio.gather(*aws, return_exceptions=True)
38
- if as_list:
39
- # filter exceptions – keep structure but drop failures
40
- return [x for x in out if not isinstance(x, Exception)]
41
- return out
42
-
43
- async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
44
- jobs = []
45
- for k in keys:
46
- jobs += [
47
- search_gene(k), # basic gene info
48
- get_mesh_definition(k), # MeSH definitions
49
- fetch_gene_info(k), # MyGene
50
- fetch_ensembl(k), # Ensembl x-refs
51
- fetch_ot(k), # Open Targets associations
52
- ]
53
- res = await _gather_safely(*jobs, as_list=False)
54
-
55
- # slice & compress five-way fan-out
56
- combo = lambda idx: [r for i, r in enumerate(res) if i % 5 == idx and r]
57
- return {
58
- "ncbi" : combo(0),
59
- "mesh" : combo(1),
60
- "mygene" : combo(2),
61
- "ensembl" : combo(3),
62
- "ot_assoc" : combo(4),
63
- }
64
-
65
-
66
- # ───────────────────────────────── orchestrator ────────────────────────────────
67
- async def orchestrate_search(query: str, *, llm: str = _DEF) -> Dict[str, Any]:
68
- """Main entry – returns dict for the Streamlit UI"""
69
- # 1 Literature – run in parallel
70
- arxiv_task = asyncio.create_task(fetch_arxiv(query))
71
- pubmed_task = asyncio.create_task(fetch_pubmed(query))
72
- papers_raw = await _gather_safely(arxiv_task, pubmed_task)
73
- papers = list(itertools.chain.from_iterable(papers_raw))[:30] # keep ≀30
74
-
75
- # 2 Keyword extraction (very light – only from abstracts)
76
- kws = {w for p in papers for w in (p["summary"][:500].split()) if w.isalpha()}
77
- kws = list(kws)[:10] # coarse, fast -> 10 seeds
78
-
79
- # 3 Bio-enrichment fan-out
80
- umls_f = [_safe_task(lookup_umls, k) for k in kws]
81
- fda_f = [_safe_task(fetch_drug_safety, k) for k in kws]
82
- gene_bundle = asyncio.create_task(_gene_enrichment(kws))
83
- trials_task = asyncio.create_task(search_trials(query, max_studies=20))
84
- cbio_task = asyncio.create_task(fetch_cbio(kws[0] if kws else ""))
85
-
86
- umls, fda, gene_dat, trials, variants = await asyncio.gather(
87
- _gather_safely(*umls_f),
88
- _gather_safely(*fda_f),
89
- gene_bundle,
90
- trials_task,
91
- cbio_task,
92
  )
93
-
94
- # 4 LLM summary
95
- summarise_fn, _, engine = _llm_router(llm)
96
- summary = await summarise_fn(" ".join(p["summary"] for p in papers)[:12000])
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  return {
99
- "papers" : papers,
100
- "umls" : umls,
101
- "drug_safety" : fda,
102
- "ai_summary" : summary,
103
- "llm_used" : engine,
104
- "genes" : gene_dat["ncbi"] + gene_dat["ensembl"] + gene_dat["mygene"],
105
- "mesh_defs" : gene_dat["mesh"],
106
- "gene_disease" : gene_dat["ot_assoc"],
107
- "clinical_trials" : trials,
108
- "variants" : variants or [],
109
  }
110
 
111
- # ─────────────────────────────── follow-up QA ─────────────────────────────────
112
- async def answer_ai_question(question: str, *, context: str, llm: str = _DEF) -> Dict[str, str]:
113
- """Follow-up QA using chosen LLM."""
114
- _, qa_fn, _ = _llm_router(llm)
115
- return {"answer": await qa_fn(f"Q: {question}\nContext: {context}\nA:")}
116
-
117
-
118
- # ─────────────────────────── internal util ───────────────────────────────────
119
- def _safe_task(fn, *args):
120
- """Helper to wrap callable β†’ Task returning RuntimeError on exception."""
121
- async def _wrapper():
122
- try:
123
- return await fn(*args)
124
- except Exception as exc:
125
- log.warning("background task %s failed: %s", fn.__name__, exc)
126
- return RuntimeError(str(exc))
127
- return asyncio.create_task(_wrapper())
 
1
+ # mcp/orchestrator.py
2
+
3
+ import asyncio
4
+ from mcp.pubmed import fetch_pubmed
5
+ from mcp.arxiv import fetch_arxiv
6
+ from mcp.umls import extract_umls_concepts
7
+ from mcp.openfda import fetch_drug_safety
8
+ from mcp.ncbi import search_gene, get_mesh_definition
9
+ from mcp.mygene import fetch_gene_info
10
+ from mcp.ensembl import fetch_ensembl
11
+ from mcp.opentargets import fetch_ot
12
+ from mcp.clinicaltrials import search_trials
13
+ from mcp.cbio import fetch_cbio
14
+ from mcp.gemini import gemini_summarize, gemini_qa
15
+ from mcp.openai_utils import ai_summarize, ai_qa
16
+ from mcp.disgenet import disease_to_genes
17
+
18
+ async def orchestrate_search(query, llm="openai"):
19
+ # --- Literature: PubMed + arXiv
20
+ pubmed_task = asyncio.create_task(fetch_pubmed(query, max_results=7))
21
+ arxiv_task = asyncio.create_task(fetch_arxiv(query, max_results=7))
22
+ # --- UMLS, OpenFDA, Gene, Mesh
23
+ umls_task = asyncio.create_task(extract_umls_concepts(query))
24
+ fda_task = asyncio.create_task(fetch_drug_safety(query))
25
+ gene_ncbi_task = asyncio.create_task(search_gene(query))
26
+ mygene_task = asyncio.create_task(fetch_gene_info(query))
27
+ ensembl_task = asyncio.create_task(fetch_ensembl(query))
28
+ ot_task = asyncio.create_task(fetch_ot(query))
29
+ mesh_task = asyncio.create_task(get_mesh_definition(query))
30
+ # --- Trials, cBio, DisGeNET
31
+ trials_task = asyncio.create_task(search_trials(query, max_studies=10))
32
+ cbio_task = asyncio.create_task(fetch_cbio(query))
33
+ disgenet_task = asyncio.create_task(disease_to_genes(query))
34
+
35
+ # Run
36
+ pubmed, arxiv, umls, fda, ncbi, mygene, ensembl, ot, mesh, trials, cbio, disgenet = await asyncio.gather(
37
+ pubmed_task, arxiv_task, umls_task, fda_task, gene_ncbi_task,
38
+ mygene_task, ensembl_task, ot_task, mesh_task, trials_task, cbio_task, disgenet_task
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  )
40
+ # Genes: flatten and deduplicate
41
+ genes = []
42
+ for g in (ncbi, mygene, ensembl, ot):
43
+ if isinstance(g, list):
44
+ genes.extend(g)
45
+ elif isinstance(g, dict) and g:
46
+ genes.append(g)
47
+ genes = [g for i, g in enumerate(genes) if g and genes.index(g) == i] # dedup
48
+
49
+ # --- AI summary (LLM engine select)
50
+ papers = (pubmed or []) + (arxiv or [])
51
+ if llm == "gemini":
52
+ ai_summary = await gemini_summarize(" ".join([p.get("summary", "") for p in papers]))
53
+ llm_used = "gemini"
54
+ else:
55
+ ai_summary = await ai_summarize(" ".join([p.get("summary", "") for p in papers]))
56
+ llm_used = "openai"
57
 
58
  return {
59
+ "papers": papers,
60
+ "genes": genes,
61
+ "umls": umls or [],
62
+ "gene_disease": disgenet if isinstance(disgenet, list) else [],
63
+ "mesh_defs": [mesh] if isinstance(mesh, str) and mesh else [],
64
+ "drug_safety": fda or [],
65
+ "clinical_trials": trials or [],
66
+ "variants": cbio if isinstance(cbio, list) else [],
67
+ "ai_summary": ai_summary,
68
+ "llm_used": llm_used
69
  }
70
 
71
+ async def answer_ai_question(question, context="", llm="openai"):
72
+ # Gemini fallback if OpenAI quota is exceeded
73
+ try:
74
+ if llm == "gemini":
75
+ answer = await gemini_qa(question, context)
76
+ else:
77
+ answer = await ai_qa(question, context)
78
+ except Exception as e:
79
+ answer = f"LLM unavailable or quota exceeded. ({e})"
80
+ return {"answer": answer}