CA-Foundation / backend /vector_store.py
“vinit5112”
post changes
8146726
raw
history blame
10.1 kB
from qdrant_client import QdrantClient, models
from qdrant_client.models import PointStruct, PayloadSchemaType
from sentence_transformers import SentenceTransformer
import uuid
import os
import logging
from typing import List, Dict, Any
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Configure logging
logger = logging.getLogger(__name__)
class VectorStore:
def __init__(self):
self.collection_name = "ca-documents"
# Get Qdrant configuration from environment variables
qdrant_url = os.getenv("QDRANT_URL")
qdrant_api_key = os.getenv("QDRANT_API_KEY")
if not qdrant_url or not qdrant_api_key:
raise ValueError("QDRANT_URL and QDRANT_API_KEY environment variables are required")
# Connect to Qdrant cluster with API key
self.client = QdrantClient(
url=qdrant_url,
api_key=qdrant_api_key,
)
print("Connected to Qdrant")
# Initialize embedding model with offline support
self.embedding_model = self._initialize_embedding_model()
# Create collection with proper indices
self._create_collection_if_not_exists()
def _initialize_embedding_model(self):
"""Initialize the embedding model from a local directory"""
try:
print("Loading sentence transformer model from local path...")
# Resolve local path to model directory
current_dir = os.path.dirname(os.path.abspath(__file__))
local_model_path = os.path.join(current_dir, "..", "model", "all-MiniLM-L6-v2")
model = SentenceTransformer(local_model_path)
print("Successfully loaded local sentence transformer model")
return model
except Exception as e:
print(f"Failed to load local model: {e}")
raise RuntimeError("Failed to initialize embedding model from local path")
# def _initialize_embedding_model(self):
# """Initialize the embedding model with offline support"""
# try:
# # Try to load the model normally first
# print("Attempting to load sentence transformer model...")
# model = SentenceTransformer("all-MiniLM-L6-v2")
# print("Successfully loaded sentence transformer model")
# return model
# except Exception as e:
# print(f"Failed to load model online: {e}")
# print("Attempting to load model in offline mode...")
# try:
# # Try to load from cache with offline mode
# import os
# os.environ['TRANSFORMERS_OFFLINE'] = '1'
# os.environ['HF_HUB_OFFLINE'] = '1'
# model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=None)
# print("Successfully loaded model in offline mode")
# return model
# except Exception as offline_error:
# print(f"Failed to load model in offline mode: {offline_error}")
# # Try to find a local cache directory
# try:
# import transformers
# cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "transformers")
# if os.path.exists(cache_dir):
# print(f"Looking for cached model in: {cache_dir}")
# # Try to load from specific cache directory
# model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=cache_dir)
# print("Successfully loaded model from cache")
# return model
# except Exception as cache_error:
# print(f"Failed to load from cache: {cache_error}")
# # If all else fails, provide instructions
# error_msg = """
# Failed to initialize sentence transformer model. This is likely due to network connectivity issues.
# Solutions:
# 1. Check your internet connection
# 2. If behind a corporate firewall, ensure huggingface.co is accessible
# 3. Pre-download the model when you have internet access by running:
# python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"
# 4. Or manually download the model and place it in your cache directory
# For now, the application will not work without the embedding model.
# """
# print(error_msg)
# raise RuntimeError(f"Cannot initialize embedding model: {str(e)}")
def _create_collection_if_not_exists(self) -> bool:
"""
Create collection with proper payload indices if it doesn't exist.
Returns:
bool: True if collection exists or was created successfully
"""
try:
# Check if collection exists
collections = self.client.get_collections()
collection_names = [col.name for col in collections.collections]
print("list of collections : ", collection_names)
if self.collection_name in collection_names:
print(f"Collection '{self.collection_name}' already exists")
return True
print(f"Creating new collection: {self.collection_name}")
# Vector size for all-MiniLM-L6-v2 is 384
vector_size = 384
# Create collection with vector configuration
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE
),
hnsw_config=models.HnswConfigDiff(
payload_m=16,
m=0,
),
)
# Create payload indices
payload_indices = {
"document_id": PayloadSchemaType.KEYWORD,
"content": PayloadSchemaType.TEXT
}
for field_name, schema_type in payload_indices.items():
self.client.create_payload_index(
collection_name=self.collection_name,
field_name=field_name,
field_schema=schema_type
)
print(f"Successfully created collection: {self.collection_name}")
return True
except Exception as e:
error_msg = f"Failed to create collection {self.collection_name}: {str(e)}"
logger.error(error_msg, exc_info=True)
print(error_msg)
return False
def add_document(self, text: str, metadata: Dict = None) -> bool:
"""Add a document to the collection"""
try:
# Generate embedding
embedding = self.embedding_model.encode([text])[0]
# Generate document ID
document_id = str(uuid.uuid4())
# Create payload with indexed fields
payload = {
"document_id": document_id, # KEYWORD index
"content": text, # TEXT index - stores the actual text content
}
# Add metadata fields if provided
if metadata:
payload.update(metadata)
# Create point
point = PointStruct(
id=document_id,
vector=embedding.tolist(),
payload=payload
)
# Store in Qdrant
self.client.upsert(
collection_name=self.collection_name,
points=[point]
)
return True
except Exception as e:
print(f"Error adding document: {e}")
return False
def search_similar(self, query: str, limit: int = 5) -> List[Dict]:
"""Search for similar documents"""
try:
# Generate query embedding
query_embedding = self.embedding_model.encode([query])[0]
# Search in Qdrant
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_embedding.tolist(),
limit=limit
)
# Return results
return [
{
"text": hit.payload["content"], # Use content field
"document_id": hit.payload.get("document_id"),
"score": hit.score,
# Include any additional metadata fields
**{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]}
}
for hit in results
]
except Exception as e:
print(f"Error searching: {e}")
return []
def get_collection_info(self) -> Dict:
"""Get information about the collection"""
try:
collection_info = self.client.get_collection(self.collection_name)
return {
"name": collection_info.config.name,
"vector_size": collection_info.config.params.vectors.size,
"distance": collection_info.config.params.vectors.distance,
"points_count": collection_info.points_count,
"indexed_only": collection_info.config.params.vectors.on_disk
}
except Exception as e:
print(f"Error getting collection info: {e}")
return {}