File size: 21,995 Bytes
ac0f906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e5b8f8
ac0f906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e83f5e9
 
 
 
 
 
 
 
 
ac0f906
 
 
e83f5e9
 
 
 
 
 
 
 
 
 
 
 
 
ac0f906
 
 
e83f5e9
 
 
ac0f906
e83f5e9
ac0f906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e5b8f8
e83f5e9
ac0f906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e83f5e9
ac0f906
 
 
 
e83f5e9
ac0f906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e83f5e9
ac0f906
 
 
 
 
 
 
e83f5e9
ac0f906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e83f5e9
ac0f906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e83f5e9
ac0f906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e83f5e9
 
ac0f906
 
 
 
 
 
 
 
 
 
e83f5e9
ac0f906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e83f5e9
 
ac0f906
 
 
 
 
 
 
 
 
 
 
 
 
e83f5e9
 
ac0f906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
import os
from pinecone import Pinecone
from dotenv import load_dotenv
import logging
from typing import Optional, List, Dict, Any, Union, Tuple
import time
from langchain_google_genai import GoogleGenerativeAIEmbeddings
import google.generativeai as genai
from langchain_core.retrievers import BaseRetriever
from langchain.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field

# Configure logging
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv()

# Pinecone API key and index name
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")

# Pinecone retrieval configuration
DEFAULT_LIMIT_K = int(os.getenv("PINECONE_DEFAULT_LIMIT_K", "10"))
DEFAULT_TOP_K = int(os.getenv("PINECONE_DEFAULT_TOP_K", "6"))
DEFAULT_SIMILARITY_METRIC = os.getenv("PINECONE_DEFAULT_SIMILARITY_METRIC", "cosine")
DEFAULT_SIMILARITY_THRESHOLD = float(os.getenv("PINECONE_DEFAULT_SIMILARITY_THRESHOLD", "0.75"))
ALLOWED_METRICS = os.getenv("PINECONE_ALLOWED_METRICS", "cosine,dotproduct,euclidean").split(",")

# Export constants for importing elsewhere
__all__ = [
    'get_pinecone_index', 
    'check_db_connection', 
    'search_vectors', 
    'upsert_vectors', 
    'delete_vectors', 
    'fetch_metadata',
    'get_chain',
    'DEFAULT_TOP_K',
    'DEFAULT_LIMIT_K',
    'DEFAULT_SIMILARITY_METRIC',
    'DEFAULT_SIMILARITY_THRESHOLD',
    'ALLOWED_METRICS',
    'ThresholdRetriever'
]

# Configure Google API
if GOOGLE_API_KEY:
    genai.configure(api_key=GOOGLE_API_KEY)

# Initialize global variables to store instances of Pinecone and index
pc = None
index = None
_retriever_instance = None

# Check environment variables
if not PINECONE_API_KEY:
    logger.error("PINECONE_API_KEY is not set in environment variables")

if not PINECONE_INDEX_NAME:
    logger.error("PINECONE_INDEX_NAME is not set in environment variables")

# Initialize Pinecone
def init_pinecone():
    """Initialize pinecone connection using new API"""
    global pc, index
    
    try:
        # Only initialize if not already initialized
        if pc is None:
            logger.info(f"Initializing Pinecone connection to index {PINECONE_INDEX_NAME}...")
            
            # Check if API key and index name are set
            if not PINECONE_API_KEY:
                logger.error("PINECONE_API_KEY is not set in environment variables")
                return None
                
            if not PINECONE_INDEX_NAME:
                logger.error("PINECONE_INDEX_NAME is not set in environment variables")
                return None
            
            # Initialize Pinecone client using the new API
            pc = Pinecone(api_key=PINECONE_API_KEY)
            
            try:
                # Check if index exists
                index_list = pc.list_indexes()
                
                if not hasattr(index_list, 'names') or PINECONE_INDEX_NAME not in index_list.names():
                    logger.error(f"Index {PINECONE_INDEX_NAME} does not exist in Pinecone")
                    return None
                
                # Get existing index
                index = pc.Index(PINECONE_INDEX_NAME)
                logger.info(f"Pinecone connection established to index {PINECONE_INDEX_NAME}")
            except Exception as connection_error:
                logger.error(f"Error connecting to Pinecone index: {connection_error}")
                return None
            
        return index
    except ImportError as e:
        logger.error(f"Required package for Pinecone is missing: {e}")
        return None
    except Exception as e:
        logger.error(f"Unexpected error initializing Pinecone: {e}")
        return None

