File size: 4,632 Bytes
a22e84b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import concurrent.futures

import opik
from loguru import logger
from qdrant_client.models import FieldCondition, Filter, MatchValue, Range, Distance

from llm_engineering.application import utils
from llm_engineering.application.preprocessing.dispatchers import EmbeddingDispatcher
from llm_engineering.domain.embedded_chunks import (
    EmbeddedArticleChunk,
    EmbeddedChunk,
    EmbeddedPostChunk,
    EmbeddedRepositoryChunk,
)
from llm_engineering.domain.queries import EmbeddedQuery, Query
from llm_engineering.domain.video_chunks import EmbeddedVideoChunk

from .query_expansion import QueryExpansion
from .reranking import Reranker
from .self_query import SelfQuery
from .multimodal_dispatcher import MultimodalEmbeddingDispatcher

from typing import Union


class ContextRetriever:
    def __init__(self, mock: bool = False) -> None:
        self._query_expander = QueryExpansion(mock=mock)
        self._metadata_extractor = SelfQuery(mock=mock)
        self._reranker = Reranker(mock=mock)

    @opik.track(name="ContextRetriever.search")
    def search(self, query: Union[str, Query], k: int = 3, expand_to_n_queries: int = 3) -> list:
        # Existing code
        query_model = Query.from_str(query) if isinstance(query, str) else query
        query_model = self._metadata_extractor.generate(query_model)
        
        n_generated_queries = self._query_expander.generate(query_model, expand_to_n=expand_to_n_queries)
        
        # Initialize n_k_documents with empty list
        n_k_documents = []
        
        with concurrent.futures.ThreadPoolExecutor() as executor:
            if n_generated_queries:
                search_tasks = [executor.submit(self._search, _query_model, k) 
                            for _query_model in n_generated_queries]
                
                # Handle potential None results from tasks
                n_k_documents = [task.result() or [] for task in concurrent.futures.as_completed(search_tasks)]
        
        # Ensure we're always working with a list of lists
        n_k_documents = n_k_documents or [[]]
        
        # Safe flattening with None filtering
        n_k_documents = utils.misc.flatten([docs for docs in n_k_documents if docs is not None])
        
        if n_k_documents:
            k_documents = self.rerank(query, chunks=n_k_documents, keep_top_k=k)
        else:
            k_documents = []
        
        return k_documents

    def _search(self, query: Query, k: int = 3) -> list[EmbeddedChunk]:
        assert k >= 3, "k should be >= 3"

        def _search_data_category(
            data_category_odm: type[EmbeddedChunk], embedded_query: EmbeddedQuery
        ) -> list[EmbeddedChunk]:
            if embedded_query.author_id:
                query_filter = Filter(
                    must=[
                        FieldCondition(
                            key="author_id",
                            match=MatchValue(
                                value=str(embedded_query.author_id),
                            ),
                        )
                    ]
                )
            else:
                query_filter = None

            return data_category_odm.search(
                query_vector=embedded_query.embedding,
                limit=k // 3,
                query_filter=query_filter,
            )

        embedded_query: EmbeddedQuery = EmbeddingDispatcher.dispatch(query)

        post_chunks = _search_data_category(EmbeddedPostChunk, embedded_query)
        articles_chunks = _search_data_category(EmbeddedArticleChunk, embedded_query)
        repositories_chunks = _search_data_category(EmbeddedRepositoryChunk, embedded_query)

        retrieved_chunks = post_chunks + articles_chunks + repositories_chunks

        return retrieved_chunks

    def rerank(self, query: str | Query, chunks: list[EmbeddedChunk], keep_top_k: int) -> list[EmbeddedChunk]:
        if isinstance(query, str):
            query = Query.from_str(query)

        reranked_documents = self._reranker.generate(query=query, chunks=chunks, keep_top_k=keep_top_k)

        logger.info(f"{len(reranked_documents)} documents reranked successfully.")

        return reranked_documents


class VideoContextRetriever(ContextRetriever):
    def _search(self, query: Query, k: int = 3) -> list[EmbeddedChunk]:
        def _search_video_category(self, embedded_query: EmbeddedQuery, k: int):
            return EmbeddedVideoChunk.search(
                query_vector=embedded_query.embedding,
                limit=k,
                query_filter=self._create_time_filter(embedded_query)
            )