mgbam commited on
Commit
e22ad8c
Β·
verified Β·
1 Parent(s): 9d9338d

Update genesis/tools.py

Browse files
Files changed (1) hide show
  1. genesis/tools.py +59 -19
genesis/tools.py CHANGED
@@ -1,16 +1,15 @@
1
-
2
  from __future__ import annotations
3
- import os, json, re
4
  import httpx
5
- from typing import Any, Dict, Optional, List
6
 
7
  class ToolBase:
8
  name: str = "tool"
9
  description: str = ""
10
-
11
  async def call(self, *args, **kwargs) -> Dict[str, Any]:
12
  raise NotImplementedError
13
 
 
14
  class OntologyTool(ToolBase):
15
  name = "ontology_normalize"
16
  description = "Normalize biomedical terms via BioPortal; returns concept info (no protocols)."
@@ -26,13 +25,14 @@ class OntologyTool(ToolBase):
26
  r = await self.http.get(
27
  "https://data.bioontology.org/search",
28
  params={"q": term, "pagesize": 5},
29
- headers={"Authorization": f"apikey token={self.bioportal_key}"}
30
  )
31
  out["bioportal"] = r.json()
32
  except Exception as e:
33
  out["bioportal_error"] = str(e)
34
  return out
35
 
 
36
  class PubMedTool(ToolBase):
37
  name = "pubmed_search"
38
  description = "Search PubMed via NCBI; return metadata with citations."
@@ -47,29 +47,31 @@ class PubMedTool(ToolBase):
47
  try:
48
  es = await self.http.get(
49
  base + "esearch.fcgi",
50
- params={"db":"pubmed","term":query,"retmode":"json","retmax":20,"api_key":self.key,"email":self.email}
51
  )
52
- ids = es.json().get("esearchresult",{}).get("idlist",[])
53
- if not ids: return {"query":query,"results":[]}
 
54
  su = await self.http.get(
55
  base + "esummary.fcgi",
56
- params={"db":"pubmed","id":",".join(ids),"retmode":"json","api_key":self.key,"email":self.email}
57
  )
58
- recs = su.json().get("result",{})
59
  items = []
60
  for pmid in ids:
61
- r = recs.get(pmid,{ })
62
  items.append({
63
  "pmid": pmid,
64
  "title": r.get("title"),
65
  "journal": r.get("fulljournalname"),
66
  "year": (r.get("pubdate") or "")[:4],
67
- "authors": [a.get("name") for a in r.get("authors",[])],
68
  })
69
- return {"query":query,"results":items}
70
  except Exception as e:
71
- return {"query":query,"error":str(e)}
72
 
 
73
  class StructureTool(ToolBase):
74
  name = "structure_info"
75
  description = "Query RCSB structure metadata (no lab steps)."
@@ -87,6 +89,7 @@ class StructureTool(ToolBase):
87
  out["error"] = str(e)
88
  return out
89
 
 
90
  class CrossrefTool(ToolBase):
91
  name = "crossref_search"
92
  description = "Crossref search for DOIs; titles, years, authors."
@@ -96,16 +99,53 @@ class CrossrefTool(ToolBase):
96
 
97
  async def call(self, query: str) -> dict:
98
  try:
99
- r = await self.http.get("https://api.crossref.org/works", params={"query":query,"rows":10})
100
- items = r.json().get("message",{}).get("items",[])
101
  papers = []
102
  for it in items:
103
  papers.append({
104
  "title": (it.get("title") or [None])[0],
105
  "doi": it.get("DOI"),
106
  "year": (it.get("issued") or {}).get("date-parts", [[None]])[0][0],
107
- "authors": [f"{a.get('given','')} {a.get('family','')}".strip() for a in it.get("author",[])],
108
  })
109
- return {"query":query,"results":papers}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  except Exception as e:
111
- return {"query":query,"error":str(e)}
 
 
1
  from __future__ import annotations
2
+ import os, json
3
  import httpx
4
+ from typing import Any, Dict, List
5
 
6
  class ToolBase:
7
  name: str = "tool"
8
  description: str = ""
 
9
  async def call(self, *args, **kwargs) -> Dict[str, Any]:
10
  raise NotImplementedError
11
 
12
+ # β€” Ontology normalization (BioPortal)
13
  class OntologyTool(ToolBase):
14
  name = "ontology_normalize"
15
  description = "Normalize biomedical terms via BioPortal; returns concept info (no protocols)."
 
