Spaces:
Sleeping
Sleeping
File size: 10,098 Bytes
deb090d 0972444 f2611d0 deb090d aff287e deb090d aff287e deb090d aff287e deb090d aff287e deb090d aff287e deb090d aff287e deb090d aff287e deb090d aff287e deb090d aff287e deb090d aff287e deb090d aff287e deb090d aff287e deb090d aff287e deb090d 0972444 deb090d 0972444 deb090d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
from qdrant_client import QdrantClient, models
from qdrant_client.models import PointStruct, PayloadSchemaType
from sentence_transformers import SentenceTransformer
import uuid
import os
import logging
from typing import List, Dict, Any
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Configure logging
logger = logging.getLogger(__name__)
class VectorStore:
def __init__(self):
self.collection_name = "ca-documents"
# Get Qdrant configuration from environment variables
qdrant_url = os.getenv("QDRANT_URL")
qdrant_api_key = os.getenv("QDRANT_API_KEY")
if not qdrant_url or not qdrant_api_key:
raise ValueError("QDRANT_URL and QDRANT_API_KEY environment variables are required")
# Connect to Qdrant cluster with API key
self.client = QdrantClient(
url=qdrant_url,
api_key=qdrant_api_key,
)
print("Connected to Qdrant")
# Initialize embedding model with offline support
self.embedding_model = self._initialize_embedding_model()
# Create collection with proper indices
self._create_collection_if_not_exists()
def _initialize_embedding_model(self):
"""Initialize the embedding model from a local directory"""
try:
print("Loading sentence transformer model from local path...")
# Resolve local path to model directory
current_dir = os.path.dirname(os.path.abspath(__file__))
local_model_path = os.path.join(current_dir, "..", "model", "all-MiniLM-L6-v2")
model = SentenceTransformer(local_model_path)
print("Successfully loaded local sentence transformer model")
return model
except Exception as e:
print(f"Failed to load local model: {e}")
raise RuntimeError("Failed to initialize embedding model from local path")
# def _initialize_embedding_model(self):
# """Initialize the embedding model with offline support"""
# try:
# # Try to load the model normally first
# print("Attempting to load sentence transformer model...")
# model = SentenceTransformer("all-MiniLM-L6-v2")
# print("Successfully loaded sentence transformer model")
# return model
# except Exception as e:
# print(f"Failed to load model online: {e}")
# print("Attempting to load model in offline mode...")
# try:
# # Try to load from cache with offline mode
# import os
# os.environ['TRANSFORMERS_OFFLINE'] = '1'
# os.environ['HF_HUB_OFFLINE'] = '1'
# model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=None)
# print("Successfully loaded model in offline mode")
# return model
# except Exception as offline_error:
# print(f"Failed to load model in offline mode: {offline_error}")
# # Try to find a local cache directory
# try:
# import transformers
# cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "transformers")
# if os.path.exists(cache_dir):
# print(f"Looking for cached model in: {cache_dir}")
# # Try to load from specific cache directory
# model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=cache_dir)
# print("Successfully loaded model from cache")
# return model
# except Exception as cache_error:
# print(f"Failed to load from cache: {cache_error}")
# # If all else fails, provide instructions
# error_msg = """
# Failed to initialize sentence transformer model. This is likely due to network connectivity issues.
# Solutions:
# 1. Check your internet connection
# 2. If behind a corporate firewall, ensure huggingface.co is accessible
# 3. Pre-download the model when you have internet access by running:
# python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"
# 4. Or manually download the model and place it in your cache directory
# For now, the application will not work without the embedding model.
# """
# print(error_msg)
# raise RuntimeError(f"Cannot initialize embedding model: {str(e)}")
def _create_collection_if_not_exists(self) -> bool:
"""
Create collection with proper payload indices if it doesn't exist.
Returns:
bool: True if collection exists or was created successfully
"""
try:
# Check if collection exists
collections = self.client.get_collections()
collection_names = [col.name for col in collections.collections]
print("list of collections : ", collection_names)
if self.collection_name in collection_names:
print(f"Collection '{self.collection_name}' already exists")
return True
print(f"Creating new collection: {self.collection_name}")
# Vector size for all-MiniLM-L6-v2 is 384
vector_size = 384
# Create collection with vector configuration
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE
),
hnsw_config=models.HnswConfigDiff(
payload_m=16,
m=0,
),
)
# Create payload indices
payload_indices = {
"document_id": PayloadSchemaType.KEYWORD,
"content": PayloadSchemaType.TEXT
}
for field_name, schema_type in payload_indices.items():
self.client.create_payload_index(
collection_name=self.collection_name,
field_name=field_name,
field_schema=schema_type
)
print(f"Successfully created collection: {self.collection_name}")
return True
except Exception as e:
error_msg = f"Failed to create collection {self.collection_name}: {str(e)}"
logger.error(error_msg, exc_info=True)
print(error_msg)
return False
def add_document(self, text: str, metadata: Dict = None) -> bool:
"""Add a document to the collection"""
try:
# Generate embedding
embedding = self.embedding_model.encode([text])[0]
# Generate document ID
document_id = str(uuid.uuid4())
# Create payload with indexed fields
payload = {
"document_id": document_id, # KEYWORD index
"content": text, # TEXT index - stores the actual text content
}
# Add metadata fields if provided
if metadata:
payload.update(metadata)
# Create point
point = PointStruct(
id=document_id,
vector=embedding.tolist(),
payload=payload
)
# Store in Qdrant
self.client.upsert(
collection_name=self.collection_name,
points=[point]
)
return True
except Exception as e:
print(f"Error adding document: {e}")
return False
def search_similar(self, query: str, limit: int = 5) -> List[Dict]:
"""Search for similar documents"""
try:
# Generate query embedding
query_embedding = self.embedding_model.encode([query])[0]
# Search in Qdrant
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_embedding.tolist(),
limit=limit
)
# Return results
return [
{
"text": hit.payload["content"], # Use content field
"document_id": hit.payload.get("document_id"),
"score": hit.score,
# Include any additional metadata fields
**{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]}
}
for hit in results
]
except Exception as e:
print(f"Error searching: {e}")
return []
def get_collection_info(self) -> Dict:
"""Get information about the collection"""
try:
collection_info = self.client.get_collection(self.collection_name)
return {
"name": collection_info.config.name,
"vector_size": collection_info.config.params.vectors.size,
"distance": collection_info.config.params.vectors.distance,
"points_count": collection_info.points_count,
"indexed_only": collection_info.config.params.vectors.on_disk
}
except Exception as e:
print(f"Error getting collection info: {e}")
return {} |