# Get Pinecone index singleton
def get_pinecone_index():
    """Get Pinecone index"""
    global index
    if index is None:
        index = init_pinecone()
    return index

# Check Pinecone connection
def check_db_connection():
    """Check Pinecone connection"""
    try:
        pinecone_index = get_pinecone_index()
        if pinecone_index is None:
            return False
            
        # Check index information to confirm connection is working
        stats = pinecone_index.describe_index_stats()
        
        # Get total vector count from the new result structure
        total_vectors = stats.get('total_vector_count', 0)
        if hasattr(stats, 'namespaces'):
            # If there are namespaces, calculate total vector count from namespaces
            total_vectors = sum(ns.get('vector_count', 0) for ns in stats.namespaces.values())
            
        logger.info(f"Pinecone connection is working. Total vectors: {total_vectors}")
        return True
    except Exception as e:
        logger.error(f"Error in Pinecone connection: {e}")
        return False

# Convert similarity score based on the metric
def convert_score(score: float, metric: str) -> float:
    """
    Convert similarity score to a 0-1 scale based on the metric used.
    For metrics like euclidean distance where lower is better, we invert the score.
    
    Args:
        score: The raw similarity score
        metric: The similarity metric used
        
    Returns:
        A normalized score between 0-1 where higher means more similar
    """
    if metric.lower() in ["euclidean", "l2"]:
        # For distance metrics (lower is better), we inverse and normalize
        # Assuming max reasonable distance is 2.0 for normalized vectors
        return max(0, 1 - (score / 2.0))
    else:
        # For cosine and dot product (higher is better), return as is
        return score

# Filter results based on similarity threshold
def filter_by_threshold(results, threshold: float, metric: str) -> List[Dict]:
    """
    Filter query results based on similarity threshold.
    
    Args:
        results: The query results from Pinecone
        threshold: The similarity threshold (0-1)
        metric: The similarity metric used
        
    Returns:
        Filtered list of matches
    """
    filtered_matches = []
    
    if not hasattr(results, 'matches'):
        return filtered_matches
        
    for match in results.matches:
        # Get the score
        score = getattr(match, 'score', 0)
        
        # Convert score based on metric
        normalized_score = convert_score(score, metric)
        
        # Filter based on threshold
        if normalized_score >= threshold:
            # Add normalized score as an additional attribute
            match.normalized_score = normalized_score
            filtered_matches.append(match)
    
    return filtered_matches

