import concurrent.futures import opik from loguru import logger from qdrant_client.models import FieldCondition, Filter, MatchValue, Range, Distance from llm_engineering.application import utils from llm_engineering.application.preprocessing.dispatchers import EmbeddingDispatcher from llm_engineering.domain.embedded_chunks import ( EmbeddedArticleChunk, EmbeddedChunk, EmbeddedPostChunk, EmbeddedRepositoryChunk, ) from llm_engineering.domain.queries import EmbeddedQuery, Query from llm_engineering.domain.video_chunks import EmbeddedVideoChunk from .query_expansion import QueryExpansion from .reranking import Reranker from .self_query import SelfQuery from .multimodal_dispatcher import MultimodalEmbeddingDispatcher from typing import Union class ContextRetriever: def __init__(self, mock: bool = False) -> None: self._query_expander = QueryExpansion(mock=mock) self._metadata_extractor = SelfQuery(mock=mock) self._reranker = Reranker(mock=mock) @opik.track(name="ContextRetriever.search") def search(self, query: Union[str, Query], k: int = 3, expand_to_n_queries: int = 3) -> list: # Existing code query_model = Query.from_str(query) if isinstance(query, str) else query query_model = self._metadata_extractor.generate(query_model) n_generated_queries = self._query_expander.generate(query_model, expand_to_n=expand_to_n_queries) # Initialize n_k_documents with empty list n_k_documents = [] with concurrent.futures.ThreadPoolExecutor() as executor: if n_generated_queries: search_tasks = [executor.submit(self._search, _query_model, k) for _query_model in n_generated_queries] # Handle potential None results from tasks n_k_documents = [task.result() or [] for task in concurrent.futures.as_completed(search_tasks)] # Ensure we're always working with a list of lists n_k_documents = n_k_documents or [[]] # Safe flattening with None filtering n_k_documents = utils.misc.flatten([docs for docs in n_k_documents if docs is not None]) if n_k_documents: k_documents = self.rerank(query, chunks=n_k_documents, keep_top_k=k) else: k_documents = [] return k_documents def _search(self, query: Query, k: int = 3) -> list[EmbeddedChunk]: assert k >= 3, "k should be >= 3" def _search_data_category( data_category_odm: type[EmbeddedChunk], embedded_query: EmbeddedQuery ) -> list[EmbeddedChunk]: if embedded_query.author_id: query_filter = Filter( must=[ FieldCondition( key="author_id", match=MatchValue( value=str(embedded_query.author_id), ), ) ] ) else: query_filter = None return data_category_odm.search( query_vector=embedded_query.embedding, limit=k // 3, query_filter=query_filter, ) embedded_query: EmbeddedQuery = EmbeddingDispatcher.dispatch(query) post_chunks = _search_data_category(EmbeddedPostChunk, embedded_query) articles_chunks = _search_data_category(EmbeddedArticleChunk, embedded_query) repositories_chunks = _search_data_category(EmbeddedRepositoryChunk, embedded_query) retrieved_chunks = post_chunks + articles_chunks + repositories_chunks return retrieved_chunks def rerank(self, query: str | Query, chunks: list[EmbeddedChunk], keep_top_k: int) -> list[EmbeddedChunk]: if isinstance(query, str): query = Query.from_str(query) reranked_documents = self._reranker.generate(query=query, chunks=chunks, keep_top_k=keep_top_k) logger.info(f"{len(reranked_documents)} documents reranked successfully.") return reranked_documents class VideoContextRetriever(ContextRetriever): def _search(self, query: Query, k: int = 3) -> list[EmbeddedChunk]: def _search_video_category(self, embedded_query: EmbeddedQuery, k: int): return EmbeddedVideoChunk.search( query_vector=embedded_query.embedding, limit=k, query_filter=self._create_time_filter(embedded_query) )