Spaces:
Running
Running
File size: 3,754 Bytes
21549be d2b1491 9e2a8ba 21549be 9e2a8ba 1ab172b 9e2a8ba d2b1491 9e2a8ba 21549be 9e2a8ba d2b1491 9e2a8ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import os
from fastembed import SparseEmbedding
from fastembed import SparseTextEmbedding
from fastembed import TextEmbedding
from pydantic import BaseModel
from qdrant_client import QdrantClient
from qdrant_client import models
from models_semantic_search import VectorType
from settings import MIN_COSINE_SCORE
from settings import QDRANT_KEY
from settings import QDRANT_URL
from settings import SBERT_MODEL_NAME
from settings import SPARSE_MODEL_NAME
# Tell HF Hub to use /tmp for cache
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
class HybridSearcher:
def __init__(self, collection_name):
self.collection_name = collection_name
self.client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_KEY)
self.model_dense = TextEmbedding(model_name=SBERT_MODEL_NAME)
self.model_sparse = SparseTextEmbedding(model_name=SPARSE_MODEL_NAME)
def embed_dense(self, documents: str | list[str]):
return list(self.model_dense.embed(documents=documents))[0]
def embed_sparse(self, documents: str | list[str]) -> SparseEmbedding:
return list(self.model_sparse.embed(documents=documents))[0]
def embed(self, documents: str | list[str], vector_type: VectorType):
match vector_type:
case VectorType.dense:
return self.embed_dense(documents=documents)
case VectorType.sparse:
return self.embed_sparse(documents=documents)
case _:
raise ValueError(f"invalid embedding type {vector_type}")
def search(
self,
documents: str | list[str],
limit: int = 10,
limit_dense: int = 2_000,
score_threshold_dense: float = MIN_COSINE_SCORE,
limit_sparse: int = 1_000,
query_filter: None | BaseModel = None,
):
dense_embeddings = self.embed_dense(documents=documents)
sparse_embeddings = self.embed_sparse(documents=documents)
sparse_embeddings = models.SparseVector(
indices=sparse_embeddings.indices.tolist(),
values=sparse_embeddings.values.tolist(),
)
search_result = self.client.query_points(
collection_name=self.collection_name,
query=models.FusionQuery(
fusion=models.Fusion.RRF # we are using reciprocal rank fusion here
),
prefetch=[
models.Prefetch(
query=dense_embeddings,
score_threshold=score_threshold_dense,
limit=limit_dense,
using="dense",
),
models.Prefetch(
query=sparse_embeddings,
limit=limit_sparse,
using="sparse",
),
],
query_filter=query_filter, # If you don't want any filters for now
limit=limit, # 5 the closest results
).points
return search_result
def build_year_filter(year_ge: int | None = None, year_le: int | None = None) -> models.Filter | None:
conditions = []
if year_ge is not None:
conditions.append(models.FieldCondition(
key="year",
range=models.Range(gte=int(year_ge))
))
if year_le is not None:
conditions.append(models.FieldCondition(
key="year",
range=models.Range(lte=int(year_le))
))
if not conditions:
return None # no filtering
if len(conditions) == 1:
return models.Filter(
must=[conditions[0]]
)
# Both conditions: year >= ge AND year <= le
return models.Filter(
must=conditions
)
|