File size: 4,888 Bytes
2c2342d 24a46bd 2a8cf8d 2c2342d 2a8cf8d 2c2342d afa570e b3cf05a 24a46bd afa570e 24a46bd afa570e 2a8cf8d afa570e 2a8cf8d afa570e 2a8cf8d afa570e 2a8cf8d 24a46bd 2a8cf8d 24a46bd 2a8cf8d 24a46bd 2a8cf8d 2c2342d 2a8cf8d 08c0325 3d539ca 2a8cf8d 24a46bd 2a8cf8d 24a46bd 2a8cf8d 86771dc afa570e 2a8cf8d 24a46bd 2a8cf8d 2c2342d 2a8cf8d 2c2342d 2a8cf8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import asyncio
import httpx
from typing import Any, Dict, List, Literal, Union
from mcp.pubmed import fetch_pubmed
from mcp.arxiv import fetch_arxiv
from mcp.umls import extract_umls_concepts
from mcp.openfda import fetch_drug_safety
from mcp.ncbi import search_gene, get_mesh_definition
from mcp.mygene import fetch_gene_info
from mcp.ensembl import fetch_ensembl
from mcp.opentargets import fetch_ot
from mcp.clinicaltrials import search_trials
from mcp.cbio import fetch_cbio
from mcp.disgenet import disease_to_genes
from mcp.gemini import gemini_summarize, gemini_qa
from mcp.openai_utils import ai_summarize, ai_qa
async def _safe_call(
func: Any,
*args,
default: Any = None,
**kwargs,
) -> Any:
"""
Safely call an async function, returning a default on HTTP or other failures.
"""
try:
return await func(*args, **kwargs) # type: ignore
except httpx.HTTPStatusError:
return default
except Exception:
return default
async def _gather_tasks(tasks: List[asyncio.Task]) -> List[Any]:
"""
Await a list of asyncio.Tasks and return their results in order.
"""
return await asyncio.gather(*tasks)
def _flatten_unique(items: List[Union[List[Any], Any]]) -> List[Any]:
"""
Flatten a list of items where elements may be lists or single values,
then deduplicate preserving order.
"""
flat: List[Any] = []
for elem in items:
if isinstance(elem, list):
for x in elem:
if x not in flat:
flat.append(x)
else:
if elem is not None and elem not in flat:
flat.append(elem)
return flat
async def orchestrate_search(
query: str,
llm: Literal['openai', 'gemini'] = 'openai',
max_papers: int = 7,
max_trials: int = 10,
) -> Dict[str, Any]:
"""
Perform a comprehensive biomedical search pipeline with fault tolerance:
- Literature (PubMed + arXiv)
- Entity extraction (UMLS)
- Drug safety, gene & variant info, disease-gene mapping
- Clinical trials, cBioPortal data
- AI-driven summary
Individual fetch functions that fail with an HTTP error will return an empty default,
ensuring the pipeline always completes.
"""
tasks = {
'pubmed': asyncio.create_task(fetch_pubmed(query, max_results=max_papers)),
'arxiv': asyncio.create_task(fetch_arxiv(query, max_results=max_papers)),
'umls': asyncio.create_task(
asyncio.to_thread(extract_umls_concepts, query)
),
'drug_safety': asyncio.create_task(_safe_call(fetch_drug_safety, query, default=[])),
'ncbi_gene': asyncio.create_task(_safe_call(search_gene, query, default=[])),
'mygene': asyncio.create_task(_safe_call(fetch_gene_info, query, default=[])),
'ensembl': asyncio.create_task(_safe_call(fetch_ensembl, query, default=[])),
'opentargets': asyncio.create_task(_safe_call(fetch_ot, query, default=[])),
'mesh': asyncio.create_task(_safe_call(get_mesh_definition, query, default="")),
'trials': asyncio.create_task(_safe_call(search_trials, query, default=[], max_studies=max_trials)),
'cbio': asyncio.create_task(_safe_call(fetch_cbio, query, default=[])),
'disgenet': asyncio.create_task(_safe_call(disease_to_genes, query, default=[])),
}
results = await _gather_tasks(list(tasks.values()))
data = dict(zip(tasks.keys(), results))
gene_sources = [data['ncbi_gene'], data['mygene'], data['ensembl'], data['opentargets']]
genes = _flatten_unique(gene_sources)
papers = (data['pubmed'] or []) + (data['arxiv'] or [])
summaries = " ".join(p.get('summary', '') for p in papers)
if llm == 'gemini':
ai_summary = await gemini_summarize(summaries)
llm_used = 'gemini'
else:
ai_summary = await ai_summarize(summaries)
llm_used = 'openai'
return {
'papers': papers,
'genes': genes,
'umls': data['umls'] or [],
'gene_disease': data['disgenet'] or [],
'mesh_defs': [data['mesh']] if data['mesh'] else [],
'drug_safety': data['drug_safety'] or [],
'clinical_trials': data['trials'] or [],
'variants': data['cbio'] or [],
'ai_summary': ai_summary,
'llm_used': llm_used,
}
async def answer_ai_question(
question: str,
context: str = "",
llm: Literal['openai', 'gemini'] = 'openai',
) -> Dict[str, str]:
"""
Answer a free-text question using the specified LLM, with fallback.
Returns {'answer': text}.
"""
try:
if llm == 'gemini':
answer = await gemini_qa(question, context)
else:
answer = await ai_qa(question, context)
except Exception as e:
answer = f"LLM unavailable or quota exceeded: {e}"
return {'answer': answer}
|