mgbam commited on
Commit
bd2d9e0
·
verified ·
1 Parent(s): 7c64f21

Update mcp/orchestrator.py

Browse files
Files changed (1) hide show
  1. mcp/orchestrator.py +70 -38
mcp/orchestrator.py CHANGED
@@ -1,47 +1,79 @@
1
  # mcp/orchestrator.py
2
  import asyncio
3
- from mcp.mygene import mygene
4
- from mcp.opentargets import ot
5
- from mcp.cbio import cbio
6
- # import pubmed, umls, clinicaltrials, etc …
7
- from typing import Dict, Any
8
-
9
- async def orchestrate_search(query: str, *, llm: str="openai") -> Dict[str,Any]:
10
- # 1) fetch papers + abstracts
11
- papers_task = asyncio.create_task(fetch_papers(query))
12
- # 2) pull UMLS concepts
13
- from mcp.nlp import extract_keywords
14
- kws = extract_keywords(query)[:5]
15
- umls_tasks = [lookup_umls(k) for k in kws]
16
- # 3) fetch gene info + associations
17
- gene_task = asyncio.create_task(mygene.fetch(query))
18
- ot_task = asyncio.create_task(ot.fetch(query))
19
- # 4) fetch variants
20
- cbio_task = asyncio.create_task(cbio.fetch_variants(query))
21
- # 5) clinical trials
22
- trials_task = asyncio.create_task(search_trials(query))
23
-
24
- # wait all
25
- papers = await papers_task
26
- umls = await asyncio.gather(*umls_tasks, return_exceptions=True)
27
- gene, assoc, vars_, trials = await asyncio.gather(
28
- gene_task, ot_task, cbio_task, trials_task, return_exceptions=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  )
30
 
31
- # 6) call your chosen LLM
32
- from mcp.ai import ai_summarize, gemini_summarize
33
- if llm=="openai":
34
- summary = await ai_summarize("\n\n".join(p["summary"] for p in papers))
35
- else:
36
- summary = await gemini_summarize("\n\n".join(p["summary"] for p in papers))
37
 
38
  return {
39
  "papers": papers,
40
- "umls": [u for u in umls if not isinstance(u, Exception)],
41
- "gene": gene if not isinstance(gene, Exception) else {},
42
- "associations": assoc if not isinstance(assoc, Exception) else [],
43
- "variants": vars_ if not isinstance(vars_, Exception) else [],
44
- "trials": trials if not isinstance(trials, Exception) else [],
 
 
 
 
45
  "ai_summary": summary,
46
- "llm_used": llm
47
  }
 
 
 
 
 
 
 
 
 
1
  # mcp/orchestrator.py
2
  import asyncio
3
+ from typing import Any, Dict
4
+ from mcp.arxiv import fetch_arxiv
5
+ from mcp.pubmed import fetch_pubmed
6
+ from mcp.nlp import extract_umls_concepts
7
+ from mcp.umls import lookup_umls
8
+ from mcp.umls_rel import fetch_relations
9
+ from mcp.openfda import fetch_drug_safety
10
+ from mcp.ncbi import search_gene, get_mesh_definition
11
+ from mcp.disgenet import disease_to_genes
12
+ from mcp.clinicaltrials import search_trials
13
+ from mcp.mygene import mygene
14
+ from mcp.opentargets import ot
15
+ from mcp.cbio import cbio
16
+ from mcp.openai_utils import ai_summarize, ai_qa
17
+ from mcp.gemini import gemini_summarize, gemini_qa
18
+
19
+ def _get_llm(llm: str):
20
+ return (gemini_summarize, gemini_qa) if llm.lower()=="gemini" else (ai_summarize, ai_qa)
21
+
22
+ async def orchestrate_search(query: str, llm: str="openai") -> Dict[str,Any]:
23
+ # 1) literature
24
+ arxiv_t, pubmed_t = fetch_arxiv(query), fetch_pubmed(query)
25
+ papers = []
26
+ for res in await asyncio.gather(arxiv_t, pubmed_t, return_exceptions=True):
27
+ if isinstance(res, list):
28
+ papers.extend(res)
29
+
30
+ # 2) UMLS concept linking
31
+ blob = " ".join(p.get("summary","") for p in papers)
32
+ umls = extract_umls_concepts(blob)
33
+ rels = await asyncio.gather(*[fetch_relations(c["cui"]) for c in umls], return_exceptions=True)
34
+
35
+ # 3) enrichment
36
+ keys = [c["name"] for c in umls]
37
+ fda_t = [fetch_drug_safety(k) for k in keys]
38
+ genes_t = search_gene(keys[0]) if keys else asyncio.sleep(0, result=[])
39
+ mesh_t = get_mesh_definition(keys[0]) if keys else asyncio.sleep(0, result="")
40
+ dis_t = disease_to_genes(keys[0]) if keys else asyncio.sleep(0, result=[])
41
+ trials_t = search_trials(query)
42
+ ot_t = ot.fetch(keys[0]) if keys else asyncio.sleep(0, result=[])
43
+ var_t = cbio.fetch_variants(keys[0]) if keys else asyncio.sleep(0, result=[])
44
+
45
+ fda, genes, mesh, dis, trials, ot_assoc, variants = await asyncio.gather(
46
+ asyncio.gather(*fda_t, return_exceptions=True),
47
+ genes_t, mesh_t, dis_t, trials_t, ot_t, var_t,
48
+ return_exceptions=False
49
  )
50
 
51
+ # 4) AI summary
52
+ summarize, _ = _get_llm(llm)
53
+ try:
54
+ summary = await summarize(blob)
55
+ except:
56
+ summary = "LLM unavailable."
57
 
58
  return {
59
  "papers": papers,
60
+ "umls": umls,
61
+ "umls_relations": rels,
62
+ "drug_safety": fda,
63
+ "genes": [genes],
64
+ "mesh_defs": [mesh],
65
+ "gene_disease": dis,
66
+ "clinical_trials": trials,
67
+ "ot_associations": ot_assoc,
68
+ "variants": variants,
69
  "ai_summary": summary,
70
+ "llm_used": llm.lower()
71
  }
72
+
73
+ async def answer_ai_question(question: str, context: str="", llm: str="openai"):
74
+ _, qa = _get_llm(llm)
75
+ try:
76
+ ans = await qa(question, context)
77
+ except:
78
+ ans = "LLM unavailable."
79
+ return {"answer": ans}