mgbam commited on
Commit
c30e46a
Β·
verified Β·
1 Parent(s): 19e03c6

Update mcp/orchestrator.py

Browse files
Files changed (1) hide show
  1. mcp/orchestrator.py +101 -88
mcp/orchestrator.py CHANGED
@@ -1,114 +1,127 @@
1
  """
2
- MedGenesis – dual-LLM orchestrator (OpenAI + Gemini)
3
- ----------------------------------------------------
4
- Returns a single dict the UI expects. New keys:
5
-
6
- β€’ variants – mutation summaries from cBioPortal
7
- β€’ variant_count – quick count for empty-tab logic
8
  """
 
 
 
9
 
10
- import asyncio
11
- from typing import Dict, Any, List
12
-
13
- # literature + NLP
14
  from mcp.arxiv import fetch_arxiv
15
  from mcp.pubmed import fetch_pubmed
16
- from mcp.nlp import extract_keywords
17
-
18
- # enrichment
19
- from mcp.umls import lookup_umls
20
- from mcp.openfda import fetch_drug_safety
21
  from mcp.ncbi import search_gene, get_mesh_definition
22
- from mcp.disgenet import disease_to_genes
23
- from mcp.clinicaltrials import search_trials
24
  from mcp.mygene import fetch_gene_info
25
  from mcp.ensembl import fetch_ensembl
26
  from mcp.opentargets import fetch_ot
27
- from mcp.cbio import fetch_cbio # NEW
28
-
29
- # LLMs
 
 
30
  from mcp.openai_utils import ai_summarize, ai_qa
31
  from mcp.gemini import gemini_summarize, gemini_qa
32
 
33
- _DEF = "openai"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- def _llm_router(llm: str):
36
- llm = (llm or _DEF).lower()
37
- if llm == "gemini":
38
- return ("gemini", gemini_summarize, gemini_qa)
39
- return ("openai", ai_summarize, ai_qa)
40
 
41
- # ---------------- gene meta helper ----------------
42
- async def _resolve_gene(sym: str) -> Dict[str, Any]:
43
- for fn in (fetch_gene_info, fetch_ensembl, fetch_ot):
44
- try:
45
- data = await fn(sym)
46
- if data:
47
- return data
48
- except Exception:
49
- continue
50
- return {}
51
-
52
- # ---------------- orchestrator --------------------
53
  async def orchestrate_search(query: str, *, llm: str = _DEF) -> Dict[str, Any]:
54
- # 1 literature ---------------------------------------------------
55
- arxiv_f = asyncio.create_task(fetch_arxiv(query))
56
- pubmed_f = asyncio.create_task(fetch_pubmed(query))
57
- papers = sum(await asyncio.gather(arxiv_f, pubmed_f), [])
58
-
59
- # 2 keywords -----------------------------------------------------
60
- blob = " ".join(p["summary"] for p in papers)
61
- keys = extract_keywords(blob)[:8] if blob else []
62
-
63
- # 3 parallel enrichment -----------------------------------------
64
- umls_f = [lookup_umls(k) for k in keys]
65
- fda_f = [fetch_drug_safety(k) for k in keys]
66
- ncbi_f = [search_gene(k) for k in keys]
67
- mesh_f = [get_mesh_definition(k) for k in keys]
68
- gene_meta= [ _resolve_gene(k) for k in keys[:3] ] # cheap
69
- trials_f = asyncio.create_task(search_trials(query, max_studies=20))
70
-
71
- # primary await
72
- (
73
- umls, fda, ncbi, meshes, gmeta, trials
74
- ) = await asyncio.gather(
75
- asyncio.gather(*umls_f, return_exceptions=True),
76
- asyncio.gather(*fda_f, return_exceptions=True),
77
- asyncio.gather(*ncbi_f, return_exceptions=True),
78
- asyncio.gather(*mesh_f, return_exceptions=True),
79
- asyncio.gather(*gene_meta, return_exceptions=True),
80
- trials_f,
81
  )