25
  r = await self.http.get(
26
  "https://data.bioontology.org/search",
27
  params={"q": term, "pagesize": 5},
28
+ headers={"Authorization": f"apikey token={self.bioportal_key}"},
29
  )
30
  out["bioportal"] = r.json()
31
  except Exception as e:
32
  out["bioportal_error"] = str(e)
33
  return out
34
 
35
+ # β€” PubMed search (NCBI E-utilities)
36
  class PubMedTool(ToolBase):
37
  name = "pubmed_search"
38
  description = "Search PubMed via NCBI; return metadata with citations."
 
47
  try:
48
  es = await self.http.get(
49
  base + "esearch.fcgi",
50
+ params={"db":"pubmed","term":query,"retmode":"json","retmax":20,"api_key":self.key,"email":self.email},
51
  )
52
+ ids = es.json().get("esearchresult", {}).get("idlist", [])
53
+ if not ids:
54
+ return {"query": query, "results": []}
55
  su = await self.http.get(
56
  base + "esummary.fcgi",
57
+ params={"db":"pubmed","id":",".join(ids),"retmode":"json","api_key":self.key,"email":self.email},
58
  )
59
+ recs = su.json().get("result", {})
60
  items = []
61
  for pmid in ids:
62
+ r = recs.get(pmid, {})
63
  items.append({
64
  "pmid": pmid,
65
  "title": r.get("title"),
66
  "journal": r.get("fulljournalname"),
67
  "year": (r.get("pubdate") or "")[:4],
68
+ "authors": [a.get("name") for a in r.get("authors", [])],
69
  })
70
+ return {"query": query, "results": items}
71
  except Exception as e:
72
+ return {"query": query, "error": str(e)}
73
 
74
+ # β€” RCSB structure metadata
75
  class StructureTool(ToolBase):
76
  name = "structure_info"
77
  description = "Query RCSB structure metadata (no lab steps)."
 
89
  out["error"] = str(e)
90
  return out
91
 
92
+ # β€” Crossref DOIs
93
  class CrossrefTool(ToolBase):
94
  name = "crossref_search"
95
  description = "Crossref search for DOIs; titles, years, authors."
 
99
 
100
  async def call(self, query: str) -> dict:
101
  try:
102
+ r = await self.http.get("https://api.crossref.org/works", params={"query": query, "rows": 10})
103
+ items = r.json().get("message", {}).get("items", [])
104
  papers = []
105
  for it in items:
106
  papers.append({
107
  "title": (it.get("title") or [None])[0],
108
  "doi": it.get("DOI"),
109
  "year": (it.get("issued") or {}).get("date-parts", [[None]])[0][0],
110
+ "authors": [f"{a.get('given','')} {a.get('family','')}".strip() for a in it.get("author", [])],
111
  })
112
+ return {"query": query, "results": papers}
113
+ except Exception as e:
114
+ return {"query": query, "error": str(e)}
115
+
116
+ # β€” HF Inference API Reranker (optional)
117
+ class HFRerankTool(ToolBase):
118
+ name = "hf_rerank"
119
+ description = "Rerank documents using a Hugging Face reranker model (API)."
120
+
121
+ def __init__(self, model_id: str):
122
+ self.model = model_id
123
+ self.hf_token = os.getenv("HF_TOKEN")
124
+ self.http = httpx.AsyncClient(timeout=30.0)
125
+
126
+ async def call(self, query: str, documents: List[str]) -> dict:
127
+ if not self.hf_token:
128
+ return {"error": "HF_TOKEN not set"}
129
+ try:
130
+ # Generic payload; different models may expect different schemas β€” keep robust.
131
+ payload = {"inputs": {"query": query, "texts": documents}}
132
+ r = await self.http.post(
133
+ f"https://api-inference.huggingface.co/models/{self.model}",
134
+ headers={"Authorization": f"Bearer {self.hf_token}"},
135
+ json=payload,
136
+ )
137
+ data = r.json()
138
+ # Try to interpret scores
139
+ scores = []
140
+ if isinstance(data, dict) and "scores" in data:
141
+ scores = data["scores"]
142
+ elif isinstance(data, list) and data and isinstance(data[0], dict) and "score" in data[0]:
143
+ scores = [x.get("score", 0.0) for x in data]
144
+ else:
145
+ # Fallback: equal scores
146
+ scores = [1.0 for _ in documents]
147
+ # Sort indices by score desc
148
+ order = sorted(range(len(documents)), key=lambda i: scores[i], reverse=True)
149
+ return {"order": order, "scores": scores, "raw": data}
150
  except Exception as e:
151
+ return {"error": str(e)}