mgbam's picture
Update genesis/tools.py
e22ad8c verified
raw
history blame
6.08 kB
from __future__ import annotations
import os, json
import httpx
from typing import Any, Dict, List
class ToolBase:
name: str = "tool"
description: str = ""
async def call(self, *args, **kwargs) -> Dict[str, Any]:
raise NotImplementedError
# β€” Ontology normalization (BioPortal)
class OntologyTool(ToolBase):
name = "ontology_normalize"
description = "Normalize biomedical terms via BioPortal; returns concept info (no protocols)."
def __init__(self, timeout: float = 20.0):
self.http = httpx.AsyncClient(timeout=timeout)
self.bioportal_key = os.getenv("BIOPORTAL_API_KEY")
async def call(self, term: str) -> dict:
out = {"term": term, "bioportal": None}
try:
if self.bioportal_key:
r = await self.http.get(
"https://data.bioontology.org/search",
params={"q": term, "pagesize": 5},
headers={"Authorization": f"apikey token={self.bioportal_key}"},
)
out["bioportal"] = r.json()
except Exception as e:
out["bioportal_error"] = str(e)
return out
# β€” PubMed search (NCBI E-utilities)
class PubMedTool(ToolBase):
name = "pubmed_search"
description = "Search PubMed via NCBI; return metadata with citations."
def __init__(self, timeout: float = 20.0):
self.http = httpx.AsyncClient(timeout=timeout)
self.key = os.getenv("NCBI_API_KEY")
self.email = os.getenv("NCBI_EMAIL")
async def call(self, query: str) -> dict:
base = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/"
try:
es = await self.http.get(
base + "esearch.fcgi",
params={"db":"pubmed","term":query,"retmode":"json","retmax":20,"api_key":self.key,"email":self.email},
)
ids = es.json().get("esearchresult", {}).get("idlist", [])
if not ids:
return {"query": query, "results": []}
su = await self.http.get(
base + "esummary.fcgi",
params={"db":"pubmed","id":",".join(ids),"retmode":"json","api_key":self.key,"email":self.email},
)
recs = su.json().get("result", {})
items = []
for pmid in ids:
r = recs.get(pmid, {})
items.append({
"pmid": pmid,
"title": r.get("title"),
"journal": r.get("fulljournalname"),
"year": (r.get("pubdate") or "")[:4],
"authors": [a.get("name") for a in r.get("authors", [])],
})
return {"query": query, "results": items}
except Exception as e:
return {"query": query, "error": str(e)}
# β€” RCSB structure metadata
class StructureTool(ToolBase):
name = "structure_info"
description = "Query RCSB structure metadata (no lab steps)."
def __init__(self, timeout: float = 20.0):
self.http = httpx.AsyncClient(timeout=timeout)
async def call(self, pdb_id: str) -> dict:
out = {"pdb_id": pdb_id}
try:
r = await self.http.get(f"https://data.rcsb.org/rest/v1/core/entry/{pdb_id}")
r.raise_for_status()
out["rcsb_core"] = r.json()
except Exception as e:
out["error"] = str(e)
return out
# β€” Crossref DOIs
class CrossrefTool(ToolBase):
name = "crossref_search"
description = "Crossref search for DOIs; titles, years, authors."
def __init__(self, timeout: float = 20.0):
self.http = httpx.AsyncClient(timeout=timeout)
async def call(self, query: str) -> dict:
try:
r = await self.http.get("https://api.crossref.org/works", params={"query": query, "rows": 10})
items = r.json().get("message", {}).get("items", [])
papers = []
for it in items:
papers.append({
"title": (it.get("title") or [None])[0],
"doi": it.get("DOI"),
"year": (it.get("issued") or {}).get("date-parts", [[None]])[0][0],
"authors": [f"{a.get('given','')} {a.get('family','')}".strip() for a in it.get("author", [])],
})
return {"query": query, "results": papers}
except Exception as e:
return {"query": query, "error": str(e)}
# β€” HF Inference API Reranker (optional)
class HFRerankTool(ToolBase):
name = "hf_rerank"
description = "Rerank documents using a Hugging Face reranker model (API)."
def __init__(self, model_id: str):
self.model = model_id
self.hf_token = os.getenv("HF_TOKEN")
self.http = httpx.AsyncClient(timeout=30.0)
async def call(self, query: str, documents: List[str]) -> dict:
if not self.hf_token:
return {"error": "HF_TOKEN not set"}
try:
# Generic payload; different models may expect different schemas β€” keep robust.
payload = {"inputs": {"query": query, "texts": documents}}
r = await self.http.post(
f"https://api-inference.huggingface.co/models/{self.model}",
headers={"Authorization": f"Bearer {self.hf_token}"},
json=payload,
)
data = r.json()
# Try to interpret scores
scores = []
if isinstance(data, dict) and "scores" in data:
scores = data["scores"]
elif isinstance(data, list) and data and isinstance(data[0], dict) and "score" in data[0]:
scores = [x.get("score", 0.0) for x in data]
else:
# Fallback: equal scores
scores = [1.0 for _ in documents]
# Sort indices by score desc
order = sorted(range(len(documents)), key=lambda i: scores[i], reverse=True)
return {"order": order, "scores": scores, "raw": data}
except Exception as e:
return {"error": str(e)}