Spaces:
Starting on CPU Upgrade

mgbam commited on
Commit
0fb7617
·
verified ·
1 Parent(s): 70ede12

Update mcp/orchestrator.py

Browse files
Files changed (1) hide show
  1. mcp/orchestrator.py +64 -24
mcp/orchestrator.py CHANGED
@@ -9,7 +9,7 @@ MedGenesis – dual-LLM orchestrator (v5)
9
 
10
  from __future__ import annotations
11
  import asyncio, itertools, logging
12
- from typing import Dict, Any, List
13
 
14
  from mcp.arxiv import fetch_arxiv
15
  from mcp.pubmed import fetch_pubmed
@@ -29,15 +29,19 @@ log = logging.getLogger(__name__)
29
  _DEFAULT_LLM = "openai"
30
 
31
 
32
- def _llm_router(engine: str = _DEFAULT_LLM):
 
33
  if engine.lower() == "gemini":
34
  return gemini_summarize, gemini_qa, "gemini"
35
  return ai_summarize, ai_qa, "openai"
36
 
37
 
38
  async def _safe_gather(*tasks, return_exceptions: bool = False):
 
 
 
39
  results = await asyncio.gather(*tasks, return_exceptions=True)
40
- cleaned = []
41
  for r in results:
42
  if isinstance(r, Exception):
43
  log.warning("Task failed: %s", r)
@@ -49,7 +53,16 @@ async def _safe_gather(*tasks, return_exceptions: bool = False):
49
 
50
 
51
  async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
52
- jobs = []
 
 
 
 
 
 
 
 
 
53
  for k in keys:
54
  jobs.extend([
55
  asyncio.create_task(search_gene(k)),
@@ -58,9 +71,11 @@ async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
58
  asyncio.create_task(fetch_ensembl(k)),
59
  asyncio.create_task(fetch_ot(k)),
60
  ])