# Search vectors in Pinecone with advanced options
async def search_vectors(
    query_vector, 
    top_k: int = DEFAULT_TOP_K,
    limit_k: int = DEFAULT_LIMIT_K,
    similarity_metric: str = DEFAULT_SIMILARITY_METRIC,
    similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
    namespace: str = "Default", 
    filter: Optional[Dict] = None
) -> Dict:
    """
    Search for most similar vectors in Pinecone with advanced filtering options.
    
    Args:
        query_vector: The query vector
        top_k: Number of results to return (after threshold filtering)
        limit_k: Maximum number of results to retrieve from Pinecone
        similarity_metric: Similarity metric to use (cosine, dotproduct, euclidean)
        similarity_threshold: Threshold for similarity (0-1)
        namespace: Namespace to search in
        filter: Filter query
        
    Returns:
        Search results with matches filtered by threshold
    """
    try:
        # Validate parameters
        if similarity_metric not in ALLOWED_METRICS:
            logger.warning(f"Invalid similarity metric: {similarity_metric}. Using default: {DEFAULT_SIMILARITY_METRIC}")
            similarity_metric = DEFAULT_SIMILARITY_METRIC
            
        if limit_k < top_k:
            logger.warning(f"limit_k ({limit_k}) must be greater than or equal to top_k ({top_k}). Setting limit_k to {top_k}")
            limit_k = top_k
        
        # Perform search directly without cache
        pinecone_index = get_pinecone_index()
        if pinecone_index is None:
            logger.error("Failed to get Pinecone index for search")
            return None
            
        # Query Pinecone with the provided metric and higher limit_k to allow for threshold filtering
        results = pinecone_index.query(
            vector=query_vector,
            top_k=limit_k,  # Retrieve more results than needed to allow for threshold filtering
            namespace=namespace,
            filter=filter,
            include_metadata=True,
            include_values=False,  # No need to return vector values to save bandwidth
            metric=similarity_metric  # Specify similarity metric
        )
        
        # Filter results by threshold
        filtered_matches = filter_by_threshold(results, similarity_threshold, similarity_metric)
        
        # Limit to top_k after filtering
        filtered_matches = filtered_matches[:top_k]
        
        # Create a new results object with filtered matches
        results.matches = filtered_matches
        
        # Log search result metrics
        match_count = len(filtered_matches)
        logger.info(f"Pinecone search returned {match_count} matches after threshold filtering (metric: {similarity_metric}, threshold: {similarity_threshold}, namespace: {namespace})")
        
        return results
    except Exception as e:
        logger.error(f"Error searching vectors: {e}")
        return None

# Upsert vectors to Pinecone
async def upsert_vectors(vectors, namespace="Default"):
    """Upsert vectors to Pinecone index"""
    try:
        pinecone_index = get_pinecone_index()
        if pinecone_index is None:
            logger.error("Failed to get Pinecone index for upsert")
            return None
            
        response = pinecone_index.upsert(
            vectors=vectors,
            namespace=namespace
        )
        
        # Log upsert metrics
        upserted_count = response.get('upserted_count', 0)
        logger.info(f"Upserted {upserted_count} vectors to Pinecone")
        
        return response
    except Exception as e:
        logger.error(f"Error upserting vectors: {e}")
        return None

# Delete vectors from Pinecone
async def delete_vectors(ids, namespace="Default"):
    """Delete vectors from Pinecone index"""
    try:
        pinecone_index = get_pinecone_index()
        if pinecone_index is None:
            logger.error("Failed to get Pinecone index for delete")
            return False
            
        response = pinecone_index.delete(
            ids=ids,
            namespace=namespace
        )
        
        logger.info(f"Deleted vectors with IDs {ids} from Pinecone")
        return True
    except Exception as e:
        logger.error(f"Error deleting vectors: {e}")
        return False

# Fetch vector metadata from Pinecone
async def fetch_metadata(ids, namespace="Default"):
    """Fetch metadata for specific vector IDs"""
    try:
        pinecone_index = get_pinecone_index()
        if pinecone_index is None:
            logger.error("Failed to get Pinecone index for fetch")
            return None
            
        response = pinecone_index.fetch(
            ids=ids,
            namespace=namespace
        )
        
        return response
    except Exception as e:
        logger.error(f"Error fetching vector metadata: {e}")
        return None

# Create a custom retriever class for Langchain integration
class ThresholdRetriever(BaseRetriever):
    """
    Custom retriever that supports threshold-based filtering and multiple similarity metrics.
    This integrates with the Langchain ecosystem while using our advanced retrieval logic.
    """
    
    vectorstore: Any = Field(description="Vector store to use for retrieval")
    embeddings: Any = Field(description="Embeddings model to use for retrieval")
    search_kwargs: Dict[str, Any] = Field(default_factory=dict, description="Search kwargs for the vectorstore")
    top_k: int = Field(default=DEFAULT_TOP_K, description="Number of results to return after filtering")
    limit_k: int = Field(default=DEFAULT_LIMIT_K, description="Maximum number of results to retrieve from Pinecone")
    similarity_metric: str = Field(default=DEFAULT_SIMILARITY_METRIC, description="Similarity metric to use")
    similarity_threshold: float = Field(default=DEFAULT_SIMILARITY_THRESHOLD, description="Threshold for similarity")
    namespace: str = "Default"

    class Config:
        """Configuration for this pydantic object."""
        arbitrary_types_allowed = True
    
    async def search_vectors_sync(
        self, query_vector, 
        top_k: int = DEFAULT_TOP_K,
        limit_k: int = DEFAULT_LIMIT_K,
        similarity_metric: str = DEFAULT_SIMILARITY_METRIC,
        similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
        namespace: str = "Default", 
        filter: Optional[Dict] = None
    ) -> Dict:
        """Synchronous wrapper for search_vectors"""
        import asyncio
        try:
            # Get current event loop or create a new one
            try:
                loop = asyncio.get_event_loop()
            except RuntimeError:
                loop = asyncio.new_event_loop()
                asyncio.set_event_loop(loop)
            
            # Use event loop to run async function
            if loop.is_running():
                # If we're in an event loop, use asyncio.create_task
                task = asyncio.create_task(search_vectors(
                    query_vector=query_vector,
                    top_k=top_k,
                    limit_k=limit_k,
                    similarity_metric=similarity_metric,
                    similarity_threshold=similarity_threshold,
                    namespace=namespace,
                    filter=filter
                ))
                return await task
            else:
                # If not in an event loop, just await directly
                return await search_vectors(
                    query_vector=query_vector,
                    top_k=top_k,
                    limit_k=limit_k,
                    similarity_metric=similarity_metric,
                    similarity_threshold=similarity_threshold,
                    namespace=namespace,
                    filter=filter
                )
        except Exception as e:
            logger.error(f"Error in search_vectors_sync: {e}")
            return None

    def _get_relevant_documents(
        self, query: str, *, run_manager: Callbacks = None
    ) -> List[Document]:
        """
        Get documents relevant to the query using threshold-based retrieval.
        
        Args:
            query: The query string
            run_manager: The callbacks manager
            
        Returns:
            List of relevant documents
        """
        # Generate embedding for query using the embeddings model
        try:
            # Use the embeddings model we stored in the class
            embedding = self.embeddings.embed_query(query)
        except Exception as e:
            logger.error(f"Error generating embedding: {e}")
            # Fallback to creating a new embedding model if needed
            embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
            embedding = embedding_model.embed_query(query)
        
        # Perform search with advanced options - avoid asyncio.run()
        import asyncio
        
        # Get or create event loop
        try:
            loop = asyncio.get_event_loop()
        except RuntimeError:
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
            
        # Run asynchronous search in a safe way
        if loop.is_running():
            # We're inside an existing event loop (like in FastAPI)
            # Use a different approach - convert it to a synchronous call
            from concurrent.futures import ThreadPoolExecutor
            import functools
            
            # Define a wrapper function to run in a thread
            def run_async_in_thread():
                # Create a new event loop for this thread
                thread_loop = asyncio.new_event_loop()
                asyncio.set_event_loop(thread_loop)
                # Run the coroutine and return the result
                return thread_loop.run_until_complete(search_vectors(
                    query_vector=embedding,
                    top_k=self.top_k,
                    limit_k=self.limit_k,
                    similarity_metric=self.similarity_metric,
                    similarity_threshold=self.similarity_threshold,
                    namespace=self.namespace,
                    # filter=self.search_kwargs.get("filter", None)
                ))
            
            # Run the async function in a thread
            with ThreadPoolExecutor() as executor:
                search_result = executor.submit(run_async_in_thread).result()
        else:
            # No event loop running, we can use run_until_complete
            search_result = loop.run_until_complete(search_vectors(
                query_vector=embedding,
                top_k=self.top_k,
                limit_k=self.limit_k,
                similarity_metric=self.similarity_metric,
                similarity_threshold=self.similarity_threshold,
                namespace=self.namespace,
                # filter=self.search_kwargs.get("filter", None)
            ))
        
        # Convert to documents
        documents = []
        if search_result and hasattr(search_result, 'matches'):
            for match in search_result.matches:
                # Extract metadata
                metadata = {}
                if hasattr(match, 'metadata'):
                    metadata = match.metadata
                
                # Add score to metadata
                score = getattr(match, 'score', 0)
                normalized_score = getattr(match, 'normalized_score', score)
                metadata['score'] = score
                metadata['normalized_score'] = normalized_score
                
                # Extract text
                text = metadata.get('text', '')
                if 'text' in metadata:
                    del metadata['text']  # Remove from metadata since it's the content
                
                # Create Document
                doc = Document(
                    page_content=text,
                    metadata=metadata
                )
                documents.append(doc)
        
        return documents

