mgbam commited on
Commit
bc178da
·
verified ·
1 Parent(s): c3f5ed6

Update mcp/orchestrator.py

Browse files
Files changed (1) hide show
  1. mcp/orchestrator.py +36 -31
mcp/orchestrator.py CHANGED
@@ -1,10 +1,9 @@
1
  # mcp/orchestrator.py
2
  import asyncio
3
- from typing import Any, Dict
4
  from mcp.arxiv import fetch_arxiv
5
  from mcp.pubmed import fetch_pubmed
6
  from mcp.nlp import extract_umls_concepts
7
- from mcp.umls import lookup_umls
8
  from mcp.umls_rel import fetch_relations
9
  from mcp.openfda import fetch_drug_safety
10
  from mcp.ncbi import search_gene, get_mesh_definition
@@ -17,63 +16,69 @@ from mcp.openai_utils import ai_summarize, ai_qa
17
  from mcp.gemini import gemini_summarize, gemini_qa
18
 
19
  def _get_llm(llm: str):
20
- return (gemini_summarize, gemini_qa) if llm.lower()=="gemini" else (ai_summarize, ai_qa)
21
 
22
- async def orchestrate_search(query: str, llm: str="openai") -> Dict[str,Any]:
23
- # 1) literature
24
  arxiv_t, pubmed_t = fetch_arxiv(query), fetch_pubmed(query)
25
  papers = []
26
  for res in await asyncio.gather(arxiv_t, pubmed_t, return_exceptions=True):
27
  if isinstance(res, list):
28
  papers.extend(res)
29
 
30
- # 2) UMLS concept linking
31
  blob = " ".join(p.get("summary","") for p in papers)
32
- umls = extract_umls_concepts(blob)
33
- rels = await asyncio.gather(*[fetch_relations(c["cui"]) for c in umls], return_exceptions=True)
34
 
35
- # 3) enrichment
 
 
 
 
 
 
36
  keys = [c["name"] for c in umls]
37
- fda_t = [fetch_drug_safety(k) for k in keys]
38
- genes_t = search_gene(keys[0]) if keys else asyncio.sleep(0, result=[])
39
- mesh_t = get_mesh_definition(keys[0]) if keys else asyncio.sleep(0, result="")
40
- dis_t = disease_to_genes(keys[0]) if keys else asyncio.sleep(0, result=[])
41
- trials_t = search_trials(query)
42
- ot_t = ot.fetch(keys[0]) if keys else asyncio.sleep(0, result=[])
43
- var_t = cbio.fetch_variants(keys[0]) if keys else asyncio.sleep(0, result=[])
44
 
45
- fda, genes, mesh, dis, trials, ot_assoc, variants = await asyncio.gather(
46
- asyncio.gather(*fda_t, return_exceptions=True),
47
- genes_t, mesh_t, dis_t, trials_t, ot_t, var_t,
 
48
  return_exceptions=False
49
  )
50
 
51
- # 4) AI summary
52
  summarize, _ = _get_llm(llm)
53
  try:
54
- summary = await summarize(blob)
55
- except:
56
- summary = "LLM unavailable."
57
 
58
  return {
59
  "papers": papers,
60
  "umls": umls,
61
  "umls_relations": rels,
62
  "drug_safety": fda,
63
- "genes": [genes],
64
  "mesh_defs": [mesh],
65
  "gene_disease": dis,
66
  "clinical_trials": trials,
67
  "ot_associations": ot_assoc,
68
  "variants": variants,
69
- "ai_summary": summary,
70
  "llm_used": llm.lower()
71
  }
72
 
73
- async def answer_ai_question(question: str, context: str="", llm: str="openai"):
74
- _, qa = _get_llm(llm)
75
  try:
76
- ans = await qa(question, context)
77
- except:
78
- ans = "LLM unavailable."
79
- return {"answer": ans}
 
1
  # mcp/orchestrator.py
2
  import asyncio
3
+ from typing import Dict, Any
4
  from mcp.arxiv import fetch_arxiv
5
  from mcp.pubmed import fetch_pubmed
6
  from mcp.nlp import extract_umls_concepts
 
7
  from mcp.umls_rel import fetch_relations
8
  from mcp.openfda import fetch_drug_safety
9
  from mcp.ncbi import search_gene, get_mesh_definition
 
16
  from mcp.gemini import gemini_summarize, gemini_qa
17
 
18
  def _get_llm(llm: str):
19
+ return (gemini_summarize, gemini_qa) if llm.lower() == "gemini" else (ai_summarize, ai_qa)
20
 
21
+ async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]:
22
+ # 1) Parallel literature pulls
23
  arxiv_t, pubmed_t = fetch_arxiv(query), fetch_pubmed(query)
24
  papers = []
25
  for res in await asyncio.gather(arxiv_t, pubmed_t, return_exceptions=True):
26
  if isinstance(res, list):
27
  papers.extend(res)
28
 
29
+ # 2) SpaCy→UMLS concept linking
30
  blob = " ".join(p.get("summary","") for p in papers)
31
+ umls = await extract_umls_concepts(blob)
 
32
 
33
+ # 3) Fetch UMLS relations in parallel
34
+ rels = await asyncio.gather(
35
+ *[fetch_relations(c["cui"]) for c in umls],
36
+ return_exceptions=True
37
+ )
38
+
39
+ # 4) Enrich: OpenFDA, NCBI, DisGeNET, Trials, OpenTargets, cBioPortal
40
  keys = [c["name"] for c in umls]
41
+ fda_tasks = [fetch_drug_safety(k) for k in keys]
42
+ gene_task = search_gene(keys[0]) if keys else asyncio.sleep(0, result=[])
43
+ mesh_task = get_mesh_definition(keys[0]) if keys else asyncio.sleep(0, result="")
44
+ dis_task = disease_to_genes(keys[0]) if keys else asyncio.sleep(0, result=[])
45
+ trials_task = search_trials(query)
46
+ ot_task = ot.fetch(keys[0]) if keys else asyncio.sleep(0, result=[])
47
+ cbio_task = cbio.fetch_variants(keys[0]) if keys else asyncio.sleep(0, result=[])
48
 
49
+ fda, gene, mesh, dis, trials, ot_assoc, variants = await asyncio.gather(
50
+ asyncio.gather(*fda_tasks, return_exceptions=True),
51
+ gene_task, mesh_task, dis_task,
52
+ trials_task, ot_task, cbio_task,
53
  return_exceptions=False
54
  )
55
 
56
+ # 5) AI summary
57
  summarize, _ = _get_llm(llm)
58
  try:
59
+ ai_summary = await summarize(blob)
60
+ except Exception:
61
+ ai_summary = "LLM summary failed."
62
 
63
  return {
64
  "papers": papers,
65
  "umls": umls,
66
  "umls_relations": rels,
67
  "drug_safety": fda,
68
+ "genes": [gene],
69
  "mesh_defs": [mesh],
70
  "gene_disease": dis,
71
  "clinical_trials": trials,
72
  "ot_associations": ot_assoc,
73
  "variants": variants,
74
+ "ai_summary": ai_summary,
75
  "llm_used": llm.lower()
76
  }
77
 
78
+ async def answer_ai_question(question: str, context: str = "", llm: str = "openai"):
79
+ _, qa_fn = _get_llm(llm)
80
  try:
81
+ answer = await qa_fn(question, context)
82
+ except Exception:
83
+ answer = "LLM follow-up failed."
84
+ return {"answer": answer}