MCP_Res / mcp /orchestrator.py
mgbam's picture
Update mcp/orchestrator.py
afa570e verified
raw
history blame
4.89 kB
import asyncio
import httpx
from typing import Any, Dict, List, Literal, Union
from mcp.pubmed import fetch_pubmed
from mcp.arxiv import fetch_arxiv
from mcp.umls import extract_umls_concepts
from mcp.openfda import fetch_drug_safety
from mcp.ncbi import search_gene, get_mesh_definition
from mcp.mygene import fetch_gene_info
from mcp.ensembl import fetch_ensembl
from mcp.opentargets import fetch_ot
from mcp.clinicaltrials import search_trials
from mcp.cbio import fetch_cbio
from mcp.disgenet import disease_to_genes
from mcp.gemini import gemini_summarize, gemini_qa
from mcp.openai_utils import ai_summarize, ai_qa
async def _safe_call(
func: Any,
*args,
default: Any = None,
**kwargs,
) -> Any:
"""
Safely call an async function, returning a default on HTTP or other failures.
"""
try:
return await func(*args, **kwargs) # type: ignore
except httpx.HTTPStatusError:
return default
except Exception:
return default
async def _gather_tasks(tasks: List[asyncio.Task]) -> List[Any]:
"""
Await a list of asyncio.Tasks and return their results in order.
"""
return await asyncio.gather(*tasks)
def _flatten_unique(items: List[Union[List[Any], Any]]) -> List[Any]:
"""
Flatten a list of items where elements may be lists or single values,
then deduplicate preserving order.
"""
flat: List[Any] = []
for elem in items:
if isinstance(elem, list):
for x in elem:
if x not in flat:
flat.append(x)
else:
if elem is not None and elem not in flat:
flat.append(elem)
return flat
async def orchestrate_search(
query: str,
llm: Literal['openai', 'gemini'] = 'openai',
max_papers: int = 7,
max_trials: int = 10,
) -> Dict[str, Any]:
"""
Perform a comprehensive biomedical search pipeline with fault tolerance:
- Literature (PubMed + arXiv)
- Entity extraction (UMLS)
- Drug safety, gene & variant info, disease-gene mapping
- Clinical trials, cBioPortal data
- AI-driven summary
Individual fetch functions that fail with an HTTP error will return an empty default,
ensuring the pipeline always completes.
"""
tasks = {
'pubmed': asyncio.create_task(fetch_pubmed(query, max_results=max_papers)),
'arxiv': asyncio.create_task(fetch_arxiv(query, max_results=max_papers)),
'umls': asyncio.create_task(
asyncio.to_thread(extract_umls_concepts, query)
),
'drug_safety': asyncio.create_task(_safe_call(fetch_drug_safety, query, default=[])),
'ncbi_gene': asyncio.create_task(_safe_call(search_gene, query, default=[])),
'mygene': asyncio.create_task(_safe_call(fetch_gene_info, query, default=[])),
'ensembl': asyncio.create_task(_safe_call(fetch_ensembl, query, default=[])),
'opentargets': asyncio.create_task(_safe_call(fetch_ot, query, default=[])),
'mesh': asyncio.create_task(_safe_call(get_mesh_definition, query, default="")),
'trials': asyncio.create_task(_safe_call(search_trials, query, default=[], max_studies=max_trials)),
'cbio': asyncio.create_task(_safe_call(fetch_cbio, query, default=[])),
'disgenet': asyncio.create_task(_safe_call(disease_to_genes, query, default=[])),
}
results = await _gather_tasks(list(tasks.values()))
data = dict(zip(tasks.keys(), results))
gene_sources = [data['ncbi_gene'], data['mygene'], data['ensembl'], data['opentargets']]
genes = _flatten_unique(gene_sources)
papers = (data['pubmed'] or []) + (data['arxiv'] or [])
summaries = " ".join(p.get('summary', '') for p in papers)
if llm == 'gemini':
ai_summary = await gemini_summarize(summaries)
llm_used = 'gemini'
else:
ai_summary = await ai_summarize(summaries)
llm_used = 'openai'
return {
'papers': papers,
'genes': genes,
'umls': data['umls'] or [],
'gene_disease': data['disgenet'] or [],
'mesh_defs': [data['mesh']] if data['mesh'] else [],
'drug_safety': data['drug_safety'] or [],
'clinical_trials': data['trials'] or [],
'variants': data['cbio'] or [],
'ai_summary': ai_summary,
'llm_used': llm_used,
}
async def answer_ai_question(
question: str,
context: str = "",
llm: Literal['openai', 'gemini'] = 'openai',
) -> Dict[str, str]:
"""
Answer a free-text question using the specified LLM, with fallback.
Returns {'answer': text}.
"""
try:
if llm == 'gemini':
answer = await gemini_qa(question, context)
else:
answer = await ai_qa(question, context)
except Exception as e:
answer = f"LLM unavailable or quota exceeded: {e}"
return {'answer': answer}