Spaces:
Sleeping
Sleeping
Update genesis/tools.py
Browse files- genesis/tools.py +59 -19
genesis/tools.py
CHANGED
@@ -1,16 +1,15 @@
|
|
1 |
-
|
2 |
from __future__ import annotations
|
3 |
-
import os, json
|
4 |
import httpx
|
5 |
-
from typing import Any, Dict,
|
6 |
|
7 |
class ToolBase:
|
8 |
name: str = "tool"
|
9 |
description: str = ""
|
10 |
-
|
11 |
async def call(self, *args, **kwargs) -> Dict[str, Any]:
|
12 |
raise NotImplementedError
|
13 |
|
|
|
14 |
class OntologyTool(ToolBase):
|
15 |
name = "ontology_normalize"
|
16 |
description = "Normalize biomedical terms via BioPortal; returns concept info (no protocols)."
|
@@ -26,13 +25,14 @@ class OntologyTool(ToolBase):
|
|
26 |
r = await self.http.get(
|
27 |
"https://data.bioontology.org/search",
|
28 |
params={"q": term, "pagesize": 5},
|
29 |
-
headers={"Authorization": f"apikey token={self.bioportal_key}"}
|
30 |
)
|
31 |
out["bioportal"] = r.json()
|
32 |
except Exception as e:
|
33 |
out["bioportal_error"] = str(e)
|
34 |
return out
|
35 |
|
|
|
36 |
class PubMedTool(ToolBase):
|
37 |
name = "pubmed_search"
|
38 |
description = "Search PubMed via NCBI; return metadata with citations."
|
@@ -47,29 +47,31 @@ class PubMedTool(ToolBase):
|
|
47 |
try:
|
48 |
es = await self.http.get(
|
49 |
base + "esearch.fcgi",
|
50 |
-
params={"db":"pubmed","term":query,"retmode":"json","retmax":20,"api_key":self.key,"email":self.email}
|
51 |
)
|
52 |
-
ids = es.json().get("esearchresult",{}).get("idlist",[])
|
53 |
-
if not ids:
|
|
|
54 |
su = await self.http.get(
|
55 |
base + "esummary.fcgi",
|
56 |
-
params={"db":"pubmed","id":",".join(ids),"retmode":"json","api_key":self.key,"email":self.email}
|
57 |
)
|
58 |
-
recs = su.json().get("result",{})
|
59 |
items = []
|
60 |
for pmid in ids:
|
61 |
-
r = recs.get(pmid,{
|
62 |
items.append({
|
63 |
"pmid": pmid,
|
64 |
"title": r.get("title"),
|
65 |
"journal": r.get("fulljournalname"),
|
66 |
"year": (r.get("pubdate") or "")[:4],
|
67 |
-
"authors": [a.get("name") for a in r.get("authors",[])],
|
68 |
})
|
69 |
-
return {"query":query,"results":items}
|
70 |
except Exception as e:
|
71 |
-
return {"query":query,"error":str(e)}
|
72 |
|
|
|
73 |
class StructureTool(ToolBase):
|
74 |
name = "structure_info"
|
75 |
description = "Query RCSB structure metadata (no lab steps)."
|
@@ -87,6 +89,7 @@ class StructureTool(ToolBase):
|
|
87 |
out["error"] = str(e)
|
88 |
return out
|
89 |
|
|
|
90 |
class CrossrefTool(ToolBase):
|
91 |
name = "crossref_search"
|
92 |
description = "Crossref search for DOIs; titles, years, authors."
|
@@ -96,16 +99,53 @@ class CrossrefTool(ToolBase):
|
|
96 |
|
97 |
async def call(self, query: str) -> dict:
|
98 |
try:
|
99 |
-
r = await self.http.get("https://api.crossref.org/works", params={"query":query,"rows":10})
|
100 |
-
items = r.json().get("message",{}).get("items",[])
|
101 |
papers = []
|
102 |
for it in items:
|
103 |
papers.append({
|
104 |
"title": (it.get("title") or [None])[0],
|
105 |
"doi": it.get("DOI"),
|
106 |
"year": (it.get("issued") or {}).get("date-parts", [[None]])[0][0],
|
107 |
-
"authors": [f"{a.get('given','')} {a.get('family','')}".strip() for a in it.get("author",[])],
|
108 |
})
|
109 |
-
return {"query":query,"results":papers}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
except Exception as e:
|
111 |
-
return {"
|
|
|
|
|
1 |
from __future__ import annotations
|
2 |
+
import os, json
|
3 |
import httpx
|
4 |
+
from typing import Any, Dict, List
|
5 |
|
6 |
class ToolBase:
|
7 |
name: str = "tool"
|
8 |
description: str = ""
|
|
|
9 |
async def call(self, *args, **kwargs) -> Dict[str, Any]:
|
10 |
raise NotImplementedError
|
11 |
|
12 |
+
# β Ontology normalization (BioPortal)
|
13 |
class OntologyTool(ToolBase):
|
14 |
name = "ontology_normalize"
|
15 |
description = "Normalize biomedical terms via BioPortal; returns concept info (no protocols)."
|
|
|
25 |
r = await self.http.get(
|
26 |
"https://data.bioontology.org/search",
|
27 |
params={"q": term, "pagesize": 5},
|
28 |
+
headers={"Authorization": f"apikey token={self.bioportal_key}"},
|
29 |
)
|
30 |
out["bioportal"] = r.json()
|
31 |
except Exception as e:
|
32 |
out["bioportal_error"] = str(e)
|
33 |
return out
|
34 |
|
35 |
+
# β PubMed search (NCBI E-utilities)
|
36 |
class PubMedTool(ToolBase):
|
37 |
name = "pubmed_search"
|
38 |
description = "Search PubMed via NCBI; return metadata with citations."
|
|
|
47 |
try:
|
48 |
es = await self.http.get(
|
49 |
base + "esearch.fcgi",
|
50 |
+
params={"db":"pubmed","term":query,"retmode":"json","retmax":20,"api_key":self.key,"email":self.email},
|
51 |
)
|
52 |
+
ids = es.json().get("esearchresult", {}).get("idlist", [])
|
53 |
+
if not ids:
|
54 |
+
return {"query": query, "results": []}
|
55 |
su = await self.http.get(
|
56 |
base + "esummary.fcgi",
|
57 |
+
params={"db":"pubmed","id":",".join(ids),"retmode":"json","api_key":self.key,"email":self.email},
|
58 |
)
|
59 |
+
recs = su.json().get("result", {})
|
60 |
items = []
|
61 |
for pmid in ids:
|
62 |
+
r = recs.get(pmid, {})
|
63 |
items.append({
|
64 |
"pmid": pmid,
|
65 |
"title": r.get("title"),
|
66 |
"journal": r.get("fulljournalname"),
|
67 |
"year": (r.get("pubdate") or "")[:4],
|
68 |
+
"authors": [a.get("name") for a in r.get("authors", [])],
|
69 |
})
|
70 |
+
return {"query": query, "results": items}
|
71 |
except Exception as e:
|
72 |
+
return {"query": query, "error": str(e)}
|
73 |
|
74 |
+
# β RCSB structure metadata
|
75 |
class StructureTool(ToolBase):
|
76 |
name = "structure_info"
|
77 |
description = "Query RCSB structure metadata (no lab steps)."
|
|
|
89 |
out["error"] = str(e)
|
90 |
return out
|
91 |
|
92 |
+
# β Crossref DOIs
|
93 |
class CrossrefTool(ToolBase):
|
94 |
name = "crossref_search"
|
95 |
description = "Crossref search for DOIs; titles, years, authors."
|
|
|
99 |
|
100 |
async def call(self, query: str) -> dict:
|
101 |
try:
|
102 |
+
r = await self.http.get("https://api.crossref.org/works", params={"query": query, "rows": 10})
|
103 |
+
items = r.json().get("message", {}).get("items", [])
|
104 |
papers = []
|
105 |
for it in items:
|
106 |
papers.append({
|
107 |
"title": (it.get("title") or [None])[0],
|
108 |
"doi": it.get("DOI"),
|
109 |
"year": (it.get("issued") or {}).get("date-parts", [[None]])[0][0],
|
110 |
+
"authors": [f"{a.get('given','')} {a.get('family','')}".strip() for a in it.get("author", [])],
|
111 |
})
|
112 |
+
return {"query": query, "results": papers}
|
113 |
+
except Exception as e:
|
114 |
+
return {"query": query, "error": str(e)}
|
115 |
+
|
116 |
+
# β HF Inference API Reranker (optional)
|
117 |
+
class HFRerankTool(ToolBase):
|
118 |
+
name = "hf_rerank"
|
119 |
+
description = "Rerank documents using a Hugging Face reranker model (API)."
|
120 |
+
|
121 |
+
def __init__(self, model_id: str):
|
122 |
+
self.model = model_id
|
123 |
+
self.hf_token = os.getenv("HF_TOKEN")
|
124 |
+
self.http = httpx.AsyncClient(timeout=30.0)
|
125 |
+
|
126 |
+
async def call(self, query: str, documents: List[str]) -> dict:
|
127 |
+
if not self.hf_token:
|
128 |
+
return {"error": "HF_TOKEN not set"}
|
129 |
+
try:
|
130 |
+
# Generic payload; different models may expect different schemas β keep robust.
|
131 |
+
payload = {"inputs": {"query": query, "texts": documents}}
|
132 |
+
r = await self.http.post(
|
133 |
+
f"https://api-inference.huggingface.co/models/{self.model}",
|
134 |
+
headers={"Authorization": f"Bearer {self.hf_token}"},
|
135 |
+
json=payload,
|
136 |
+
)
|
137 |
+
data = r.json()
|
138 |
+
# Try to interpret scores
|
139 |
+
scores = []
|
140 |
+
if isinstance(data, dict) and "scores" in data:
|
141 |
+
scores = data["scores"]
|
142 |
+
elif isinstance(data, list) and data and isinstance(data[0], dict) and "score" in data[0]:
|
143 |
+
scores = [x.get("score", 0.0) for x in data]
|
144 |
+
else:
|
145 |
+
# Fallback: equal scores
|
146 |
+
scores = [1.0 for _ in documents]
|
147 |
+
# Sort indices by score desc
|
148 |
+
order = sorted(range(len(documents)), key=lambda i: scores[i], reverse=True)
|
149 |
+
return {"order": order, "scores": scores, "raw": data}
|
150 |
except Exception as e:
|
151 |
+
return {"error": str(e)}
|