61
- res = await _safe_gather(*jobs, return_exceptions=True)
62
- # split into buckets of 5
63
- def bucket(i): return [x for idx, x in enumerate(res) if idx % 5 == i and not isinstance(x, Exception)]
 
 
64
  return {
65
  "ncbi": bucket(0),
66
  "mesh": bucket(1),
@@ -71,22 +86,35 @@ async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
71
 
72
 
73
  async def orchestrate_search(query: str, llm: str = _DEFAULT_LLM) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
74
  # 1) Literature
75
- pm_t = asyncio.create_task(fetch_pubmed(query, max_results=7))
76
- ar_t = asyncio.create_task(fetch_arxiv(query, max_results=7))
77
- papers_raw = await _safe_gather(pm_t, ar_t)
78
  papers = list(itertools.chain.from_iterable(papers_raw))[:30]
79
 
80
- # 2) Seeds
81
  seeds = {
82
- w for p in papers for w in p.get("summary", "")[:500].split() if w.isalpha()
 
 
 
83
  }
84
  seeds = list(seeds)[:10]
85
 
86
- # 3) Fan-out
87
  umls_tasks = [asyncio.create_task(lookup_umls(k)) for k in seeds]
88
  fda_tasks = [asyncio.create_task(fetch_drug_safety(k)) for k in seeds]
89
- gene_t = asyncio.create_task(_gene_enrichment(seeds))
90
  trials_t = asyncio.create_task(fetch_clinical_trials(query, max_studies=10))
91
  cbio_t = asyncio.create_task(
92
  fetch_cbio_variants(seeds[0]) if seeds else asyncio.sleep(0, result=[])
@@ -95,12 +123,12 @@ async def orchestrate_search(query: str, llm: str = _DEFAULT_LLM) -> Dict[str, A
95
  umls_list, fda_list, gene_data, trials, variants = await asyncio.gather(
96
  _safe_gather(*umls_tasks, return_exceptions=True),
97
  _safe_gather(*fda_tasks, return_exceptions=True),
98
- gene_t,
99
  trials_t,
100
  cbio_t,
101
  )
102
 
103
- # 4) Genes
104
  genes = {
105
  g["symbol"]
106
  for src in (gene_data["ncbi"], gene_data["mygene"], gene_data["ensembl"], gene_data["ot_assoc"])
@@ -108,22 +136,31 @@ async def orchestrate_search(query: str, llm: str = _DEFAULT_LLM) -> Dict[str, A
108
  }
109
  genes = list(genes)
110
 
111
- # 5) Dedupe variants by coords
112
- seen = set(); unique_vars = []
 
113
  for v in variants or []:
114
- key = (v.get("chromosome"), v.get("startPosition"), v.get("referenceAllele"), v.get("variantAllele"))
 
 
 
 
 
115
  if key not in seen:
116
- seen.add(key); unique_vars.append(v)
 
117
 
118
- # 6) LLM summary
119
- sum_fn, _, engine_used = _llm_router(llm)
120
  combined = " ".join(p.get("summary", "") for p in papers)
121
- ai_summary = await sum_fn(combined[:12000])
122
 
123
  return {
124
  "papers": papers,
125
  "umls": [u for u in umls_list if not isinstance(u, Exception)],
126
- "drug_safety": list(itertools.chain.from_iterable(dfa for dfa in fda_list if isinstance(dfa, list))),
 
 
127
  "clinical_trials": trials or [],
128
  "variants": unique_vars,
129
  "genes": gene_data["ncbi"] + gene_data["ensembl"] + gene_data["mygene"],
@@ -135,6 +172,9 @@ async def orchestrate_search(query: str, llm: str = _DEFAULT_LLM) -> Dict[str, A
135
 
136
 
137
  async def answer_ai_question(question: str, context: str, llm: str = _DEFAULT_LLM) -> Dict[str, str]:
 
 
 
138
  _, qa_fn, _ = _llm_router(llm)
139
  prompt = f"Q: {question}\nContext: {context}\nA:"
140
  try:
 
9
 
10
  from __future__ import annotations
11
  import asyncio, itertools, logging
12
+ from typing import Dict, Any, List, Tuple
13
 
14
  from mcp.arxiv import fetch_arxiv
15
  from mcp.pubmed import fetch_pubmed
 
29
  _DEFAULT_LLM = "openai"
30
 
31
 
32
+ def _llm_router(engine: str = _DEFAULT_LLM) -> Tuple:
33
+ """Choose summarization and QA functions based on engine name."""
34
  if engine.lower() == "gemini":
35
  return gemini_summarize, gemini_qa, "gemini"
36
  return ai_summarize, ai_qa, "openai"
37
 
38
 
39
  async def _safe_gather(*tasks, return_exceptions: bool = False):
40
+ """
41
+ Await multiple coroutines, log any exceptions, and optionally return them.
42
+ """
43
  results = await asyncio.gather(*tasks, return_exceptions=True)
44
+ cleaned: List[Any] = []
45
  for r in results:
46
  if isinstance(r, Exception):
47
  log.warning("Task failed: %s", r)
 
53
 
54
 
55
  async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
56
+ """
57
+ Fan-out gene-related endpoints for each seed keyword:
58
+ - NCBI gene lookup
59
+ - MeSH definition
60
+ - MyGene.info
61
+ - Ensembl cross-refs
62
+ - OpenTargets associations
63
+ Returns a dict of results.
64
+ """
65
+ jobs: List[asyncio.Task] = []
66
  for k in keys:
67
  jobs.extend([
68
  asyncio.create_task(search_gene(k)),
 
71
  asyncio.create_task(fetch_ensembl(k)),
72
  asyncio.create_task(fetch_ot(k)),
73
  ])
74
+ results = await _safe_gather(*jobs, return_exceptions=True)
75
+
76
+ def bucket(idx: int) -> List[Any]:
77
+ return [res for i, res in enumerate(results) if i % 5 == idx and not isinstance(res, Exception)]
78
+
79
  return {
80
  "ncbi": bucket(0),
81
  "mesh": bucket(1),
 
86
 
87
 
88
  async def orchestrate_search(query: str, llm: str = _DEFAULT_LLM) -> Dict[str, Any]:
89
+ """
90
+ Main entry point. Performs:
91
+ 1. Literature fetch (PubMed + arXiv)
92
+ 2. Keyword seed extraction
93
+ 3. Bio-enrichment (UMLS, OpenFDA, gene services)
94
+ 4. Clinical trials lookup
95
+ 5. cBioPortal variants
96
+ 6. AI LLM summary
97
+ Returns a unified dict for the UI.
98
+ """
99
  # 1) Literature
100
+ pubmed_t = asyncio.create_task(fetch_pubmed(query, max_results=7))
101
+ arxiv_t = asyncio.create_task(fetch_arxiv(query, max_results=7))
102
+ papers_raw = await _safe_gather(pubmed_t, arxiv_t)
103
  papers = list(itertools.chain.from_iterable(papers_raw))[:30]
104
 
105
+ # 2) Seed keywords
106
  seeds = {
107
+ w.strip()
108
+ for p in papers
109
+ for w in p.get("summary", "")[:500].split()
110
+ if w.isalpha()
111
  }
112
  seeds = list(seeds)[:10]
113
 
114
+ # 3) Bio-enrichment fan-out
115
  umls_tasks = [asyncio.create_task(lookup_umls(k)) for k in seeds]
116
  fda_tasks = [asyncio.create_task(fetch_drug_safety(k)) for k in seeds]
117
+ gene_task = asyncio.create_task(_gene_enrichment(seeds))
118
  trials_t = asyncio.create_task(fetch_clinical_trials(query, max_studies=10))
119
  cbio_t = asyncio.create_task(
120
  fetch_cbio_variants(seeds[0]) if seeds else asyncio.sleep(0, result=[])
 
123
  umls_list, fda_list, gene_data, trials, variants = await asyncio.gather(
124
  _safe_gather(*umls_tasks, return_exceptions=True),
125
  _safe_gather(*fda_tasks, return_exceptions=True),
126
+ gene_task,
127
  trials_t,
128
  cbio_t,
129
  )
130
 
131
+ # 4) Deduplicate gene symbols from enrichment
132
  genes = {
133
  g["symbol"]
134
  for src in (gene_data["ncbi"], gene_data["mygene"], gene_data["ensembl"], gene_data["ot_assoc"])
 
136
  }
137
  genes = list(genes)
138
 
139
+ # 5) Deduplicate variants by genomic coordinates
140
+ seen: set = set()
141
+ unique_vars: List[dict] = []
142
  for v in variants or []:
143
+ key = (
144
+ v.get("chromosome"),
145
+ v.get("startPosition"),
146
+ v.get("referenceAllele"),
147
+ v.get("variantAllele"),
148
+ )
149
  if key not in seen:
150
+ seen.add(key)
151
+ unique_vars.append(v)
152
 
153
+ # 6) LLM-driven summary
154
+ summarize_fn, _, engine_used = _llm_router(llm)
155
  combined = " ".join(p.get("summary", "") for p in papers)
156
+ ai_summary = await summarize_fn(combined[:12000])
157
 
158
  return {
159
  "papers": papers,
160
  "umls": [u for u in umls_list if not isinstance(u, Exception)],
161
+ "drug_safety": list(
162
+ itertools.chain.from_iterable(dfa for dfa in fda_list if isinstance(dfa, list))
163
+ ),
164
  "clinical_trials": trials or [],
165
  "variants": unique_vars,
166
  "genes": gene_data["ncbi"] + gene_data["ensembl"] + gene_data["mygene"],
 
172
 
173
 
174
  async def answer_ai_question(question: str, context: str, llm: str = _DEFAULT_LLM) -> Dict[str, str]:
175
+ """
176
+ Follow-up QA: uses the designated QA function from the LLM router.
177
+ """
178
  _, qa_fn, _ = _llm_router(llm)
179
  prompt = f"Q: {question}\nContext: {context}\nA:"
180
  try: