mgbam commited on
Commit
e33dfeb
Β·
verified Β·
1 Parent(s): 078f31a

Update mcp/orchestrator.py

Browse files
Files changed (1) hide show
  1. mcp/orchestrator.py +79 -74
mcp/orchestrator.py CHANGED
@@ -1,117 +1,122 @@
1
  #!/usr/bin/env python3
2
- """MedGenesis – asynchronous orchestrator (v3)
3
-
4
- β€’ Pulls literature (PubMed + arXiv)
5
- β€’ Extracts keywords via spaCy β†’ fans‑out to:
6
- – MyGene.info (gene annotations)
7
- – ClinicalTrials.gov v2 (trials) with v1 fallback
8
- – UMLS / openFDA / DisGeNET / MeSH
9
- β€’ Summarises with OpenAI or Gemini (user‑selectable)
10
- β€’ Returns a single dict ready for Streamlit UI & pydantic schemas.
11
  """
12
  from __future__ import annotations
13
 
14
  import asyncio
15
  from typing import Any, Dict, List
16
 
17
- # ── Internal modules ────────────────────────────────────────────────
18
- from mcp.arxiv import fetch_arxiv
19
- from mcp.pubmed import fetch_pubmed
20
- from mcp.nlp import extract_keywords
21
- from mcp.umls import lookup_umls
22
- from mcp.openfda import fetch_drug_safety
23
- from mcp.ncbi import search_gene, get_mesh_definition
24
- from mcp.disgenet import disease_to_genes
25
- from mcp.mygene import fetch_gene_info
26
- from mcp.ctgov import search_trials # v2 helper (v1 fallback inside)
27
- from mcp.openai_utils import ai_summarize, ai_qa
28
- from mcp.gemini import gemini_summarize, gemini_qa
29
-
30
- # -------------------------------------------------------------------
 
 
31
  # LLM router
32
- # -------------------------------------------------------------------
33
- _DEF_MODEL = "openai"
34
 
35
- def _get_llm(llm: str | None):
36
- if llm and llm.lower() == "gemini":
37
  return gemini_summarize, gemini_qa, "gemini"
38
  return ai_summarize, ai_qa, "openai"
39
 
40
- # -------------------------------------------------------------------
41
- # Helper: NCBI / MeSH / DisGeNET enrichment
42
- # -------------------------------------------------------------------
43
  async def _enrich_keywords(keys: List[str]) -> Dict[str, Any]:
44
- tasks = []
45
  for k in keys:
46
  tasks += [search_gene(k), get_mesh_definition(k), disease_to_genes(k)]
47
 
48
- results = await asyncio.gather(*tasks, return_exceptions=True)
49
 
50
  genes, mesh_defs, disg = [], [], []
51
- for idx, res in enumerate(results):
52
- if isinstance(res, Exception):
53
  continue
54
  bucket = idx % 3
55
  if bucket == 0:
56
- genes.extend(res)
57
  elif bucket == 1:
58
- mesh_defs.append(res)
59
  else:
60
- disg.extend(res)
61
  return {"genes": genes, "meshes": mesh_defs, "disgenet": disg}
62
 
63
- # -------------------------------------------------------------------
64
- # Public API
65
- # -------------------------------------------------------------------
66
- async def orchestrate_search(query: str, *, llm: str = _DEF_MODEL) -> Dict[str, Any]:
67
- """Run full async pipeline and return merged result dict."""
68
  # 1) Literature --------------------------------------------------
69
- arxiv_task = asyncio.create_task(fetch_arxiv(query, max_results=10))
70
- pubmed_task = asyncio.create_task(fetch_pubmed(query, max_results=10))
71
- papers = sum(await asyncio.gather(arxiv_task, pubmed_task), [])
72
-
73
- # 2) Keyword extraction -----------------------------------------
74
- corpus = " ".join(p["summary"] for p in papers)
 
 
 
