from __future__ import annotations import os, json import httpx from typing import Any, Dict, List class ToolBase: name: str = "tool" description: str = "" async def call(self, *args, **kwargs) -> Dict[str, Any]: raise NotImplementedError # — Ontology normalization (BioPortal) class OntologyTool(ToolBase): name = "ontology_normalize" description = "Normalize biomedical terms via BioPortal; returns concept info (no protocols)." def __init__(self, timeout: float = 20.0): self.http = httpx.AsyncClient(timeout=timeout) self.bioportal_key = os.getenv("BIOPORTAL_API_KEY") async def call(self, term: str) -> dict: out = {"term": term, "bioportal": None} try: if self.bioportal_key: r = await self.http.get( "https://data.bioontology.org/search", params={"q": term, "pagesize": 5}, headers={"Authorization": f"apikey token={self.bioportal_key}"}, ) out["bioportal"] = r.json() except Exception as e: out["bioportal_error"] = str(e) return out # — PubMed search (NCBI E-utilities) class PubMedTool(ToolBase): name = "pubmed_search" description = "Search PubMed via NCBI; return metadata with citations." def __init__(self, timeout: float = 20.0): self.http = httpx.AsyncClient(timeout=timeout) self.key = os.getenv("NCBI_API_KEY") self.email = os.getenv("NCBI_EMAIL") async def call(self, query: str) -> dict: base = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/" try: es = await self.http.get( base + "esearch.fcgi", params={"db":"pubmed","term":query,"retmode":"json","retmax":20,"api_key":self.key,"email":self.email}, ) ids = es.json().get("esearchresult", {}).get("idlist", []) if not ids: return {"query": query, "results": []} su = await self.http.get( base + "esummary.fcgi", params={"db":"pubmed","id":",".join(ids),"retmode":"json","api_key":self.key,"email":self.email}, ) recs = su.json().get("result", {}) items = [] for pmid in ids: r = recs.get(pmid, {}) items.append({ "pmid": pmid, "title": r.get("title"), "journal": r.get("fulljournalname"), "year": (r.get("pubdate") or "")[:4], "authors": [a.get("name") for a in r.get("authors", [])], }) return {"query": query, "results": items} except Exception as e: return {"query": query, "error": str(e)} # — RCSB structure metadata class StructureTool(ToolBase): name = "structure_info" description = "Query RCSB structure metadata (no lab steps)." def __init__(self, timeout: float = 20.0): self.http = httpx.AsyncClient(timeout=timeout) async def call(self, pdb_id: str) -> dict: out = {"pdb_id": pdb_id} try: r = await self.http.get(f"https://data.rcsb.org/rest/v1/core/entry/{pdb_id}") r.raise_for_status() out["rcsb_core"] = r.json() except Exception as e: out["error"] = str(e) return out # — Crossref DOIs class CrossrefTool(ToolBase): name = "crossref_search" description = "Crossref search for DOIs; titles, years, authors." def __init__(self, timeout: float = 20.0): self.http = httpx.AsyncClient(timeout=timeout) async def call(self, query: str) -> dict: try: r = await self.http.get("https://api.crossref.org/works", params={"query": query, "rows": 10}) items = r.json().get("message", {}).get("items", []) papers = [] for it in items: papers.append({ "title": (it.get("title") or [None])[0], "doi": it.get("DOI"), "year": (it.get("issued") or {}).get("date-parts", [[None]])[0][0], "authors": [f"{a.get('given','')} {a.get('family','')}".strip() for a in it.get("author", [])], }) return {"query": query, "results": papers} except Exception as e: return {"query": query, "error": str(e)} # — HF Inference API Reranker (optional) class HFRerankTool(ToolBase): name = "hf_rerank" description = "Rerank documents using a Hugging Face reranker model (API)." def __init__(self, model_id: str): self.model = model_id self.hf_token = os.getenv("HF_TOKEN") self.http = httpx.AsyncClient(timeout=30.0) async def call(self, query: str, documents: List[str]) -> dict: if not self.hf_token: return {"error": "HF_TOKEN not set"} try: # Generic payload; different models may expect different schemas — keep robust. payload = {"inputs": {"query": query, "texts": documents}} r = await self.http.post( f"https://api-inference.huggingface.co/models/{self.model}", headers={"Authorization": f"Bearer {self.hf_token}"}, json=payload, ) data = r.json() # Try to interpret scores scores = [] if isinstance(data, dict) and "scores" in data: scores = data["scores"] elif isinstance(data, list) and data and isinstance(data[0], dict) and "score" in data[0]: scores = [x.get("score", 0.0) for x in data] else: # Fallback: equal scores scores = [1.0 for _ in documents] # Sort indices by score desc order = sorted(range(len(documents)), key=lambda i: scores[i], reverse=True) return {"order": order, "scores": scores, "raw": data} except Exception as e: return {"error": str(e)}