File size: 3,754 Bytes
21549be
 
d2b1491
9e2a8ba
 
 
 
 
 
21549be
 
9e2a8ba
 
 
 
1ab172b
 
 
 
 
9e2a8ba
 
 
 
 
 
 
 
d2b1491
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e2a8ba
 
 
 
 
21549be
9e2a8ba
 
 
d2b1491
 
9e2a8ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os

from fastembed import SparseEmbedding
from fastembed import SparseTextEmbedding
from fastembed import TextEmbedding
from pydantic import BaseModel
from qdrant_client import QdrantClient
from qdrant_client import models

from models_semantic_search import VectorType
from settings import MIN_COSINE_SCORE
from settings import QDRANT_KEY
from settings import QDRANT_URL
from settings import SBERT_MODEL_NAME
from settings import SPARSE_MODEL_NAME

# Tell HF Hub to use /tmp for cache
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"


class HybridSearcher:
    def __init__(self, collection_name):
        self.collection_name = collection_name
        self.client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_KEY)
        self.model_dense = TextEmbedding(model_name=SBERT_MODEL_NAME)
        self.model_sparse = SparseTextEmbedding(model_name=SPARSE_MODEL_NAME)

    def embed_dense(self, documents: str | list[str]):
        return list(self.model_dense.embed(documents=documents))[0]

    def embed_sparse(self, documents: str | list[str]) -> SparseEmbedding:
        return list(self.model_sparse.embed(documents=documents))[0]

    def embed(self, documents: str | list[str], vector_type: VectorType):
        match vector_type:
            case VectorType.dense:
                return self.embed_dense(documents=documents)
            case VectorType.sparse:
                return self.embed_sparse(documents=documents)
            case _:
                raise ValueError(f"invalid embedding type {vector_type}")

    def search(
            self,
            documents: str | list[str],
            limit: int = 10,
            limit_dense: int = 2_000,
            score_threshold_dense: float = MIN_COSINE_SCORE,
            limit_sparse: int = 1_000,
            query_filter: None | BaseModel = None,
    ):
        dense_embeddings = self.embed_dense(documents=documents)
        sparse_embeddings = self.embed_sparse(documents=documents)
        sparse_embeddings = models.SparseVector(
            indices=sparse_embeddings.indices.tolist(),
            values=sparse_embeddings.values.tolist(),
        )

        search_result = self.client.query_points(
            collection_name=self.collection_name,
            query=models.FusionQuery(
                fusion=models.Fusion.RRF  # we are using reciprocal rank fusion here
            ),
            prefetch=[
                models.Prefetch(
                    query=dense_embeddings,
                    score_threshold=score_threshold_dense,
                    limit=limit_dense,
                    using="dense",

                ),
                models.Prefetch(
                    query=sparse_embeddings,
                    limit=limit_sparse,
                    using="sparse",
                ),
            ],
            query_filter=query_filter,  # If you don't want any filters for now
            limit=limit,  # 5 the closest results
        ).points

        return search_result


def build_year_filter(year_ge: int | None = None, year_le: int | None = None) -> models.Filter | None:
    conditions = []
    if year_ge is not None:
        conditions.append(models.FieldCondition(
            key="year",
            range=models.Range(gte=int(year_ge))
        ))
    if year_le is not None:
        conditions.append(models.FieldCondition(
            key="year",
            range=models.Range(lte=int(year_le))
        ))

    if not conditions:
        return None  # no filtering

    if len(conditions) == 1:
        return models.Filter(
            must=[conditions[0]]
        )

    # Both conditions: year >= ge AND year <= le
    return models.Filter(
        must=conditions
    )