75
  keywords = extract_keywords(corpus)[:8]
76
 
77
- # 3) Enrichment fan‑out -----------------------------------------
78
- umls_tasks = [lookup_umls(k) for k in keywords]
79
- fda_tasks = [fetch_drug_safety(k) for k in keywords]
80
-
81
- ncbi_task = asyncio.create_task(_enrich_keywords(keywords))
82
- gene_task = asyncio.create_task(fetch_gene_info(query))
83
- trials_task = asyncio.create_task(search_trials(query, max_studies=20))
84
-
85
- umls, fda, ncbi_data, mygene, trials = await asyncio.gather(
86
- asyncio.gather(*umls_tasks, return_exceptions=True),
87
- asyncio.gather(*fda_tasks, return_exceptions=True),
88
- ncbi_task,
89
- gene_task,
90
- trials_task,
91
  )
92
 
 
 
 
 
93
  # 4) LLM summary -------------------------------------------------
94
- summarize_fn, _, engine = _get_llm(llm)
95
- ai_summary = await summarize_fn(corpus)
96
 
97
- # 5) Assemble result --------------------------------------------
98
  return {
99
  "papers" : papers,
100
  "umls" : umls,
101
  "drug_safety" : fda,
102
  "ai_summary" : ai_summary,
103
  "llm_used" : engine,
104
-
105
- # genes & ontologies
106
- "genes" : (ncbi_data["genes"] or []) + ([mygene] if mygene else []),
107
- "mesh_defs" : ncbi_data["meshes"],
108
- "gene_disease" : ncbi_data["disgenet"],
109
-
110
  # trials
111
  "clinical_trials": trials,
112
  }
113
 
114
-
115
- async def answer_ai_question(question: str, *, context: str, llm: str = _DEF_MODEL) -> Dict[str, str]:
116
- summarize, qa_fn, engine = _get_llm(llm)
 
117
  return {"answer": await qa_fn(question, context)}
 
1
  #!/usr/bin/env python3
2
+ """MedGenesis – orchestrator (v4, resilient).
3
+
4
+ * Pulls PubMed + arXiv (async)
5
+ * spaCy keyword extraction β†’ UMLS / openFDA / DisGeNET / MeSH fan‑out
6
+ * Adds MyGene.info, ClinicalTrials.gov (v2 β†’ v1 fallback)
7
+ * Filters out failed enrichment calls (exceptions) so UI never crashes
8
+ * Summarises with OpenAI *or* Gemini; router returns `llm_used`
 
 
9
  """
10
  from __future__ import annotations
11
 
12
  import asyncio
13
  from typing import Any, Dict, List
14
 
15
+ # ── async fetchers ──────────────────────────────────────────────────
16
+ from mcp.arxiv import fetch_arxiv
17
+ from mcp.pubmed import fetch_pubmed
18
+ from mcp.nlp import extract_keywords
19
+ from mcp.umls import lookup_umls
20
+ from mcp.openfda import fetch_drug_safety
21
+ from mcp.ncbi import search_gene, get_mesh_definition
22
+ from mcp.disgenet import disease_to_genes
23
+ from mcp.mygene import fetch_gene_info
24
+ from mcp.ctgov import search_trials
25
+
26
+ # ── LLM helpers ────────────────────────────────────────────────────
27
+ from mcp.openai_utils import ai_summarize, ai_qa
28
+ from mcp.gemini import gemini_summarize, gemini_qa
29
+
30
+ # ------------------------------------------------------------------
31
  # LLM router
32
+ # ------------------------------------------------------------------
33
+ _DEF = "openai"
34
 
35
+ def _llm_router(name: str | None):
36
+ if name and name.lower() == "gemini":
37
  return gemini_summarize, gemini_qa, "gemini"
38
  return ai_summarize, ai_qa, "openai"
39
 
40
+ # ------------------------------------------------------------------
41
+ # Keyword enrichment bundle
42
+ # ------------------------------------------------------------------
43
  async def _enrich_keywords(keys: List[str]) -> Dict[str, Any]:
