mgbam commited on
Commit
b15fc81
·
verified ·
1 Parent(s): a130367

Update mcp/orchestrator.py

Browse files
Files changed (1) hide show
  1. mcp/orchestrator.py +53 -89
mcp/orchestrator.py CHANGED
@@ -1,13 +1,10 @@
1
  #!/usr/bin/env python3
2
- # mcp/orchestrator.py
3
-
4
  """
5
- MedGenesis – dual-LLM orchestrator (v4)
6
  ---------------------------------------
7
- Accepts llm="openai" | "gemini" (defaults to OpenAI)
8
- Safely runs all data-source calls in parallel
9
- Uses pytrials for ClinicalTrials.gov and pybioportal for cBioPortal
10
- • Returns one dict that the Streamlit UI can rely on
11
  """
12
 
13
  from __future__ import annotations
@@ -33,90 +30,95 @@ _DEFAULT_LLM = "openai"
33
 
34
 
35
  def _llm_router(engine: str = _DEFAULT_LLM):
36
- """Returns (summarize_fn, qa_fn, engine_name)."""
37
  if engine.lower() == "gemini":
38
  return gemini_summarize, gemini_qa, "gemini"
39
  return ai_summarize, ai_qa, "openai"
40
 
41
 
42
  async def _safe_gather(*tasks, return_exceptions: bool = False):
43
- """
44
- Wrapper around asyncio.gather that logs failures
45
- and optionally returns exceptions as results.
46
- """
47
  results = await asyncio.gather(*tasks, return_exceptions=True)
48
  cleaned = []
49
- for idx, res in enumerate(results):
50
- if isinstance(res, Exception):
51
- log.warning("Task %d failed: %s", idx, res)
52
  if return_exceptions:
53
- cleaned.append(res)
54
  else:
55
- cleaned.append(res)
56
  return cleaned
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  async def orchestrate_search(query: str, llm: str = _DEFAULT_LLM) -> Dict[str, Any]:
60
- """
61
- Main entry point for MedGenesis UI.
62
- Returns a dict with:
63
- - papers, umls, drug_safety, clinical_trials, variants
64
- - genes, mesh_defs, gene_disease
65
- - ai_summary, llm_used
66
- """
67
- # 1) Literature (PubMed + arXiv in parallel)
68
- pubmed_t = asyncio.create_task(fetch_pubmed(query, max_results=7))
69
- arxiv_t = asyncio.create_task(fetch_arxiv(query, max_results=7))
70
- papers_raw = await _safe_gather(pubmed_t, arxiv_t)
71
  papers = list(itertools.chain.from_iterable(papers_raw))[:30]
72
 
73
- # 2) Keyword seeds from abstracts (first 500 chars, split on whitespace)
74
  seeds = {
75
- w.strip()
76
- for p in papers
77
- for w in p.get("summary", "")[:500].split()
78
- if w.isalpha()
79
  }
80
  seeds = list(seeds)[:10]
81
 
82
- # 3) Fan-out all bio-enrichment tasks safely
83
  umls_tasks = [asyncio.create_task(lookup_umls(k)) for k in seeds]
84
  fda_tasks = [asyncio.create_task(fetch_drug_safety(k)) for k in seeds]
85
- gene_enrich_t = asyncio.create_task(_gene_enrichment(seeds))
86
- trials_t = asyncio.create_task(fetch_clinical_trials(query, max_studies=10))
87
- cbio_t = asyncio.create_task(
88
  fetch_cbio_variants(seeds[0]) if seeds else asyncio.sleep(0, result=[])
89
  )
90
 
91
  umls_list, fda_list, gene_data, trials, variants = await asyncio.gather(
92
  _safe_gather(*umls_tasks, return_exceptions=True),
93
  _safe_gather(*fda_tasks, return_exceptions=True),
94
- gene_enrich_t,
95
  trials_t,
96
  cbio_t,
97
  )
98
 
