mgbam commited on
Commit
0bd4f6b
Β·
verified Β·
1 Parent(s): edc2450

Update mcp/orchestrator.py

Browse files
Files changed (1) hide show
  1. mcp/orchestrator.py +120 -154
mcp/orchestrator.py CHANGED
@@ -1,161 +1,127 @@
1
- import asyncio
2
- import httpx
3
- from typing import Any, Dict, List, Literal, Union
4
-
5
- from mcp.pubmed import fetch_pubmed
6
- from mcp.arxiv import fetch_arxiv
7
- from mcp.umls import extract_umls_concepts, lookup_umls
8
- from mcp.openfda import fetch_drug_safety
9
- from mcp.ncbi import search_gene, get_mesh_definition
10
- from mcp.mygene import fetch_gene_info
11
- from mcp.ensembl import fetch_ensembl
12
- from mcp.opentargets import fetch_ot
13
- from mcp.clinicaltrials import search_trials
14
- from mcp.cbio import fetch_cbio
15
- from mcp.disgenet import disease_to_genes
16
- from mcp.gemini import gemini_summarize, gemini_qa
17
- from mcp.openai_utils import ai_summarize, ai_qa
18
-
19
-
20
- async def _safe_call(
21
- func: Any,
22
- *args,
23
- default: Any = None,
24
- **kwargs,
25
- ) -> Any:
26
- """
27
- Safely call an async function, returning a default on HTTP or other failures.
28
- """
29
- try:
30
- return await func(*args, **kwargs) # type: ignore
31
- except httpx.HTTPStatusError:
32
- return default
33
- except Exception:
34
- return default
35
-
36
-
37
- async def _gather_tasks(tasks: List[asyncio.Task]) -> List[Any]:
38
- """
39
- Await a list of asyncio.Tasks and return their results in order.
40
- """
41
- return await asyncio.gather(*tasks)
42
-
43
-
44
- def _flatten_unique(items: List[Union[List[Any], Any]]) -> List[Any]:
45
- """
46
- Flatten a list of items where elements may be lists or single values,
47
- then deduplicate preserving insertion order.
48
- """
49
- flat: List[Any] = []
50
- for elem in items:
51
- if isinstance(elem, list):
52
- for x in elem:
53
- if x not in flat:
54
- flat.append(x)
55
- else:
56
- if elem is not None and elem not in flat:
57
- flat.append(elem)
58
- return flat
59
-
60
-
61
- async def orchestrate_search(
62
- query: str,
63
- llm: Literal['openai', 'gemini'] = 'openai',
64
- max_papers: int = 7,
65
- max_trials: int = 10,
66
- ) -> Dict[str, Any]:
67
- """
68
- Perform a comprehensive biomedical search pipeline with fault tolerance:
69
- - Extract UMLS concepts and fetch definitions
70
- - Literature (PubMed + arXiv)
71
- - Drug safety, gene & variant info, disease-gene mapping
72
- - Clinical trials, cBioPortal data
73
- - AI-driven summary
74
-
75
- Returns a dict with structured results ready for UI/graph building.
76
- """
77
- # 1) Extract concepts and perform UMLS lookups
78
- raw_concepts = await asyncio.to_thread(extract_umls_concepts, query)
79
- umls_tasks = [
80
- asyncio.create_task(
81
- _safe_call(
82
- lookup_umls,
83
- term,
84
- default={
85
- 'term': term,
86
- 'cui': None,
87
- 'name': None,
88
- 'definition': None,
89
- },
90
- )
91
- )
92
- for term in raw_concepts
93
- ]
94
-
95
- # 2) Launch parallel data-fetch tasks (excluding UMLS)
96
- tasks = {
97
- 'pubmed': asyncio.create_task(fetch_pubmed(query, max_results=max_papers)),
98
- 'arxiv': asyncio.create_task(fetch_arxiv(query, max_results=max_papers)),
99
- 'drug_safety': asyncio.create_task(_safe_call(fetch_drug_safety, query, default=[])),
100
- 'ncbi_gene': asyncio.create_task(_safe_call(search_gene, query, default=[])),
101
- 'mygene': asyncio.create_task(_safe_call(fetch_gene_info, query, default=[])),
102
- 'ensembl': asyncio.create_task(_safe_call(fetch_ensembl, query, default=[])),
103
- 'opentargets': asyncio.create_task(_safe_call(fetch_ot, query, default=[])),
104
- 'mesh': asyncio.create_task(_safe_call(get_mesh_definition, query, default="")),
105
- 'trials': asyncio.create_task(_safe_call(search_trials, query, default=[], max_studies=max_trials)),
106
- 'cbio': asyncio.create_task(_safe_call(fetch_cbio, query, default=[])),
107
- 'disgenet': asyncio.create_task(_safe_call(disease_to_genes, query, default=[])),
108
  }
109
 
110
- # 3) Await all tasks
111
- results = await _gather_tasks(list(tasks.values()))
112
- data = dict(zip(tasks.keys(), results))
113
- umls_results = await asyncio.gather(*umls_tasks)
114
-
115
- # 4) Consolidate gene sources
116
- gene_sources = [data['ncbi_gene'], data['mygene'], data['ensembl'], data['opentargets']]
117
- genes = _flatten_unique(gene_sources)
118
 
119
- # 5) Merge literature
120
- papers = (data['pubmed'] or []) + (data['arxiv'] or [])
121
-
122
- # 6) AI-driven summary
123
- summaries = " ".join(p.get('summary', '') for p in papers)
124
- if llm == 'gemini':
125
- ai_summary = await gemini_summarize(summaries)
126
- llm_used = 'gemini'
127
- else:
128
- ai_summary = await ai_summarize(summaries)
129
- llm_used = 'openai'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  return {
132
- 'papers': papers,
133
- 'genes': genes,
134
- 'umls': umls_results,
135
- 'gene_disease': data['disgenet'] or [],
136
- 'mesh_defs': [data['mesh']] if data['mesh'] else [],
137
- 'drug_safety': data['drug_safety'] or [],
138
- 'clinical_trials': data['trials'] or [],
139
- 'variants': data['cbio'] or [],
140
- 'ai_summary': ai_summary,
141
- 'llm_used': llm_used,
142
  }
143
 
144
-
145
- async def answer_ai_question(
146
- question: str,
147
- context: str = "",
148
- llm: Literal['openai', 'gemini'] = 'openai',
149
- ) -> Dict[str, str]:
150
- """
151
- Answer a free-text question using the specified LLM, with fallback.
152
- Returns {'answer': text}.
153
- """
154
- try:
155
- if llm == 'gemini':
156
- answer = await gemini_qa(question, context)
157
- else:
158
- answer = await ai_qa(question, context)
159
- except Exception as e:
160
- answer = f"LLM unavailable or quota exceeded: {e}"
161
- return {'answer': answer}
 
1
+ """
2
+ MedGenesis – dual-LLM orchestrator
3
+ ----------------------------------
4
+ β€’ Accepts llm = "openai" | "gemini" (falls back to OpenAI)
5
+ β€’ Returns one unified dict the UI can rely on.
6
+ """
7
+ from __future__ import annotations
8
+ import asyncio, itertools, logging
9
+ from typing import Dict, Any, List, Tuple
10
+
11
+ from mcp.arxiv import fetch_arxiv
12
+ from mcp.pubmed import fetch_pubmed
13
+ from mcp.ncbi import search_gene, get_mesh_definition
14
+ from mcp.mygene import fetch_gene_info
15
+ from mcp.ensembl import fetch_ensembl
16
+ from mcp.opentargets import fetch_ot
17
+ from mcp.umls import lookup_umls
18
+ from mcp.openfda import fetch_drug_safety
19
+ from mcp.disgenet import disease_to_genes
20
+ from mcp.clinicaltrials import search_trials
21
+ from mcp.cbio import fetch_cbio
22
+ from mcp.openai_utils import ai_summarize, ai_qa
23
+ from mcp.gemini import gemini_summarize, gemini_qa
24
+
25
+ log = logging.getLogger(__name__)
26
+ _DEF = "openai" # default engine
27
+
28
+
29
+ # ─────────────────────────────────── helpers ───────────────────────────────────
30
+ def _llm_router(engine: str = _DEF) -> Tuple:
31
+ if engine.lower() == "gemini":
32
+ return gemini_summarize, gemini_qa, "gemini"
33
+ return ai_summarize, ai_qa, "openai"
34
+
35
+ async def _gather_safely(*aws, as_list: bool = True):
36
+ """await gather() that converts Exception β†’ RuntimeError placeholder"""
37
+ out = await asyncio.gather(*aws, return_exceptions=True)
38
+ if as_list:
39
+ # filter exceptions – keep structure but drop failures
40
+ return [x for x in out if not isinstance(x, Exception)]
41
+ return out
42
+
43
+ async def _gene_enrichment(keys: List[str]) -> Dict[str, Any]:
44
+ jobs = []
45
+ for k in keys:
46
+ jobs += [
47
+ search_gene(k), # basic gene info
48
+ get_mesh_definition(k), # MeSH definitions
49
+ fetch_gene_info(k), # MyGene
50
+ fetch_ensembl(k), # Ensembl x-refs
51
+ fetch_ot(k), # Open Targets associations
52
+ ]
53
+ res = await _gather_safely(*jobs, as_list=False)
54
+
55
+ # slice & compress five-way fan-out
56
+ combo = lambda idx: [r for i, r in enumerate(res) if i % 5 == idx and r]
57
+ return {
58
+ "ncbi" : combo(0),
59
+ "mesh" : combo(1),
60
+ "mygene" : combo(2),
61
+ "ensembl" : combo(3),
62
+ "ot_assoc" : combo(4),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  }
64
 
 
 
 
 
 
 
 
 
65
 
66
+ # ───────────────────────────────── orchestrator ────────────────────────────────
67
+ async def orchestrate_search(query: str, *, llm: str = _DEF) -> Dict[str, Any]:
68
+ """Main entry – returns dict for the Streamlit UI"""
69
+ # 1 Literature – run in parallel
70
+ arxiv_task = asyncio.create_task(fetch_arxiv(query))
71
+ pubmed_task = asyncio.create_task(fetch_pubmed(query))
72
+ papers_raw = await _gather_safely(arxiv_task, pubmed_task)
73
+ papers = list(itertools.chain.from_iterable(papers_raw))[:30] # keep ≀30
74
+
75
+ # 2 Keyword extraction (very light – only from abstracts)
76
+ kws = {w for p in papers for w in (p["summary"][:500].split()) if w.isalpha()}
77
+ kws = list(kws)[:10] # coarse, fast -> 10 seeds
78
+
79
+ # 3 Bio-enrichment fan-out
80
+ umls_f = [_safe_task(lookup_umls, k) for k in kws]
81
+ fda_f = [_safe_task(fetch_drug_safety, k) for k in kws]
82
+ gene_bundle = asyncio.create_task(_gene_enrichment(kws))
83
+ trials_task = asyncio.create_task(search_trials(query, max_studies=20))
84
+ cbio_task = asyncio.create_task(fetch_cbio(kws[0] if kws else ""))
85
+
86
+ umls, fda, gene_dat, trials, variants = await asyncio.gather(
87
+ _gather_safely(*umls_f),
88
+ _gather_safely(*fda_f),
89
+ gene_bundle,
90
+ trials_task,
91
+ cbio_task,
92
+ )
93
+
94
+ # 4 LLM summary
95
+ summarise_fn, _, engine = _llm_router(llm)
96
+ summary = await summarise_fn(" ".join(p["summary"] for p in papers)[:12000])
97
 
98
  return {
99
+ "papers" : papers,
100
+ "umls" : umls,
101
+ "drug_safety" : fda,
102
+ "ai_summary" : summary,
103
+ "llm_used" : engine,
104
+ "genes" : gene_dat["ncbi"] + gene_dat["ensembl"] + gene_dat["mygene"],
105
+ "mesh_defs" : gene_dat["mesh"],
106
+ "gene_disease" : gene_dat["ot_assoc"],
107
+ "clinical_trials" : trials,
108
+ "variants" : variants or [],
109
  }
110
 
111
+ # ─────────────────────────────── follow-up QA ─────────────────────────────────
112
+ async def answer_ai_question(question: str, *, context: str, llm: str = _DEF) -> Dict[str, str]:
113
+ """Follow-up QA using chosen LLM."""
114
+ _, qa_fn, _ = _llm_router(llm)
115
+ return {"answer": await qa_fn(f"Q: {question}\nContext: {context}\nA:")}
116
+
117
+
118
+ # ─────────────────────────── internal util ───────────────────────────────────
119
+ def _safe_task(fn, *args):
120
+ """Helper to wrap callable β†’ Task returning RuntimeError on exception."""
121
+ async def _wrapper():
122
+ try:
123
+ return await fn(*args)
124
+ except Exception as exc:
125
+ log.warning("background task %s failed: %s", fn.__name__, exc)
126
+ return RuntimeError(str(exc))
127
+ return asyncio.create_task(_wrapper())