82
 
83
- # 4 variants (fire & forget; don’t fail whole run) --------------
84
- var_jobs = [fetch_cbio(g.get("symbol") or k)
85
- for g, k in zip(gmeta, keys[:len(gmeta)])]
86
- try:
87
- variants = sum(await asyncio.gather(*var_jobs), [])
88
- except Exception:
89
- variants = []
90
-
91
- # 5 LLM summary -------------------------------------------------
92
- _, summarise, _ = _llm_router(llm)
93
- summary = await summarise(blob) if blob else "No abstracts found."
94
 
95
  return {
96
  "papers" : papers,
97
  "umls" : umls,
98
  "drug_safety" : fda,
99
- "genes" : sum(ncbi, []),
100
- "mesh_defs" : meshes,
101
- "gene_meta" : gmeta,
102
- "gene_disease" : await disease_to_genes(query) or [],
103
- "clinical_trials" : trials,
104
- "variants" : variants,
105
- "variant_count" : len(variants),
106
  "ai_summary" : summary,
107
- "llm_used" : llm.lower(),
 
 
 
 
 
108
  }
109
 
110
- # ---------------- follow-up QA --------------------
111
  async def answer_ai_question(question: str, *, context: str, llm: str = _DEF) -> Dict[str, str]:
112
- _, _, qa_fn = _llm_router(llm)
113
- ans = await qa_fn(question, context)
114
- return {"answer": ans}
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ MedGenesis – dual-LLM orchestrator
3
+ ----------------------------------
4
+ β€’ Accepts llm = "openai" | "gemini" (falls back to OpenAI)
5
+ β€’ Returns one unified dict the UI can rely on.
 
 
6
  """
7
+ from __future__ import annotations
8
+ import asyncio, itertools, logging
9
+ from typing import Dict, Any, List, Tuple
10
 
 
 
 
 
11
  from mcp.arxiv import fetch_arxiv
12
  from mcp.pubmed import fetch_pubmed
 
 
 
 
 
13
  from mcp.ncbi import search_gene, get_mesh_definition
 
 
14
  from mcp.mygene import fetch_gene_info
15
  from mcp.ensembl import fetch_ensembl
16
  from mcp.opentargets import fetch_ot
17
+ from mcp.umls import lookup_umls
18
+ from mcp.openfda import fetch_drug_safety
19
+ from mcp.disgenet import disease_to_genes
20
+ from mcp.clinicaltrials import search_trials
21
+ from mcp.cbio import fetch_cbio
22
  from mcp.openai_utils import ai_summarize, ai_qa
23
  from mcp.gemini import gemini_summarize, gemini_qa
24
 
25
+ log = logging.getLogger(__name__)
26
+ _DEF = "openai" # default engine
27
+
28
+
29
+ # ─────────────────────────────────── helpers ───────────────────────────────────
30
+ def _llm_router(engine: str = _DEF) -> Tuple:
31
+ if engine.lower() == "gemini":
32
+ return gemini_summarize, gemini_qa, "gemini"
33
+ return ai_summarize, ai_qa, "openai"
34
+
35
+ async def _gather_safely(*aws, as_list: bool = True):
36
+ """await gather() that converts Exception β†’ RuntimeError placeholder"""
37
+ out = await asyncio.gather(*aws, return_exceptions=True)
38
+ if as_list:
39
+ # filter exceptions – keep structure but drop failures
40
+ return [x for x in out if not isinstance(x, Exception)]
41
+ return out
42
+
43
+ async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
44
+ jobs = []
45
+ for k in keys:
46
+ jobs += [
47
+ search_gene(k), # basic gene info
48
+ get_mesh_definition(k), # MeSH definitions
49
+ fetch_gene_info(k), # MyGene
50
+ fetch_ensembl(k), # Ensembl x-refs
51
+ fetch_ot(k), # Open Targets associations
52
+ ]
53
+ res = await _gather_safely(*jobs, as_list=False)
54
+
55
+ # slice & compress five-way fan-out
56
+ combo = lambda idx: [r for i, r in enumerate(res) if i % 5 == idx and r]
57
+ return {
58
+ "ncbi" : combo(0),
59
+ "mesh" : combo(1),
60
+ "mygene" : combo(2),
61
+ "ensembl" : combo(3),
62
+ "ot_assoc" : combo(4),
63
+ }
64
 
 
 
 
 
 
65
 
66
+ # ───────────────────────────────── orchestrator ────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
67
  async def orchestrate_search(query: str, *, llm: str = _DEF) -> Dict[str, Any]:
68
+ """Main entry – returns dict for the Streamlit UI"""
69
+ # 1 Literature – run in parallel
70
+ arxiv_task = asyncio.create_task(fetch_arxiv(query))
71
+ pubmed_task = asyncio.create_task(fetch_pubmed(query))
72
+ papers_raw = await _gather_safely(arxiv_task, pubmed_task)
73
+ papers = list(itertools.chain.from_iterable(papers_raw))[:30] # keep ≀30
74
+
75
+ # 2 Keyword extraction (very light – only from abstracts)
76
+ kws = {w for p in papers for w in (p["summary"][:500].split()) if w.isalpha()}
77
+ kws = list(kws)[:10] # coarse, fast -> 10 seeds
78
+
79
+ # 3 Bio-enrichment fan-out
80
+ umls_f = [_safe_task(lookup_umls, k) for k in kws]
81
+ fda_f = [_safe_task(fetch_drug_safety, k) for k in kws]
82
+ gene_bundle = asyncio.create_task(_gene_enrichment(kws))
83
+ trials_task = asyncio.create_task(search_trials(query, max_studies=20))
84
+ cbio_task = asyncio.create_task(fetch_cbio(kws[0] if kws else ""))
85
+
86
+ umls, fda, gene_dat, trials, variants = await asyncio.gather(
87
+ _gather_safely(*umls_f),
88
+ _gather_safely(*fda_f),
89
+ gene_bundle,
90
+ trials_task,
91
+ cbio_task,
 
 
 
92
  )
93
 
94
+ # 4 LLM summary
95
+ summarise_fn, _, engine = _llm_router(llm)
96
+ summary = await summarise_fn(" ".join(p["summary"] for p in papers)[:12000])
 
 
 
 
 
 
 
 
97
 
98
  return {
99
  "papers" : papers,
100
  "umls" : umls,
101
  "drug_safety" : fda,
 
 
 
 
 
 
 
102
  "ai_summary" : summary,
103
+ "llm_used" : engine,
104
+ "genes" : gene_dat["ncbi"] + gene_dat["ensembl"] + gene_dat["mygene"],
105
+ "mesh_defs" : gene_dat["mesh"],
106
+ "gene_disease" : gene_dat["ot_assoc"],
107
+ "clinical_trials" : trials,
108
+ "variants" : variants or [],
109
  }
110
 
111
+ # ─────────────────────────────── follow-up QA ─────────────────────────────────
112
  async def answer_ai_question(question: str, *, context: str, llm: str = _DEF) -> Dict[str, str]:
113
+ """Follow-up QA using chosen LLM."""
114
+ _, qa_fn, _ = _llm_router(llm)
115
+ return {"answer": await qa_fn(f"Q: {question}\nContext: {context}\nA:")}
116
+
117
+
118
+ # ─────────────────────────── internal util ───────────────────────────────────
119
+ def _safe_task(fn, *args):
120
+ """Helper to wrap callable β†’ Task returning RuntimeError on exception."""
121
+ async def _wrapper():
122
+ try:
123
+ return await fn(*args)
124
+ except Exception as exc:
125
+ log.warning("background task %s failed: %s", fn.__name__, exc)
126
+ return RuntimeError(str(exc))
127
+ return asyncio.create_task(_wrapper())