Update mcp/orchestrator.py
Browse files- mcp/orchestrator.py +8 -12
mcp/orchestrator.py
CHANGED
@@ -16,6 +16,7 @@ from mcp.disgenet import disease_to_genes
|
|
16 |
from mcp.gemini import gemini_summarize, gemini_qa
|
17 |
from mcp.openai_utils import ai_summarize, ai_qa
|
18 |
|
|
|
19 |
async def _safe_call(
|
20 |
func: Any,
|
21 |
*args,
|
@@ -26,12 +27,13 @@ async def _safe_call(
|
|
26 |
Safely call an async function, returning a default on HTTP or other failures.
|
27 |
"""
|
28 |
try:
|
29 |
-
return await func(*args, **kwargs)
|
30 |
except httpx.HTTPStatusError:
|
31 |
return default
|
32 |
except Exception:
|
33 |
return default
|
34 |
|
|
|
35 |
async def _gather_tasks(tasks: List[asyncio.Task]) -> List[Any]:
|
36 |
"""
|
37 |
Await a list of asyncio.Tasks and return their results in order.
|
@@ -45,19 +47,17 @@ def _flatten_unique(items: List[Union[List[Any], Any]]) -> List[Any]:
|
|
45 |
then deduplicate preserving order.
|
46 |
"""
|
47 |
flat: List[Any] = []
|
48 |
-
seen = set()
|
49 |
for elem in items:
|
50 |
if isinstance(elem, list):
|
51 |
for x in elem:
|
52 |
-
if x not in
|
53 |
-
seen.add(x)
|
54 |
flat.append(x)
|
55 |
-
|
56 |
-
if elem not in
|
57 |
-
seen.add(elem)
|
58 |
flat.append(elem)
|
59 |
return flat
|
60 |
|
|
|
61 |
async def orchestrate_search(
|
62 |
query: str,
|
63 |
llm: Literal['openai', 'gemini'] = 'openai',
|
@@ -75,7 +75,6 @@ async def orchestrate_search(
|
|
75 |
Individual fetch functions that fail with an HTTP error will return an empty default,
|
76 |
ensuring the pipeline always completes.
|
77 |
"""
|
78 |
-
# Launch parallel tasks with safe wrapper for potential HTTP errors
|
79 |
tasks = {
|
80 |
'pubmed': asyncio.create_task(fetch_pubmed(query, max_results=max_papers)),
|
81 |
'arxiv': asyncio.create_task(fetch_arxiv(query, max_results=max_papers)),
|
@@ -93,18 +92,14 @@ async def orchestrate_search(
|
|
93 |
'disgenet': asyncio.create_task(_safe_call(disease_to_genes, query, default=[])),
|
94 |
}
|
95 |
|
96 |
-
# Await all tasks
|
97 |
results = await _gather_tasks(list(tasks.values()))
|
98 |
data = dict(zip(tasks.keys(), results))
|
99 |
|
100 |
-
# Consolidate gene sources
|
101 |
gene_sources = [data['ncbi_gene'], data['mygene'], data['ensembl'], data['opentargets']]
|
102 |
genes = _flatten_unique(gene_sources)
|
103 |
|
104 |
-
# Merge literature results
|
105 |
papers = (data['pubmed'] or []) + (data['arxiv'] or [])
|
106 |
|
107 |
-
# AI-driven summary
|
108 |
summaries = " ".join(p.get('summary', '') for p in papers)
|
109 |
if llm == 'gemini':
|
110 |
ai_summary = await gemini_summarize(summaries)
|
@@ -126,6 +121,7 @@ async def orchestrate_search(
|
|
126 |
'llm_used': llm_used,
|
127 |
}
|
128 |
|
|
|
129 |
async def answer_ai_question(
|
130 |
question: str,
|
131 |
context: str = "",
|
|
|
16 |
from mcp.gemini import gemini_summarize, gemini_qa
|
17 |
from mcp.openai_utils import ai_summarize, ai_qa
|
18 |
|
19 |
+
|
20 |
async def _safe_call(
|
21 |
func: Any,
|
22 |
*args,
|
|
|
27 |
Safely call an async function, returning a default on HTTP or other failures.
|
28 |
"""
|
29 |
try:
|
30 |
+
return await func(*args, **kwargs) # type: ignore
|
31 |
except httpx.HTTPStatusError:
|
32 |
return default
|
33 |
except Exception:
|
34 |
return default
|
35 |
|
36 |
+
|
37 |
async def _gather_tasks(tasks: List[asyncio.Task]) -> List[Any]:
|
38 |
"""
|
39 |
Await a list of asyncio.Tasks and return their results in order.
|
|
|
47 |
then deduplicate preserving order.
|
48 |
"""
|
49 |
flat: List[Any] = []
|
|
|
50 |
for elem in items:
|
51 |
if isinstance(elem, list):
|
52 |
for x in elem:
|
53 |
+
if x not in flat:
|
|
|
54 |
flat.append(x)
|
55 |
+
else:
|
56 |
+
if elem is not None and elem not in flat:
|
|
|
57 |
flat.append(elem)
|
58 |
return flat
|
59 |
|
60 |
+
|
61 |
async def orchestrate_search(
|
62 |
query: str,
|
63 |
llm: Literal['openai', 'gemini'] = 'openai',
|
|
|
75 |
Individual fetch functions that fail with an HTTP error will return an empty default,
|
76 |
ensuring the pipeline always completes.
|
77 |
"""
|
|
|
78 |
tasks = {
|
79 |
'pubmed': asyncio.create_task(fetch_pubmed(query, max_results=max_papers)),
|
80 |
'arxiv': asyncio.create_task(fetch_arxiv(query, max_results=max_papers)),
|
|
|
92 |
'disgenet': asyncio.create_task(_safe_call(disease_to_genes, query, default=[])),
|
93 |
}
|
94 |
|
|
|
95 |
results = await _gather_tasks(list(tasks.values()))
|
96 |
data = dict(zip(tasks.keys(), results))
|
97 |
|
|
|
98 |
gene_sources = [data['ncbi_gene'], data['mygene'], data['ensembl'], data['opentargets']]
|
99 |
genes = _flatten_unique(gene_sources)
|
100 |
|
|
|
101 |
papers = (data['pubmed'] or []) + (data['arxiv'] or [])
|
102 |
|
|
|
103 |
summaries = " ".join(p.get('summary', '') for p in papers)
|
104 |
if llm == 'gemini':
|
105 |
ai_summary = await gemini_summarize(summaries)
|
|
|
121 |
'llm_used': llm_used,
|
122 |
}
|
123 |
|
124 |
+
|
125 |
async def answer_ai_question(
|
126 |
question: str,
|
127 |
context: str = "",
|