File size: 2,191 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import functools
from typing import Optional, Union, Any
import requests
from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory
from dsp.utils import dotdict
# TODO: Ideally, this takes the name of the index and looks up its port.
class ColBERTv2:
"""Wrapper for the ColBERTv2 Retrieval."""
def __init__(
self,
url: str = "http://0.0.0.0",
port: Optional[Union[str, int]] = None,
post_requests: bool = False,
):
self.post_requests = post_requests
self.url = f"{url}:{port}" if port else url
def __call__(
self, query: str, k: int = 10, simplify: bool = False
) -> Union[list[str], list[dotdict]]:
if self.post_requests:
topk: list[dict[str, Any]] = colbertv2_post_request(self.url, query, k)
else:
topk: list[dict[str, Any]] = colbertv2_get_request(self.url, query, k)
if simplify:
return [psg["long_text"] for psg in topk]
return [dotdict(psg) for psg in topk]
@CacheMemory.cache
def colbertv2_get_request_v2(url: str, query: str, k: int):
assert (
k <= 100
), "Only k <= 100 is supported for the hosted ColBERTv2 server at the moment."
payload = {"query": query, "k": k}
res = requests.get(url, params=payload, timeout=10)
topk = res.json()["topk"][:k]
topk = [{**d, "long_text": d["text"]} for d in topk]
return topk[:k]
@functools.lru_cache(maxsize=None)
@NotebookCacheMemory.cache
def colbertv2_get_request_v2_wrapped(*args, **kwargs):
return colbertv2_get_request_v2(*args, **kwargs)
colbertv2_get_request = colbertv2_get_request_v2_wrapped
@CacheMemory.cache
def colbertv2_post_request_v2(url: str, query: str, k: int):
headers = {"Content-Type": "application/json; charset=utf-8"}
payload = {"query": query, "k": k}
res = requests.post(url, json=payload, headers=headers, timeout=10)
return res.json()["topk"][:k]
@functools.lru_cache(maxsize=None)
@NotebookCacheMemory.cache
def colbertv2_post_request_v2_wrapped(*args, **kwargs):
return colbertv2_post_request_v2(*args, **kwargs)
colbertv2_post_request = colbertv2_post_request_v2_wrapped
|