Spaces:
Running
Running
import numpy as np | |
import pickle | |
import base64 | |
from typing import List, Dict, Any, Tuple, Optional | |
import json | |
from dataclasses import dataclass | |
try: | |
from sentence_transformers import SentenceTransformer | |
HAS_SENTENCE_TRANSFORMERS = True | |
except ImportError: | |
HAS_SENTENCE_TRANSFORMERS = False | |
try: | |
import faiss | |
HAS_FAISS = True | |
except ImportError: | |
HAS_FAISS = False | |
class SearchResult: | |
chunk_id: str | |
text: str | |
score: float | |
metadata: Dict[str, Any] | |
class VectorStore: | |
def __init__(self, embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"): | |
self.embedding_model_name = embedding_model | |
self.embedding_model = None | |
self.index = None | |
self.chunks = {} # chunk_id -> chunk data | |
self.chunk_ids = [] # Ordered list for FAISS index mapping | |
self.dimension = 384 # Default for all-MiniLM-L6-v2 | |
if HAS_SENTENCE_TRANSFORMERS: | |
self._initialize_model() | |
def _initialize_model(self): | |
"""Initialize the embedding model""" | |
if not HAS_SENTENCE_TRANSFORMERS: | |
raise ImportError("sentence-transformers not installed") | |
self.embedding_model = SentenceTransformer(self.embedding_model_name) | |
# Update dimension based on model | |
self.dimension = self.embedding_model.get_sentence_embedding_dimension() | |
def create_embeddings(self, texts: List[str], batch_size: int = 32) -> np.ndarray: | |
"""Create embeddings for a list of texts""" | |
if not self.embedding_model: | |
self._initialize_model() | |
# Process in batches for efficiency | |
embeddings = [] | |
for i in range(0, len(texts), batch_size): | |
batch = texts[i:i + batch_size] | |
batch_embeddings = self.embedding_model.encode( | |
batch, | |
convert_to_numpy=True, | |
show_progress_bar=False | |
) | |
embeddings.append(batch_embeddings) | |
return np.vstack(embeddings) if embeddings else np.array([]) | |
def build_index(self, chunks: List[Dict[str, Any]], show_progress: bool = True): | |
"""Build FAISS index from chunks""" | |
if not HAS_FAISS: | |
raise ImportError("faiss-cpu not installed") | |
# Extract texts and build embeddings | |
texts = [chunk['text'] for chunk in chunks] | |
if show_progress: | |
print(f"Creating embeddings for {len(texts)} chunks...") | |
embeddings = self.create_embeddings(texts) | |
# Build FAISS index | |
if show_progress: | |
print("Building FAISS index...") | |
# Use IndexFlatIP for inner product (cosine similarity with normalized vectors) | |
self.index = faiss.IndexFlatIP(self.dimension) | |
# Normalize embeddings for cosine similarity | |
faiss.normalize_L2(embeddings) | |
# Add to index | |
self.index.add(embeddings) | |
# Store chunks and maintain mapping | |
self.chunks = {} | |
self.chunk_ids = [] | |
for chunk in chunks: | |
chunk_id = chunk['chunk_id'] | |
self.chunks[chunk_id] = chunk | |
self.chunk_ids.append(chunk_id) | |
if show_progress: | |
print(f"Index built with {len(chunks)} chunks") | |
def search(self, query: str, top_k: int = 5, score_threshold: float = 0.3) -> List[SearchResult]: | |
"""Search for similar chunks""" | |
if not self.index or not self.chunks: | |
return [] | |
# Create query embedding | |
query_embedding = self.create_embeddings([query]) | |
# Normalize for cosine similarity | |
faiss.normalize_L2(query_embedding) | |
# Search | |
scores, indices = self.index.search(query_embedding, min(top_k, len(self.chunks))) | |
# Convert to results | |
results = [] | |
for score, idx in zip(scores[0], indices[0]): | |
if idx < 0 or score < score_threshold: | |
continue | |
chunk_id = self.chunk_ids[idx] | |
chunk = self.chunks[chunk_id] | |
result = SearchResult( | |
chunk_id=chunk_id, | |
text=chunk['text'], | |
score=float(score), | |
metadata=chunk.get('metadata', {}) | |
) | |
results.append(result) | |
return results | |
def serialize(self) -> Dict[str, Any]: | |
"""Serialize the vector store for deployment""" | |
if not self.index: | |
raise ValueError("No index to serialize") | |
# Serialize FAISS index | |
index_bytes = faiss.serialize_index(self.index) | |
index_base64 = base64.b64encode(index_bytes).decode('utf-8') | |
return { | |
'index_base64': index_base64, | |
'chunks': self.chunks, | |
'chunk_ids': self.chunk_ids, | |
'dimension': self.dimension, | |
'model_name': self.embedding_model_name | |
} | |
def deserialize(cls, data: Dict[str, Any]) -> 'VectorStore': | |
"""Deserialize a vector store from deployment data""" | |
if not HAS_FAISS: | |
raise ImportError("faiss-cpu not installed") | |
store = cls(embedding_model=data['model_name']) | |
# Deserialize FAISS index | |
index_bytes = base64.b64decode(data['index_base64']) | |
store.index = faiss.deserialize_index(index_bytes) | |
# Restore chunks and mappings | |
store.chunks = data['chunks'] | |
store.chunk_ids = data['chunk_ids'] | |
store.dimension = data['dimension'] | |
return store | |
def get_stats(self) -> Dict[str, Any]: | |
"""Get statistics about the vector store""" | |
return { | |
'total_chunks': len(self.chunks), | |
'index_size': self.index.ntotal if self.index else 0, | |
'dimension': self.dimension, | |
'model': self.embedding_model_name | |
} | |
class LightweightVectorStore: | |
"""Lightweight version for deployed spaces without embedding model""" | |
def __init__(self, serialized_data: Dict[str, Any]): | |
if not HAS_FAISS: | |
raise ImportError("faiss-cpu not installed") | |
# Deserialize FAISS index | |
index_bytes = base64.b64decode(serialized_data['index_base64']) | |
self.index = faiss.deserialize_index(index_bytes) | |
# Restore chunks and mappings | |
self.chunks = serialized_data['chunks'] | |
self.chunk_ids = serialized_data['chunk_ids'] | |
self.dimension = serialized_data['dimension'] | |
# For query embedding, we'll need to include pre-computed embeddings | |
# or use a lightweight embedding service | |
self.query_embeddings_cache = serialized_data.get('query_embeddings_cache', {}) | |
def search_with_embedding(self, query_embedding: np.ndarray, top_k: int = 5, score_threshold: float = 0.3) -> List[SearchResult]: | |
"""Search using pre-computed query embedding""" | |
if not self.index or not self.chunks: | |
return [] | |
# Normalize for cosine similarity | |
faiss.normalize_L2(query_embedding) | |
# Search | |
scores, indices = self.index.search(query_embedding, min(top_k, len(self.chunks))) | |
# Convert to results | |
results = [] | |
for score, idx in zip(scores[0], indices[0]): | |
if idx < 0 or score < score_threshold: | |
continue | |
chunk_id = self.chunk_ids[idx] | |
chunk = self.chunks[chunk_id] | |
result = SearchResult( | |
chunk_id=chunk_id, | |
text=chunk['text'], | |
score=float(score), | |
metadata=chunk.get('metadata', {}) | |
) | |
results.append(result) | |
return results | |
# Utility functions | |
def estimate_index_size(num_chunks: int, dimension: int = 384) -> float: | |
"""Estimate the size of the index in MB""" | |
# Rough estimation: 4 bytes per float * dimension * num_chunks | |
bytes_size = 4 * dimension * num_chunks | |
# Add overhead for index structure and metadata | |
overhead = 1.2 | |
return (bytes_size * overhead) / (1024 * 1024) |