|
""" |
|
CPU-optimized retrieval for efficient context handling. |
|
""" |
|
|
|
import logging |
|
import heapq |
|
from typing import List, Dict, Any, Optional, Tuple, Union |
|
import numpy as np |
|
|
|
from efficient_context.retrieval.base import BaseRetriever |
|
from efficient_context.chunking.base import Chunk |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class CPUOptimizedRetriever(BaseRetriever): |
|
""" |
|
Retriever optimized for CPU performance and low memory usage. |
|
|
|
This retriever uses techniques to minimize computational requirements |
|
while still providing high-quality retrieval results. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embedding_model: str = "lightweight", |
|
similarity_metric: str = "cosine", |
|
use_batching: bool = True, |
|
batch_size: int = 32, |
|
max_index_size: Optional[int] = None, |
|
): |
|
""" |
|
Initialize the CPUOptimizedRetriever. |
|
|
|
Args: |
|
embedding_model: Model to use for embeddings |
|
similarity_metric: Metric for comparing embeddings |
|
use_batching: Whether to batch embedding operations |
|
batch_size: Size of batches for embedding |
|
max_index_size: Maximum number of chunks to keep in the index |
|
""" |
|
self.embedding_model = embedding_model |
|
self.similarity_metric = similarity_metric |
|
self.use_batching = use_batching |
|
self.batch_size = batch_size |
|
self.max_index_size = max_index_size |
|
|
|
|
|
self.chunks = [] |
|
self.chunk_embeddings = None |
|
self.chunk_ids_to_index = {} |
|
|
|
|
|
self._init_embedding_model() |
|
|
|
logger.info("CPUOptimizedRetriever initialized with model: %s", embedding_model) |
|
|
|
def _init_embedding_model(self): |
|
"""Initialize the embedding model.""" |
|
try: |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
if self.embedding_model == "lightweight": |
|
|
|
self.model = SentenceTransformer('paraphrase-MiniLM-L3-v2') |
|
else: |
|
|
|
self.model = SentenceTransformer(self.embedding_model) |
|
|
|
logger.info("Using embedding model: %s", self.model.get_sentence_embedding_dimension()) |
|
except ImportError: |
|
logger.warning("SentenceTransformer not available, using numpy fallback (less accurate)") |
|
self.model = None |
|
|
|
def _get_embeddings(self, texts: List[str]) -> np.ndarray: |
|
""" |
|
Get embeddings for a list of texts. |
|
|
|
Args: |
|
texts: List of texts to embed |
|
|
|
Returns: |
|
embeddings: Array of text embeddings |
|
""" |
|
if not texts: |
|
return np.array([]) |
|
|
|
if self.model is not None: |
|
|
|
|
|
if self.use_batching and len(texts) > self.batch_size: |
|
embeddings = [] |
|
|
|
for i in range(0, len(texts), self.batch_size): |
|
batch = texts[i:i+self.batch_size] |
|
batch_embeddings = self.model.encode( |
|
batch, |
|
show_progress_bar=False, |
|
convert_to_numpy=True |
|
) |
|
embeddings.append(batch_embeddings) |
|
|
|
return np.vstack(embeddings) |
|
else: |
|
return self.model.encode(texts, show_progress_bar=False) |
|
else: |
|
|
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
vectorizer = TfidfVectorizer(max_features=5000) |
|
return vectorizer.fit_transform(texts).toarray() |
|
|
|
def _compute_similarities(self, query_embedding: np.ndarray, chunk_embeddings: np.ndarray) -> np.ndarray: |
|
""" |
|
Compute similarities between query and chunk embeddings. |
|
|
|
Args: |
|
query_embedding: Embedding of the query |
|
chunk_embeddings: Embeddings of the chunks |
|
|
|
Returns: |
|
similarities: Array of similarity scores |
|
""" |
|
if self.similarity_metric == "cosine": |
|
|
|
query_norm = np.linalg.norm(query_embedding) |
|
if query_norm > 0: |
|
query_embedding = query_embedding / query_norm |
|
|
|
|
|
return np.dot(chunk_embeddings, query_embedding) |
|
elif self.similarity_metric == "dot": |
|
|
|
return np.dot(chunk_embeddings, query_embedding) |
|
elif self.similarity_metric == "euclidean": |
|
|
|
return -np.sqrt(np.sum((chunk_embeddings - query_embedding) ** 2, axis=1)) |
|
else: |
|
|
|
return np.dot(chunk_embeddings, query_embedding) |
|
|
|
def index_chunks(self, chunks: List[Chunk]) -> None: |
|
""" |
|
Index chunks for future retrieval. |
|
|
|
Args: |
|
chunks: Chunks to index |
|
""" |
|
if not chunks: |
|
return |
|
|
|
|
|
for chunk in chunks: |
|
|
|
if chunk.chunk_id in self.chunk_ids_to_index: |
|
continue |
|
|
|
self.chunks.append(chunk) |
|
self.chunk_ids_to_index[chunk.chunk_id] = len(self.chunks) - 1 |
|
|
|
|
|
chunk_texts = [chunk.content for chunk in self.chunks] |
|
self.chunk_embeddings = self._get_embeddings(chunk_texts) |
|
|
|
|
|
if (self.max_index_size is not None and |
|
len(self.chunks) > self.max_index_size and |
|
self.model is not None): |
|
|
|
|
|
self.chunks = self.chunks[-self.max_index_size:] |
|
|
|
|
|
self.chunk_ids_to_index = { |
|
chunk.chunk_id: i for i, chunk in enumerate(self.chunks) |
|
} |
|
|
|
|
|
chunk_texts = [chunk.content for chunk in self.chunks] |
|
self.chunk_embeddings = self._get_embeddings(chunk_texts) |
|
|
|
|
|
if self.similarity_metric == "cosine" and self.chunk_embeddings is not None: |
|
|
|
norms = np.linalg.norm(self.chunk_embeddings, axis=1, keepdims=True) |
|
|
|
|
|
non_zero_norms = norms > 0 |
|
if np.any(non_zero_norms): |
|
|
|
self.chunk_embeddings = np.where( |
|
non_zero_norms, |
|
self.chunk_embeddings / norms, |
|
self.chunk_embeddings |
|
) |
|
|
|
logger.info("Indexed %d chunks (total: %d)", len(chunks), len(self.chunks)) |
|
|
|
def retrieve(self, query: str, top_k: Optional[int] = None) -> List[Chunk]: |
|
""" |
|
Retrieve chunks relevant to a query. |
|
|
|
Args: |
|
query: Query to retrieve chunks for |
|
top_k: Number of chunks to retrieve (default: 5) |
|
|
|
Returns: |
|
chunks: List of retrieved chunks |
|
""" |
|
if not self.chunks: |
|
logger.warning("No chunks indexed for retrieval") |
|
return [] |
|
|
|
if not query: |
|
logger.warning("Empty query provided") |
|
return [] |
|
|
|
|
|
top_k = top_k or 5 |
|
|
|
|
|
query_embedding = self._get_embeddings([query])[0] |
|
|
|
|
|
similarities = self._compute_similarities(query_embedding, self.chunk_embeddings) |
|
|
|
|
|
if top_k >= len(similarities): |
|
top_indices = list(range(len(similarities))) |
|
top_indices.sort(key=lambda i: similarities[i], reverse=True) |
|
else: |
|
|
|
top_indices = heapq.nlargest(top_k, range(len(similarities)), key=lambda i: similarities[i]) |
|
|
|
|
|
retrieved_chunks = [self.chunks[i] for i in top_indices] |
|
|
|
logger.info("Retrieved %d chunks for query", len(retrieved_chunks)) |
|
return retrieved_chunks |
|
|
|
def clear(self) -> None: |
|
"""Clear all indexed chunks.""" |
|
self.chunks = [] |
|
self.chunk_embeddings = None |
|
self.chunk_ids_to_index = {} |
|
logger.info("Cleared chunk index") |
|
|