99
- # 4) Deduplicate and flatten genes
100
  genes = {
101
  g["symbol"]
102
- for source in (gene_data["ncbi"], gene_data["mygene"], gene_data["ensembl"], gene_data["ot_assoc"])
103
- for g in source if isinstance(g, dict) and g.get("symbol")
104
  }
105
  genes = list(genes)
106
 
107
- # 5) Dedupe variants by (chrom, pos, ref, alt) if returned as dicts
108
- seen = set()
109
- unique_vars: List[dict] = []
110
- for var in variants or []:
111
- key = (var.get("chromosome"), var.get("startPosition"), var.get("referenceAllele"), var.get("variantAllele"))
112
  if key not in seen:
113
- seen.add(key)
114
- unique_vars.append(var)
115
 
116
  # 6) LLM summary
117
- summarize_fn, _, engine_used = _llm_router(llm)
118
- long_text = " ".join(p.get("summary", "") for p in papers)
119
- ai_summary = await summarize_fn(long_text[:12000])
120
 
121
  return {
122
  "papers": papers,
@@ -132,45 +134,7 @@ async def orchestrate_search(query: str, llm: str = _DEFAULT_LLM) -> Dict[str, A
132
  }
133
 
134
 
135
- async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
136
- """
137
- Fan-out gene-related tasks for each seed key:
138
- - NCBI gene lookup
139
- - MeSH definition
140
- - MyGene.info
141
- - Ensembl xrefs
142
- - OpenTargets associations
143
- Returns a dict of lists.
144
- """
145
- jobs = []
146
- for k in keys:
147
- jobs.extend([
148
- asyncio.create_task(search_gene(k)),
149
- asyncio.create_task(get_mesh_definition(k)),
150
- asyncio.create_task(fetch_gene_info(k)),
151
- asyncio.create_task(fetch_ensembl(k)),
152
- asyncio.create_task(fetch_ot(k)),
153
- ])
154
-
155
- results = await _safe_gather(*jobs, return_exceptions=True)
156
-
157
- # Group back into 5 buckets
158
- def bucket(idx: int):
159
- return [r for i, r in enumerate(results) if i % 5 == idx and not isinstance(r, Exception)]
160
-
161
- return {
162
- "ncbi": bucket(0),
163
- "mesh": bucket(1),
164
- "mygene": bucket(2),
165
- "ensembl": bucket(3),
166
- "ot_assoc": bucket(4),
167
- }
168
-
169
-
170
  async def answer_ai_question(question: str, context: str, llm: str = _DEFAULT_LLM) -> Dict[str, str]:
171
- """
172
- Follow-up QA: wraps the chosen LLM’s QA function.
173
- """
174
  _, qa_fn, _ = _llm_router(llm)
175
  prompt = f"Q: {question}\nContext: {context}\nA:"
176
  try:
 
1
  #!/usr/bin/env python3
 
 
2
  """
3
+ MedGenesis – dual-LLM orchestrator (v5)
4
  ---------------------------------------
5
+ No external 'pytrials' dependency.
6
+ Uses direct HTTP for clinical trials.
7
+ Clean async fan-out, dual-LLM support.
 
8
  """
9
 
10
  from __future__ import annotations
 
30
 
31
 
32
  def _llm_router(engine: str = _DEFAULT_LLM):
 
33
  if engine.lower() == "gemini":
34
  return gemini_summarize, gemini_qa, "gemini"
35
  return ai_summarize, ai_qa, "openai"
36
 
37
 
38
  async def _safe_gather(*tasks, return_exceptions: bool = False):
 
 
 
 
39
  results = await asyncio.gather(*tasks, return_exceptions=True)
40
  cleaned = []
41
+ for r in results:
42
+ if isinstance(r, Exception):
43
+ log.warning("Task failed: %s", r)
44
  if return_exceptions:
45
+ cleaned.append(r)
46
  else:
47
+ cleaned.append(r)
48
  return cleaned
49
 
50
 
51
+ async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
52
+ jobs = []
53
+ for k in keys:
54
+ jobs.extend([
55
+ asyncio.create_task(search_gene(k)),
56
+ asyncio.create_task(get_mesh_definition(k)),
57
+ asyncio.create_task(fetch_gene_info(k)),
58
+ asyncio.create_task(fetch_ensembl(k)),
59
+ asyncio.create_task(fetch_ot(k)),
60
+ ])
61
+ res = await _safe_gather(*jobs, return_exceptions=True)
62
+ # split into buckets of 5
63
+ def bucket(i): return [x for idx, x in enumerate(res) if idx % 5 == i and not isinstance(x, Exception)]
64
+ return {
65
+ "ncbi": bucket(0),
66
+ "mesh": bucket(1),
67
+ "mygene": bucket(2),
68
+ "ensembl": bucket(3),
69
+ "ot_assoc": bucket(4),
70
+ }
71
+
72
+
73
  async def orchestrate_search(query: str, llm: str = _DEFAULT_LLM) -> Dict[str, Any]:
74
+ # 1) Literature
75
+ pm_t = asyncio.create_task(fetch_pubmed(query, max_results=7))
76
+ ar_t = asyncio.create_task(fetch_arxiv(query, max_results=7))
77
+ papers_raw = await _safe_gather(pm_t, ar_t)
 
 
 
 
 
 
 
78
  papers = list(itertools.chain.from_iterable(papers_raw))[:30]
79
 
80
+ # 2) Seeds
81
  seeds = {
82
+ w for p in papers for w in p.get("summary", "")[:500].split() if w.isalpha()
 
 
 
83
  }
84
  seeds = list(seeds)[:10]
85
 
86
+ # 3) Fan-out
87
  umls_tasks = [asyncio.create_task(lookup_umls(k)) for k in seeds]
88
  fda_tasks = [asyncio.create_task(fetch_drug_safety(k)) for k in seeds]
89
+ gene_t = asyncio.create_task(_gene_enrichment(seeds))
90
+ trials_t = asyncio.create_task(fetch_clinical_trials(query, max_studies=10))
91
+ cbio_t = asyncio.create_task(
92
  fetch_cbio_variants(seeds[0]) if seeds else asyncio.sleep(0, result=[])
93
  )
94
 
95
  umls_list, fda_list, gene_data, trials, variants = await asyncio.gather(
96
  _safe_gather(*umls_tasks, return_exceptions=True),
97
  _safe_gather(*fda_tasks, return_exceptions=True),
98
+ gene_t,
99
  trials_t,
100
  cbio_t,
101
  )
102
 
103
+ # 4) Genes
104
  genes = {
105
  g["symbol"]
106
+ for src in (gene_data["ncbi"], gene_data["mygene"], gene_data["ensembl"], gene_data["ot_assoc"])
107
+ for g in src if isinstance(g, dict) and g.get("symbol")
108
  }
109
  genes = list(genes)
110
 
111
+ # 5) Dedupe variants by coords
112
+ seen = set(); unique_vars = []
113
+ for v in variants or []:
114
+ key = (v.get("chromosome"), v.get("startPosition"), v.get("referenceAllele"), v.get("variantAllele"))
 
115
  if key not in seen:
116
+ seen.add(key); unique_vars.append(v)
 
117
 
118
  # 6) LLM summary
119
+ sum_fn, _, engine_used = _llm_router(llm)
120
+ combined = " ".join(p.get("summary", "") for p in papers)
121
+ ai_summary = await sum_fn(combined[:12000])
122
 
123
  return {
124
  "papers": papers,
 
134
  }
135
 
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  async def answer_ai_question(question: str, context: str, llm: str = _DEFAULT_LLM) -> Dict[str, str]:
 
 
 
138
  _, qa_fn, _ = _llm_router(llm)
139
  prompt = f"Q: {question}\nContext: {context}\nA:"
140
  try: