Spaces:
Running
Running
import numpy as np | |
import pickle | |
import base64 | |
from typing import List, Dict, Any, Tuple, Optional | |
import json | |
from dataclasses import dataclass | |
try: | |
from sentence_transformers import SentenceTransformer | |
HAS_SENTENCE_TRANSFORMERS = True | |
except ImportError: | |
HAS_SENTENCE_TRANSFORMERS = False | |
try: | |
import faiss | |
HAS_FAISS = True | |
except ImportError: | |
HAS_FAISS = False | |
class SearchResult: | |
chunk_id: str | |
text: str | |
score: float | |
metadata: Dict[str, Any] | |
class VectorStore: | |
def __init__(self, embedding_model: str = "all-MiniLM-L6-v2"): | |
self.embedding_model_name = embedding_model | |
self.embedding_model = None | |
self.index = None | |
self.chunks = {} # chunk_id -> chunk data | |
self.chunk_ids = [] # Ordered list for FAISS index mapping | |
self.dimension = 384 # Default for all-MiniLM-L6-v2 | |
if HAS_SENTENCE_TRANSFORMERS: | |
self._initialize_model() | |
def _initialize_model(self): | |
"""Initialize the embedding model""" | |
if not HAS_SENTENCE_TRANSFORMERS: | |
raise ImportError("sentence-transformers not installed") | |
try: | |
print(f"Loading embedding model: {self.embedding_model_name}") | |
print("This may take a moment on first run as the model downloads...") | |
# Set environment variables to prevent multiprocessing issues | |
import os | |
os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
os.environ['OMP_NUM_THREADS'] = '1' | |
os.environ['MKL_NUM_THREADS'] = '1' | |
# Initialize with specific settings to avoid multiprocessing issues | |
self.embedding_model = SentenceTransformer( | |
self.embedding_model_name, | |
device='cpu', # Force CPU to avoid GPU/multiprocessing conflicts | |
cache_folder=None, # Use default cache | |
# Additional parameters to reduce memory usage | |
use_auth_token=False, | |
trust_remote_code=False # Security best practice | |
) | |
# Disable multiprocessing for stability in web apps | |
if hasattr(self.embedding_model, 'pool'): | |
self.embedding_model.pool = None | |
# Additional stability measures for Gradio environment | |
if hasattr(self.embedding_model, '_modules'): | |
for module in self.embedding_model._modules.values(): | |
if hasattr(module, 'num_workers'): | |
module.num_workers = 0 | |
# Update dimension based on model | |
self.dimension = self.embedding_model.get_sentence_embedding_dimension() | |
print(f"Model loaded successfully, dimension: {self.dimension}") | |
except Exception as e: | |
print(f"Failed to initialize embedding model: {e}") | |
# Provide more specific error messages | |
if "connection" in str(e).lower() or "timeout" in str(e).lower(): | |
raise RuntimeError(f"Network error downloading model '{self.embedding_model_name}'. " | |
f"Please check your internet connection and try again: {e}") | |
elif "memory" in str(e).lower() or "out of memory" in str(e).lower(): | |
raise RuntimeError(f"Insufficient memory to load model '{self.embedding_model_name}'. " | |
f"Try using a smaller model or increase available memory: {e}") | |
else: | |
raise RuntimeError(f"Could not load embedding model '{self.embedding_model_name}': {e}") | |
def create_embeddings(self, texts: List[str], batch_size: int = 8) -> np.ndarray: | |
"""Create embeddings for a list of texts""" | |
if not self.embedding_model: | |
self._initialize_model() | |
# Use smaller batch size for stability | |
embeddings = [] | |
try: | |
print(f"Creating embeddings for {len(texts)} text chunks...") | |
for i in range(0, len(texts), batch_size): | |
batch = texts[i:i + batch_size] | |
print(f"Processing batch {i//batch_size + 1}/{(len(texts) + batch_size - 1)//batch_size}") | |
batch_embeddings = self.embedding_model.encode( | |
batch, | |
convert_to_numpy=True, | |
show_progress_bar=False, | |
device='cpu', # Force CPU to avoid GPU conflicts | |
normalize_embeddings=False, # We'll normalize later with FAISS | |
batch_size=min(batch_size, 4) # Extra safety on batch size | |
) | |
embeddings.append(batch_embeddings) | |
# Import gc for garbage collection | |
import gc | |
gc.collect() # Force garbage collection between batches | |
except Exception as e: | |
# Log the error and provide a helpful message | |
print(f"Error creating embeddings: {e}") | |
if "cuda" in str(e).lower() or "gpu" in str(e).lower(): | |
raise RuntimeError(f"GPU/CUDA error encountered. The model is configured to use CPU only. Error: {e}") | |
elif "memory" in str(e).lower() or "out of memory" in str(e).lower(): | |
raise RuntimeError(f"Out of memory while creating embeddings. Try uploading smaller files or fewer files at once: {e}") | |
else: | |
raise RuntimeError(f"Failed to create embeddings: {e}") | |
return np.vstack(embeddings) if embeddings else np.array([]) | |
def build_index(self, chunks: List[Dict[str, Any]], show_progress: bool = True): | |
"""Build FAISS index from chunks""" | |
if not HAS_FAISS: | |
raise ImportError("faiss-cpu not installed") | |
# Extract texts and build embeddings | |
texts = [chunk['text'] for chunk in chunks] | |
if show_progress: | |
print(f"Creating embeddings for {len(texts)} chunks...") | |
embeddings = self.create_embeddings(texts) | |
# Build FAISS index | |
if show_progress: | |
print("Building FAISS index...") | |
# Use IndexFlatIP for inner product (cosine similarity with normalized vectors) | |
self.index = faiss.IndexFlatIP(self.dimension) | |
# Normalize embeddings for cosine similarity | |
faiss.normalize_L2(embeddings) | |
# Add to index | |
self.index.add(embeddings) | |
# Store chunks and maintain mapping | |
self.chunks = {} | |
self.chunk_ids = [] | |
for chunk in chunks: | |
chunk_id = chunk['chunk_id'] | |
self.chunks[chunk_id] = chunk | |
self.chunk_ids.append(chunk_id) | |
if show_progress: | |
print(f"Index built with {len(chunks)} chunks") | |
def search(self, query: str, top_k: int = 5, score_threshold: float = 0.3) -> List[SearchResult]: | |
"""Search for similar chunks""" | |
if not self.index or not self.chunks: | |
return [] | |
# Create query embedding | |
query_embedding = self.create_embeddings([query]) | |
# Normalize for cosine similarity | |
faiss.normalize_L2(query_embedding) | |
# Search | |
scores, indices = self.index.search(query_embedding, min(top_k, len(self.chunks))) | |
# Convert to results | |
results = [] | |
for score, idx in zip(scores[0], indices[0]): | |
if idx < 0 or score < score_threshold: | |
continue | |
chunk_id = self.chunk_ids[idx] | |
chunk = self.chunks[chunk_id] | |
result = SearchResult( | |
chunk_id=chunk_id, | |
text=chunk['text'], | |
score=float(score), | |
metadata=chunk.get('metadata', {}) | |
) | |
results.append(result) | |
return results | |
def serialize(self) -> Dict[str, Any]: | |
"""Serialize the vector store for deployment""" | |
if not self.index: | |
raise ValueError("No index to serialize") | |
# Serialize FAISS index | |
index_bytes = faiss.serialize_index(self.index) | |
index_base64 = base64.b64encode(index_bytes).decode('utf-8') | |
return { | |
'index_base64': index_base64, | |
'chunks': self.chunks, | |
'chunk_ids': self.chunk_ids, | |
'dimension': self.dimension, | |
'model_name': self.embedding_model_name | |
} | |
def deserialize(cls, data: Dict[str, Any]) -> 'VectorStore': | |
"""Deserialize a vector store from deployment data""" | |
if not HAS_FAISS: | |
raise ImportError("faiss-cpu not installed") | |
store = cls(embedding_model=data['model_name']) | |
# Deserialize FAISS index | |
index_bytes = base64.b64decode(data['index_base64']) | |
store.index = faiss.deserialize_index(index_bytes) | |
# Restore chunks and mappings | |
store.chunks = data['chunks'] | |
store.chunk_ids = data['chunk_ids'] | |
store.dimension = data['dimension'] | |
return store | |
def get_stats(self) -> Dict[str, Any]: | |
"""Get statistics about the vector store""" | |
return { | |
'total_chunks': len(self.chunks), | |
'index_size': self.index.ntotal if self.index else 0, | |
'dimension': self.dimension, | |
'model': self.embedding_model_name | |
} | |
class LightweightVectorStore: | |
"""Lightweight version for deployed spaces without embedding model""" | |
def __init__(self, serialized_data: Dict[str, Any]): | |
if not HAS_FAISS: | |
raise ImportError("faiss-cpu not installed") | |
# Deserialize FAISS index | |
index_bytes = base64.b64decode(serialized_data['index_base64']) | |
self.index = faiss.deserialize_index(index_bytes) | |
# Restore chunks and mappings | |
self.chunks = serialized_data['chunks'] | |
self.chunk_ids = serialized_data['chunk_ids'] | |
self.dimension = serialized_data['dimension'] | |
# For query embedding, we'll need to include pre-computed embeddings | |
# or use a lightweight embedding service | |
self.query_embeddings_cache = serialized_data.get('query_embeddings_cache', {}) | |
def search_with_embedding(self, query_embedding: np.ndarray, top_k: int = 5, score_threshold: float = 0.3) -> List[SearchResult]: | |
"""Search using pre-computed query embedding""" | |
if not self.index or not self.chunks: | |
return [] | |
# Normalize for cosine similarity | |
faiss.normalize_L2(query_embedding) | |
# Search | |
scores, indices = self.index.search(query_embedding, min(top_k, len(self.chunks))) | |
# Convert to results | |
results = [] | |
for score, idx in zip(scores[0], indices[0]): | |
if idx < 0 or score < score_threshold: | |
continue | |
chunk_id = self.chunk_ids[idx] | |
chunk = self.chunks[chunk_id] | |
result = SearchResult( | |
chunk_id=chunk_id, | |
text=chunk['text'], | |
score=float(score), | |
metadata=chunk.get('metadata', {}) | |
) | |
results.append(result) | |
return results | |
# Utility functions | |
def estimate_index_size(num_chunks: int, dimension: int = 384) -> float: | |
"""Estimate the size of the index in MB""" | |
# Rough estimation: 4 bytes per float * dimension * num_chunks | |
bytes_size = 4 * dimension * num_chunks | |
# Add overhead for index structure and metadata | |
overhead = 1.2 | |
return (bytes_size * overhead) / (1024 * 1024) |