44
+ tasks: List[asyncio.Future] = []
45
  for k in keys:
46
  tasks += [search_gene(k), get_mesh_definition(k), disease_to_genes(k)]
47
 
48
+ res = await asyncio.gather(*tasks, return_exceptions=True)
49
 
50
  genes, mesh_defs, disg = [], [], []
51
+ for idx, r in enumerate(res):
52
+ if isinstance(r, Exception):
53
  continue
54
  bucket = idx % 3
55
  if bucket == 0:
56
+ genes.extend(r)
57
  elif bucket == 1:
58
+ mesh_defs.append(r)
59
  else:
60
+ disg.extend(r)
61
  return {"genes": genes, "meshes": mesh_defs, "disgenet": disg}
62
 
63
+ # ------------------------------------------------------------------
64
+ # Main orchestrator
65
+ # ------------------------------------------------------------------
66
+ async def orchestrate_search(query: str, *, llm: str = _DEF) -> Dict[str, Any]:
67
+ """Run entire async pipeline; never raises."""
68
  # 1) Literature --------------------------------------------------
69
+ arxiv_f = asyncio.create_task(fetch_arxiv(query, max_results=10))
70
+ pubmed_f = asyncio.create_task(fetch_pubmed(query, max_results=10))
71
+ papers = []
72
+ for fut in await asyncio.gather(arxiv_f, pubmed_f, return_exceptions=True):
73
+ if not isinstance(fut, Exception):
74
+ papers.extend(fut)
75
+
76
+ # 2) Keywords ----------------------------------------------------
77
+ corpus = " ".join(p.get("summary", "") for p in papers)
78
  keywords = extract_keywords(corpus)[:8]
79
 
80
+ # 3) Fan‑out enrichment -----------------------------------------
81
+ umls_f = [lookup_umls(k) for k in keywords]
82
+ fda_f = [fetch_drug_safety(k) for k in keywords]
83
+ ncbi_f = asyncio.create_task(_enrich_keywords(keywords))
84
+ mygene_f = asyncio.create_task(fetch_gene_info(query))
85
+ trials_f = asyncio.create_task(search_trials(query, max_studies=20))
86
+
87
+ umls, fda, ncbi, mygene, trials = await asyncio.gather(
88
+ asyncio.gather(*umls_f, return_exceptions=True),
89
+ asyncio.gather(*fda_f, return_exceptions=True),
90
+ ncbi_f,
91
+ mygene_f,
92
+ trials_f,
 
93
  )
94
 
95
+ # ── filter exception objects -----------------------------------
96
+ umls = [u for u in umls if isinstance(u, dict)]
97
+ fda = [d for d in fda if isinstance(d, (dict, list))]
98
+
99
  # 4) LLM summary -------------------------------------------------
100
+ summarize, _, engine = _llm_router(llm)
101
+ ai_summary = await summarize(corpus) if corpus else ""
102
 
103
+ # 5) Assemble payload -------------------------------------------
104
  return {
105
  "papers" : papers,
106
  "umls" : umls,
107
  "drug_safety" : fda,
108
  "ai_summary" : ai_summary,
109
  "llm_used" : engine,
110
+ # gene context
111
+ "genes" : (ncbi["genes"] or []) + ([mygene] if mygene else []),
112
+ "mesh_defs" : ncbi["meshes"],
113
+ "gene_disease" : ncbi["disgenet"],
 
 
114
  # trials
115
  "clinical_trials": trials,
116
  }
117
 
118
+ # ------------------------------------------------------------------
119
+ async def answer_ai_question(question: str, *, context: str, llm: str = _DEF) -> Dict[str, str]:
120
+ """Follow‑up QA using chosen LLM."""
121
+ _, qa_fn, _ = _llm_router(llm)
122
  return {"answer": await qa_fn(question, context)}