CHI-tography / utils_semantic_search.py
ocantocarlos's picture
feat: add files for hf space deployement
9e2a8ba
from fastembed import SparseTextEmbedding
from fastembed import TextEmbedding
from pydantic import BaseModel
from qdrant_client import QdrantClient
from qdrant_client import models
from settings import QDRANT_KEY
from settings import QDRANT_URL
from settings import SBERT_MODEL_NAME
from settings import SPARSE_MODEL_NAME
class HybridSearcher:
def __init__(self, collection_name):
self.collection_name = collection_name
self.client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_KEY)
self.model_dense = TextEmbedding(model_name=SBERT_MODEL_NAME)
self.model_sparse = SparseTextEmbedding(model_name=SPARSE_MODEL_NAME)
def search(
self,
documents: str | list[str],
limit: int = 10,
limit_dense: int = 2_000,
score_threshold_dense: float = 0.7,
limit_sparse: int = 1_000,
query_filter: None | BaseModel = None,
):
dense_embeddings = list(self.model_dense.embed(documents=documents))[0]
sparse_embeddings = list(self.model_sparse.embed(documents=documents))[0]
sparse_embeddings = models.SparseVector(
indices=sparse_embeddings.indices.tolist(),
values=sparse_embeddings.values.tolist(),
)
search_result = self.client.query_points(
collection_name=self.collection_name,
query=models.FusionQuery(
fusion=models.Fusion.RRF # we are using reciprocal rank fusion here
),
prefetch=[
models.Prefetch(
query=dense_embeddings,
score_threshold=score_threshold_dense,
limit=limit_dense,
using="dense",
),
models.Prefetch(
query=sparse_embeddings,
limit=limit_sparse,
using="sparse",
),
],
query_filter=query_filter, # If you don't want any filters for now
limit=limit, # 5 the closest results
).points
return search_result
def build_year_filter(year_ge: int | None = None, year_le: int | None = None) -> models.Filter | None:
conditions = []
if year_ge is not None:
conditions.append(models.FieldCondition(
key="year",
range=models.Range(gte=int(year_ge))
))
if year_le is not None:
conditions.append(models.FieldCondition(
key="year",
range=models.Range(lte=int(year_le))
))
if not conditions:
return None # no filtering
if len(conditions) == 1:
return models.Filter(
must=[conditions[0]]
)
# Both conditions: year >= ge AND year <= le
return models.Filter(
must=conditions
)