from qdrant_client import QdrantClient, models from qdrant_client.models import PointStruct, PayloadSchemaType from sentence_transformers import SentenceTransformer import uuid import os import logging from typing import List, Dict, Any from dotenv import load_dotenv # Load environment variables load_dotenv() # Configure logging logger = logging.getLogger(__name__) class VectorStore: def __init__(self): self.collection_name = "ca-documents" # Get Qdrant configuration from environment variables qdrant_url = os.getenv("QDRANT_URL") qdrant_api_key = os.getenv("QDRANT_API_KEY") if not qdrant_url or not qdrant_api_key: raise ValueError("QDRANT_URL and QDRANT_API_KEY environment variables are required") # Connect to Qdrant cluster with API key self.client = QdrantClient( url=qdrant_url, api_key=qdrant_api_key, ) print("Connected to Qdrant") # Initialize embedding model with offline support self.embedding_model = self._initialize_embedding_model() # Create collection with proper indices self._create_collection_if_not_exists() def _initialize_embedding_model(self): """Initialize the embedding model from a local directory""" try: print("Loading sentence transformer model from local path...") # Resolve local path to model directory current_dir = os.path.dirname(os.path.abspath(__file__)) local_model_path = os.path.join(current_dir, "..", "model", "all-MiniLM-L6-v2") model = SentenceTransformer(local_model_path) print("Successfully loaded local sentence transformer model") return model except Exception as e: print(f"Failed to load local model: {e}") raise RuntimeError("Failed to initialize embedding model from local path") # def _initialize_embedding_model(self): # """Initialize the embedding model with offline support""" # try: # # Try to load the model normally first # print("Attempting to load sentence transformer model...") # model = SentenceTransformer("all-MiniLM-L6-v2") # print("Successfully loaded sentence transformer model") # return model # except Exception as e: # print(f"Failed to load model online: {e}") # print("Attempting to load model in offline mode...") # try: # # Try to load from cache with offline mode # import os # os.environ['TRANSFORMERS_OFFLINE'] = '1' # os.environ['HF_HUB_OFFLINE'] = '1' # model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=None) # print("Successfully loaded model in offline mode") # return model # except Exception as offline_error: # print(f"Failed to load model in offline mode: {offline_error}") # # Try to find a local cache directory # try: # import transformers # cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "transformers") # if os.path.exists(cache_dir): # print(f"Looking for cached model in: {cache_dir}") # # Try to load from specific cache directory # model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=cache_dir) # print("Successfully loaded model from cache") # return model # except Exception as cache_error: # print(f"Failed to load from cache: {cache_error}") # # If all else fails, provide instructions # error_msg = """ # Failed to initialize sentence transformer model. This is likely due to network connectivity issues. # Solutions: # 1. Check your internet connection # 2. If behind a corporate firewall, ensure huggingface.co is accessible # 3. Pre-download the model when you have internet access by running: # python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')" # 4. Or manually download the model and place it in your cache directory # For now, the application will not work without the embedding model. # """ # print(error_msg) # raise RuntimeError(f"Cannot initialize embedding model: {str(e)}") def _create_collection_if_not_exists(self) -> bool: """ Create collection with proper payload indices if it doesn't exist. Returns: bool: True if collection exists or was created successfully """ try: # Check if collection exists collections = self.client.get_collections() collection_names = [col.name for col in collections.collections] print("list of collections : ", collection_names) if self.collection_name in collection_names: print(f"Collection '{self.collection_name}' already exists") return True print(f"Creating new collection: {self.collection_name}") # Vector size for all-MiniLM-L6-v2 is 384 vector_size = 384 # Create collection with vector configuration self.client.create_collection( collection_name=self.collection_name, vectors_config=models.VectorParams( size=vector_size, distance=models.Distance.COSINE ), hnsw_config=models.HnswConfigDiff( payload_m=16, m=0, ), ) # Create payload indices payload_indices = { "document_id": PayloadSchemaType.KEYWORD, "content": PayloadSchemaType.TEXT } for field_name, schema_type in payload_indices.items(): self.client.create_payload_index( collection_name=self.collection_name, field_name=field_name, field_schema=schema_type ) print(f"Successfully created collection: {self.collection_name}") return True except Exception as e: error_msg = f"Failed to create collection {self.collection_name}: {str(e)}" logger.error(error_msg, exc_info=True) print(error_msg) return False def add_document(self, text: str, metadata: Dict = None) -> bool: """Add a document to the collection""" try: # Generate embedding embedding = self.embedding_model.encode([text])[0] # Generate document ID document_id = str(uuid.uuid4()) # Create payload with indexed fields payload = { "document_id": document_id, # KEYWORD index "content": text, # TEXT index - stores the actual text content } # Add metadata fields if provided if metadata: payload.update(metadata) # Create point point = PointStruct( id=document_id, vector=embedding.tolist(), payload=payload ) # Store in Qdrant self.client.upsert( collection_name=self.collection_name, points=[point] ) return True except Exception as e: print(f"Error adding document: {e}") return False def search_similar(self, query: str, limit: int = 5) -> List[Dict]: """Search for similar documents""" try: # Generate query embedding query_embedding = self.embedding_model.encode([query])[0] # Search in Qdrant results = self.client.search( collection_name=self.collection_name, query_vector=query_embedding.tolist(), limit=limit ) # Return results return [ { "text": hit.payload["content"], # Use content field "document_id": hit.payload.get("document_id"), "score": hit.score, # Include any additional metadata fields **{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]} } for hit in results ] except Exception as e: print(f"Error searching: {e}") return [] def get_collection_info(self) -> Dict: """Get information about the collection""" try: collection_info = self.client.get_collection(self.collection_name) return { "name": collection_info.config.name, "vector_size": collection_info.config.params.vectors.size, "distance": collection_info.config.params.vectors.distance, "points_count": collection_info.points_count, "indexed_only": collection_info.config.params.vectors.on_disk } except Exception as e: print(f"Error getting collection info: {e}") return {}