Spaces:
Sleeping
Sleeping
from fastembed import SparseTextEmbedding | |
from fastembed import TextEmbedding | |
from pydantic import BaseModel | |
from qdrant_client import QdrantClient | |
from qdrant_client import models | |
from settings import QDRANT_KEY | |
from settings import QDRANT_URL | |
from settings import SBERT_MODEL_NAME | |
from settings import SPARSE_MODEL_NAME | |
class HybridSearcher: | |
def __init__(self, collection_name): | |
self.collection_name = collection_name | |
self.client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_KEY) | |
self.model_dense = TextEmbedding(model_name=SBERT_MODEL_NAME) | |
self.model_sparse = SparseTextEmbedding(model_name=SPARSE_MODEL_NAME) | |
def search( | |
self, | |
documents: str | list[str], | |
limit: int = 10, | |
limit_dense: int = 2_000, | |
score_threshold_dense: float = 0.7, | |
limit_sparse: int = 1_000, | |
query_filter: None | BaseModel = None, | |
): | |
dense_embeddings = list(self.model_dense.embed(documents=documents))[0] | |
sparse_embeddings = list(self.model_sparse.embed(documents=documents))[0] | |
sparse_embeddings = models.SparseVector( | |
indices=sparse_embeddings.indices.tolist(), | |
values=sparse_embeddings.values.tolist(), | |
) | |
search_result = self.client.query_points( | |
collection_name=self.collection_name, | |
query=models.FusionQuery( | |
fusion=models.Fusion.RRF # we are using reciprocal rank fusion here | |
), | |
prefetch=[ | |
models.Prefetch( | |
query=dense_embeddings, | |
score_threshold=score_threshold_dense, | |
limit=limit_dense, | |
using="dense", | |
), | |
models.Prefetch( | |
query=sparse_embeddings, | |
limit=limit_sparse, | |
using="sparse", | |
), | |
], | |
query_filter=query_filter, # If you don't want any filters for now | |
limit=limit, # 5 the closest results | |
).points | |
return search_result | |
def build_year_filter(year_ge: int | None = None, year_le: int | None = None) -> models.Filter | None: | |
conditions = [] | |
if year_ge is not None: | |
conditions.append(models.FieldCondition( | |
key="year", | |
range=models.Range(gte=int(year_ge)) | |
)) | |
if year_le is not None: | |
conditions.append(models.FieldCondition( | |
key="year", | |
range=models.Range(lte=int(year_le)) | |
)) | |
if not conditions: | |
return None # no filtering | |
if len(conditions) == 1: | |
return models.Filter( | |
must=[conditions[0]] | |
) | |
# Both conditions: year >= ge AND year <= le | |
return models.Filter( | |
must=conditions | |
) | |