File size: 4,888 Bytes
2c2342d
24a46bd
2a8cf8d
 
2c2342d
 
 
 
 
 
 
 
 
 
2a8cf8d
2c2342d
 
 
afa570e
b3cf05a
24a46bd
 
 
 
 
 
 
 
 
afa570e
24a46bd
 
 
 
 
afa570e
2a8cf8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afa570e
2a8cf8d
afa570e
 
2a8cf8d
 
 
afa570e
2a8cf8d
 
 
 
 
 
 
24a46bd
2a8cf8d
 
 
 
 
 
24a46bd
 
2a8cf8d
 
 
 
 
 
 
24a46bd
 
 
 
 
 
 
 
 
2a8cf8d
 
 
 
 
 
 
 
 
 
 
 
 
 
2c2342d
2a8cf8d
 
08c0325
3d539ca
2a8cf8d
 
 
24a46bd
 
2a8cf8d
 
24a46bd
2a8cf8d
 
86771dc
 
afa570e
2a8cf8d
 
 
 
 
 
 
24a46bd
2a8cf8d
2c2342d
2a8cf8d
2c2342d
 
 
 
2a8cf8d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import asyncio
import httpx
from typing import Any, Dict, List, Literal, Union

from mcp.pubmed import fetch_pubmed
from mcp.arxiv import fetch_arxiv
from mcp.umls import extract_umls_concepts
from mcp.openfda import fetch_drug_safety
from mcp.ncbi import search_gene, get_mesh_definition
from mcp.mygene import fetch_gene_info
from mcp.ensembl import fetch_ensembl
from mcp.opentargets import fetch_ot
from mcp.clinicaltrials import search_trials
from mcp.cbio import fetch_cbio
from mcp.disgenet import disease_to_genes
from mcp.gemini import gemini_summarize, gemini_qa
from mcp.openai_utils import ai_summarize, ai_qa


async def _safe_call(
    func: Any,
    *args,
    default: Any = None,
    **kwargs,
) -> Any:
    """
    Safely call an async function, returning a default on HTTP or other failures.
    """
    try:
        return await func(*args, **kwargs)  # type: ignore
    except httpx.HTTPStatusError:
        return default
    except Exception:
        return default


async def _gather_tasks(tasks: List[asyncio.Task]) -> List[Any]:
    """
    Await a list of asyncio.Tasks and return their results in order.
    """
    return await asyncio.gather(*tasks)


def _flatten_unique(items: List[Union[List[Any], Any]]) -> List[Any]:
    """
    Flatten a list of items where elements may be lists or single values,
    then deduplicate preserving order.
    """
    flat: List[Any] = []
    for elem in items:
        if isinstance(elem, list):
            for x in elem:
                if x not in flat:
                    flat.append(x)
        else:
            if elem is not None and elem not in flat:
                flat.append(elem)
    return flat


async def orchestrate_search(
    query: str,
    llm: Literal['openai', 'gemini'] = 'openai',
    max_papers: int = 7,
    max_trials: int = 10,
) -> Dict[str, Any]:
    """
    Perform a comprehensive biomedical search pipeline with fault tolerance:
      - Literature (PubMed + arXiv)
      - Entity extraction (UMLS)
      - Drug safety, gene & variant info, disease-gene mapping
      - Clinical trials, cBioPortal data
      - AI-driven summary

    Individual fetch functions that fail with an HTTP error will return an empty default,
    ensuring the pipeline always completes.
    """
    tasks = {
        'pubmed': asyncio.create_task(fetch_pubmed(query, max_results=max_papers)),
        'arxiv': asyncio.create_task(fetch_arxiv(query, max_results=max_papers)),
        'umls': asyncio.create_task(
            asyncio.to_thread(extract_umls_concepts, query)
        ),
        'drug_safety': asyncio.create_task(_safe_call(fetch_drug_safety, query, default=[])),
        'ncbi_gene': asyncio.create_task(_safe_call(search_gene, query, default=[])),
        'mygene': asyncio.create_task(_safe_call(fetch_gene_info, query, default=[])),
        'ensembl': asyncio.create_task(_safe_call(fetch_ensembl, query, default=[])),
        'opentargets': asyncio.create_task(_safe_call(fetch_ot, query, default=[])),
        'mesh': asyncio.create_task(_safe_call(get_mesh_definition, query, default="")),
        'trials': asyncio.create_task(_safe_call(search_trials, query, default=[], max_studies=max_trials)),
        'cbio': asyncio.create_task(_safe_call(fetch_cbio, query, default=[])),
        'disgenet': asyncio.create_task(_safe_call(disease_to_genes, query, default=[])),
    }

    results = await _gather_tasks(list(tasks.values()))
    data = dict(zip(tasks.keys(), results))

    gene_sources = [data['ncbi_gene'], data['mygene'], data['ensembl'], data['opentargets']]
    genes = _flatten_unique(gene_sources)

    papers = (data['pubmed'] or []) + (data['arxiv'] or [])

    summaries = " ".join(p.get('summary', '') for p in papers)
    if llm == 'gemini':
        ai_summary = await gemini_summarize(summaries)
        llm_used = 'gemini'
    else:
        ai_summary = await ai_summarize(summaries)
        llm_used = 'openai'

    return {
        'papers': papers,
        'genes': genes,
        'umls': data['umls'] or [],
        'gene_disease': data['disgenet'] or [],
        'mesh_defs': [data['mesh']] if data['mesh'] else [],
        'drug_safety': data['drug_safety'] or [],
        'clinical_trials': data['trials'] or [],
        'variants': data['cbio'] or [],
        'ai_summary': ai_summary,
        'llm_used': llm_used,
    }


async def answer_ai_question(
    question: str,
    context: str = "",
    llm: Literal['openai', 'gemini'] = 'openai',
) -> Dict[str, str]:
    """
    Answer a free-text question using the specified LLM, with fallback.
    Returns {'answer': text}.
    """
    try:
        if llm == 'gemini':
            answer = await gemini_qa(question, context)
        else:
            answer = await ai_qa(question, context)
    except Exception as e:
        answer = f"LLM unavailable or quota exceeded: {e}"
    return {'answer': answer}