mgbam commited on
Commit
afa570e
·
verified ·
1 Parent(s): b3cf05a

Update mcp/orchestrator.py

Browse files
Files changed (1) hide show
  1. mcp/orchestrator.py +8 -12
mcp/orchestrator.py CHANGED
@@ -16,6 +16,7 @@ from mcp.disgenet import disease_to_genes
16
  from mcp.gemini import gemini_summarize, gemini_qa
17
  from mcp.openai_utils import ai_summarize, ai_qa
18
 
 
19
  async def _safe_call(
20
  func: Any,
21
  *args,
@@ -26,12 +27,13 @@ async def _safe_call(
26
  Safely call an async function, returning a default on HTTP or other failures.
27
  """
28
  try:
29
- return await func(*args, **kwargs)
30
  except httpx.HTTPStatusError:
31
  return default
32
  except Exception:
33
  return default
34
 
 
35
  async def _gather_tasks(tasks: List[asyncio.Task]) -> List[Any]:
36
  """
37
  Await a list of asyncio.Tasks and return their results in order.
@@ -45,19 +47,17 @@ def _flatten_unique(items: List[Union[List[Any], Any]]) -> List[Any]:
45
  then deduplicate preserving order.
46
  """
47
  flat: List[Any] = []
48
- seen = set()
49
  for elem in items:
50
  if isinstance(elem, list):
51
  for x in elem:
52
- if x not in seen:
53
- seen.add(x)
54
  flat.append(x)
55
- elif elem is not None:
56
- if elem not in seen:
57
- seen.add(elem)
58
  flat.append(elem)
59
  return flat
60
 
 
61
  async def orchestrate_search(
62
  query: str,
63
  llm: Literal['openai', 'gemini'] = 'openai',
@@ -75,7 +75,6 @@ async def orchestrate_search(
75
  Individual fetch functions that fail with an HTTP error will return an empty default,
76
  ensuring the pipeline always completes.
77
  """
78
- # Launch parallel tasks with safe wrapper for potential HTTP errors
79
  tasks = {
80
  'pubmed': asyncio.create_task(fetch_pubmed(query, max_results=max_papers)),
81
  'arxiv': asyncio.create_task(fetch_arxiv(query, max_results=max_papers)),
@@ -93,18 +92,14 @@ async def orchestrate_search(
93
  'disgenet': asyncio.create_task(_safe_call(disease_to_genes, query, default=[])),
94
  }
95
 
96
- # Await all tasks
97
  results = await _gather_tasks(list(tasks.values()))
98
  data = dict(zip(tasks.keys(), results))
99
 
100
- # Consolidate gene sources
101
  gene_sources = [data['ncbi_gene'], data['mygene'], data['ensembl'], data['opentargets']]
102
  genes = _flatten_unique(gene_sources)
103
 
104
- # Merge literature results
105
  papers = (data['pubmed'] or []) + (data['arxiv'] or [])
106
 
107
- # AI-driven summary
108
  summaries = " ".join(p.get('summary', '') for p in papers)
109
  if llm == 'gemini':
110
  ai_summary = await gemini_summarize(summaries)
@@ -126,6 +121,7 @@ async def orchestrate_search(
126
  'llm_used': llm_used,
127
  }
128
 
 
129
  async def answer_ai_question(
130
  question: str,
131
  context: str = "",
 
16
  from mcp.gemini import gemini_summarize, gemini_qa
17
  from mcp.openai_utils import ai_summarize, ai_qa
18
 
19
+
20
  async def _safe_call(
21
  func: Any,
22
  *args,
 
27
  Safely call an async function, returning a default on HTTP or other failures.
28
  """
29
  try:
30
+ return await func(*args, **kwargs) # type: ignore
31
  except httpx.HTTPStatusError:
32
  return default
33
  except Exception:
34
  return default
35
 
36
+
37
  async def _gather_tasks(tasks: List[asyncio.Task]) -> List[Any]:
38
  """
39
  Await a list of asyncio.Tasks and return their results in order.
 
47
  then deduplicate preserving order.
48
  """
49
  flat: List[Any] = []
 
50
  for elem in items:
51
  if isinstance(elem, list):
52
  for x in elem:
53
+ if x not in flat:
 
54
  flat.append(x)
55
+ else:
56
+ if elem is not None and elem not in flat:
 
57
  flat.append(elem)
58
  return flat
59
 
60
+
61
  async def orchestrate_search(
62
  query: str,
63
  llm: Literal['openai', 'gemini'] = 'openai',
 
75
  Individual fetch functions that fail with an HTTP error will return an empty default,
76
  ensuring the pipeline always completes.
77
  """
 
78
  tasks = {
79
  'pubmed': asyncio.create_task(fetch_pubmed(query, max_results=max_papers)),
80
  'arxiv': asyncio.create_task(fetch_arxiv(query, max_results=max_papers)),
 
92
  'disgenet': asyncio.create_task(_safe_call(disease_to_genes, query, default=[])),
93
  }
94
 
 
95
  results = await _gather_tasks(list(tasks.values()))
96
  data = dict(zip(tasks.keys(), results))
97
 
 
98
  gene_sources = [data['ncbi_gene'], data['mygene'], data['ensembl'], data['opentargets']]
99
  genes = _flatten_unique(gene_sources)
100
 
 
101
  papers = (data['pubmed'] or []) + (data['arxiv'] or [])
102
 
 
103
  summaries = " ".join(p.get('summary', '') for p in papers)
104
  if llm == 'gemini':
105
  ai_summary = await gemini_summarize(summaries)
 
121
  'llm_used': llm_used,
122
  }
123
 
124
+
125
  async def answer_ai_question(
126
  question: str,
127
  context: str = "",