# Get the retrieval chain with Pinecone vector store
def get_chain(
    index_name=PINECONE_INDEX_NAME, 
    namespace="Default", 
    top_k=DEFAULT_TOP_K, 
    limit_k=DEFAULT_LIMIT_K,
    similarity_metric=DEFAULT_SIMILARITY_METRIC, 
    similarity_threshold=DEFAULT_SIMILARITY_THRESHOLD
):
    """
    Get the retrieval chain with Pinecone vector store using threshold-based retrieval.
    
    Args:
        index_name: Pinecone index name
        namespace: Pinecone namespace
        top_k: Number of results to return after filtering
        limit_k: Maximum number of results to retrieve from Pinecone
        similarity_metric: Similarity metric to use (cosine, dotproduct, euclidean)
        similarity_threshold: Threshold for similarity (0-1)
        
    Returns:
        ThresholdRetriever instance
    """
    global _retriever_instance
    try:
        # If already initialized with same parameters, return cached instance
        if _retriever_instance is not None:
            return _retriever_instance
            
        start_time = time.time()
        logger.info("Initializing new retriever chain with threshold-based filtering")
        
        # Initialize embeddings model
        embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
        
        # Get index
        pinecone_index = get_pinecone_index()
        if not pinecone_index:
            logger.error("Failed to get Pinecone index for retriever chain")
            return None
            
        # Get statistics for logging
        try:
            stats = pinecone_index.describe_index_stats()
            total_vectors = stats.get('total_vector_count', 0)
            logger.info(f"Pinecone index stats - Total vectors: {total_vectors}")
        except Exception as e:
            logger.error(f"Error getting index stats: {e}")
        
        # Use Pinecone from langchain_community.vectorstores
        from langchain_community.vectorstores import Pinecone as LangchainPinecone
        
        logger.info(f"Creating Pinecone vectorstore with index: {index_name}, namespace: {namespace}")
        vectorstore = LangchainPinecone.from_existing_index(
            embedding=embeddings,
            index_name=index_name,
            namespace=namespace,
            text_key="text" 
        )
        
        # Create threshold-based retriever
        logger.info(f"Creating ThresholdRetriever with top_k={top_k}, limit_k={limit_k}, " +
                    f"metric={similarity_metric}, threshold={similarity_threshold}")
        
        # Create ThresholdRetriever with both vectorstore and embeddings
        _retriever_instance = ThresholdRetriever(
            vectorstore=vectorstore,
            embeddings=embeddings,  # Pass embeddings separately
            top_k=top_k,
            limit_k=limit_k,
            similarity_metric=similarity_metric,
            similarity_threshold=similarity_threshold
        )
        
        logger.info(f"Pinecone retriever initialized in {time.time() - start_time:.2f} seconds")
        
        return _retriever_instance
    except Exception as e:
        logger.error(f"Error creating retrieval chain: {e}")
        return None