Pix-Agent / app /database /pinecone.py
ManTea's picture
QA to PROD
0e5b8f8
import os
from pinecone import Pinecone
from dotenv import load_dotenv
import logging
from typing import Optional, List, Dict, Any, Union, Tuple
import time
from langchain_google_genai import GoogleGenerativeAIEmbeddings
import google.generativeai as genai
from langchain_core.retrievers import BaseRetriever
from langchain.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field
# Configure logging
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
# Pinecone API key and index name
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
# Pinecone retrieval configuration
DEFAULT_LIMIT_K = int(os.getenv("PINECONE_DEFAULT_LIMIT_K", "10"))
DEFAULT_TOP_K = int(os.getenv("PINECONE_DEFAULT_TOP_K", "6"))
DEFAULT_SIMILARITY_METRIC = os.getenv("PINECONE_DEFAULT_SIMILARITY_METRIC", "cosine")
DEFAULT_SIMILARITY_THRESHOLD = float(os.getenv("PINECONE_DEFAULT_SIMILARITY_THRESHOLD", "0.75"))
ALLOWED_METRICS = os.getenv("PINECONE_ALLOWED_METRICS", "cosine,dotproduct,euclidean").split(",")
# Export constants for importing elsewhere
__all__ = [
'get_pinecone_index',
'check_db_connection',
'search_vectors',
'upsert_vectors',
'delete_vectors',
'fetch_metadata',
'get_chain',
'DEFAULT_TOP_K',
'DEFAULT_LIMIT_K',
'DEFAULT_SIMILARITY_METRIC',
'DEFAULT_SIMILARITY_THRESHOLD',
'ALLOWED_METRICS',
'ThresholdRetriever'
]
# Configure Google API
if GOOGLE_API_KEY:
genai.configure(api_key=GOOGLE_API_KEY)
# Initialize global variables to store instances of Pinecone and index
pc = None
index = None
_retriever_instance = None
# Check environment variables
if not PINECONE_API_KEY:
logger.error("PINECONE_API_KEY is not set in environment variables")
if not PINECONE_INDEX_NAME:
logger.error("PINECONE_INDEX_NAME is not set in environment variables")
# Initialize Pinecone
def init_pinecone():
"""Initialize pinecone connection using new API"""
global pc, index
try:
# Only initialize if not already initialized
if pc is None:
logger.info(f"Initializing Pinecone connection to index {PINECONE_INDEX_NAME}...")
# Check if API key and index name are set
if not PINECONE_API_KEY:
logger.error("PINECONE_API_KEY is not set in environment variables")
return None
if not PINECONE_INDEX_NAME:
logger.error("PINECONE_INDEX_NAME is not set in environment variables")
return None
# Initialize Pinecone client using the new API
pc = Pinecone(api_key=PINECONE_API_KEY)
try:
# Check if index exists
index_list = pc.list_indexes()
if not hasattr(index_list, 'names') or PINECONE_INDEX_NAME not in index_list.names():
logger.error(f"Index {PINECONE_INDEX_NAME} does not exist in Pinecone")
return None
# Get existing index
index = pc.Index(PINECONE_INDEX_NAME)
logger.info(f"Pinecone connection established to index {PINECONE_INDEX_NAME}")
except Exception as connection_error:
logger.error(f"Error connecting to Pinecone index: {connection_error}")
return None
return index
except ImportError as e:
logger.error(f"Required package for Pinecone is missing: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error initializing Pinecone: {e}")
return None
# Get Pinecone index singleton
def get_pinecone_index():
"""Get Pinecone index"""
global index
if index is None:
index = init_pinecone()
return index
# Check Pinecone connection
def check_db_connection():
"""Check Pinecone connection"""
try:
pinecone_index = get_pinecone_index()
if pinecone_index is None:
return False
# Check index information to confirm connection is working
stats = pinecone_index.describe_index_stats()
# Get total vector count from the new result structure
total_vectors = stats.get('total_vector_count', 0)
if hasattr(stats, 'namespaces'):
# If there are namespaces, calculate total vector count from namespaces
total_vectors = sum(ns.get('vector_count', 0) for ns in stats.namespaces.values())
logger.info(f"Pinecone connection is working. Total vectors: {total_vectors}")
return True
except Exception as e:
logger.error(f"Error in Pinecone connection: {e}")
return False
# Convert similarity score based on the metric
def convert_score(score: float, metric: str) -> float:
"""
Convert similarity score to a 0-1 scale based on the metric used.
For metrics like euclidean distance where lower is better, we invert the score.
Args:
score: The raw similarity score
metric: The similarity metric used
Returns:
A normalized score between 0-1 where higher means more similar
"""
if metric.lower() in ["euclidean", "l2"]:
# For distance metrics (lower is better), we inverse and normalize
# Assuming max reasonable distance is 2.0 for normalized vectors
return max(0, 1 - (score / 2.0))
else:
# For cosine and dot product (higher is better), return as is
return score
# Filter results based on similarity threshold
def filter_by_threshold(results, threshold: float, metric: str) -> List[Dict]:
"""
Filter query results based on similarity threshold.
Args:
results: The query results from Pinecone
threshold: The similarity threshold (0-1)
metric: The similarity metric used
Returns:
Filtered list of matches
"""
filtered_matches = []
if not hasattr(results, 'matches'):
return filtered_matches
for match in results.matches:
# Get the score
score = getattr(match, 'score', 0)
# Convert score based on metric
normalized_score = convert_score(score, metric)
# Filter based on threshold
if normalized_score >= threshold:
# Add normalized score as an additional attribute
match.normalized_score = normalized_score
filtered_matches.append(match)
return filtered_matches
# Search vectors in Pinecone with advanced options
async def search_vectors(
query_vector,
top_k: int = DEFAULT_TOP_K,
limit_k: int = DEFAULT_LIMIT_K,
similarity_metric: str = DEFAULT_SIMILARITY_METRIC,
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
namespace: str = "Default",
filter: Optional[Dict] = None
) -> Dict:
"""
Search for most similar vectors in Pinecone with advanced filtering options.
Args:
query_vector: The query vector
top_k: Number of results to return (after threshold filtering)
limit_k: Maximum number of results to retrieve from Pinecone
similarity_metric: Similarity metric to use (cosine, dotproduct, euclidean)
similarity_threshold: Threshold for similarity (0-1)
namespace: Namespace to search in
filter: Filter query
Returns:
Search results with matches filtered by threshold
"""
try:
# Validate parameters
if similarity_metric not in ALLOWED_METRICS:
logger.warning(f"Invalid similarity metric: {similarity_metric}. Using default: {DEFAULT_SIMILARITY_METRIC}")
similarity_metric = DEFAULT_SIMILARITY_METRIC
if limit_k < top_k:
logger.warning(f"limit_k ({limit_k}) must be greater than or equal to top_k ({top_k}). Setting limit_k to {top_k}")
limit_k = top_k
# Perform search directly without cache
pinecone_index = get_pinecone_index()
if pinecone_index is None:
logger.error("Failed to get Pinecone index for search")
return None
# Query Pinecone with the provided metric and higher limit_k to allow for threshold filtering
results = pinecone_index.query(
vector=query_vector,
top_k=limit_k, # Retrieve more results than needed to allow for threshold filtering
namespace=namespace,
filter=filter,
include_metadata=True,
include_values=False, # No need to return vector values to save bandwidth
metric=similarity_metric # Specify similarity metric
)
# Filter results by threshold
filtered_matches = filter_by_threshold(results, similarity_threshold, similarity_metric)
# Limit to top_k after filtering
filtered_matches = filtered_matches[:top_k]
# Create a new results object with filtered matches
results.matches = filtered_matches
# Log search result metrics
match_count = len(filtered_matches)
logger.info(f"Pinecone search returned {match_count} matches after threshold filtering (metric: {similarity_metric}, threshold: {similarity_threshold}, namespace: {namespace})")
return results
except Exception as e:
logger.error(f"Error searching vectors: {e}")
return None
# Upsert vectors to Pinecone
async def upsert_vectors(vectors, namespace="Default"):
"""Upsert vectors to Pinecone index"""
try:
pinecone_index = get_pinecone_index()
if pinecone_index is None:
logger.error("Failed to get Pinecone index for upsert")
return None
response = pinecone_index.upsert(
vectors=vectors,
namespace=namespace
)
# Log upsert metrics
upserted_count = response.get('upserted_count', 0)
logger.info(f"Upserted {upserted_count} vectors to Pinecone")
return response
except Exception as e:
logger.error(f"Error upserting vectors: {e}")
return None
# Delete vectors from Pinecone
async def delete_vectors(ids, namespace="Default"):
"""Delete vectors from Pinecone index"""
try:
pinecone_index = get_pinecone_index()
if pinecone_index is None:
logger.error("Failed to get Pinecone index for delete")
return False
response = pinecone_index.delete(
ids=ids,
namespace=namespace
)
logger.info(f"Deleted vectors with IDs {ids} from Pinecone")
return True
except Exception as e:
logger.error(f"Error deleting vectors: {e}")
return False
# Fetch vector metadata from Pinecone
async def fetch_metadata(ids, namespace="Default"):
"""Fetch metadata for specific vector IDs"""
try:
pinecone_index = get_pinecone_index()
if pinecone_index is None:
logger.error("Failed to get Pinecone index for fetch")
return None
response = pinecone_index.fetch(
ids=ids,
namespace=namespace
)
return response
except Exception as e:
logger.error(f"Error fetching vector metadata: {e}")
return None
# Create a custom retriever class for Langchain integration
class ThresholdRetriever(BaseRetriever):
"""
Custom retriever that supports threshold-based filtering and multiple similarity metrics.
This integrates with the Langchain ecosystem while using our advanced retrieval logic.
"""
vectorstore: Any = Field(description="Vector store to use for retrieval")
embeddings: Any = Field(description="Embeddings model to use for retrieval")
search_kwargs: Dict[str, Any] = Field(default_factory=dict, description="Search kwargs for the vectorstore")
top_k: int = Field(default=DEFAULT_TOP_K, description="Number of results to return after filtering")
limit_k: int = Field(default=DEFAULT_LIMIT_K, description="Maximum number of results to retrieve from Pinecone")
similarity_metric: str = Field(default=DEFAULT_SIMILARITY_METRIC, description="Similarity metric to use")
similarity_threshold: float = Field(default=DEFAULT_SIMILARITY_THRESHOLD, description="Threshold for similarity")
namespace: str = "Default"
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
async def search_vectors_sync(
self, query_vector,
top_k: int = DEFAULT_TOP_K,
limit_k: int = DEFAULT_LIMIT_K,
similarity_metric: str = DEFAULT_SIMILARITY_METRIC,
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
namespace: str = "Default",
filter: Optional[Dict] = None
) -> Dict:
"""Synchronous wrapper for search_vectors"""
import asyncio
try:
# Get current event loop or create a new one
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Use event loop to run async function
if loop.is_running():
# If we're in an event loop, use asyncio.create_task
task = asyncio.create_task(search_vectors(
query_vector=query_vector,
top_k=top_k,
limit_k=limit_k,
similarity_metric=similarity_metric,
similarity_threshold=similarity_threshold,
namespace=namespace,
filter=filter
))
return await task
else:
# If not in an event loop, just await directly
return await search_vectors(
query_vector=query_vector,
top_k=top_k,
limit_k=limit_k,
similarity_metric=similarity_metric,
similarity_threshold=similarity_threshold,
namespace=namespace,
filter=filter
)
except Exception as e:
logger.error(f"Error in search_vectors_sync: {e}")
return None
def _get_relevant_documents(
self, query: str, *, run_manager: Callbacks = None
) -> List[Document]:
"""
Get documents relevant to the query using threshold-based retrieval.
Args:
query: The query string
run_manager: The callbacks manager
Returns:
List of relevant documents
"""
# Generate embedding for query using the embeddings model
try:
# Use the embeddings model we stored in the class
embedding = self.embeddings.embed_query(query)
except Exception as e:
logger.error(f"Error generating embedding: {e}")
# Fallback to creating a new embedding model if needed
embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
embedding = embedding_model.embed_query(query)
# Perform search with advanced options - avoid asyncio.run()
import asyncio
# Get or create event loop
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Run asynchronous search in a safe way
if loop.is_running():
# We're inside an existing event loop (like in FastAPI)
# Use a different approach - convert it to a synchronous call
from concurrent.futures import ThreadPoolExecutor
import functools
# Define a wrapper function to run in a thread
def run_async_in_thread():
# Create a new event loop for this thread
thread_loop = asyncio.new_event_loop()
asyncio.set_event_loop(thread_loop)
# Run the coroutine and return the result
return thread_loop.run_until_complete(search_vectors(
query_vector=embedding,
top_k=self.top_k,
limit_k=self.limit_k,
similarity_metric=self.similarity_metric,
similarity_threshold=self.similarity_threshold,
namespace=self.namespace,
# filter=self.search_kwargs.get("filter", None)
))
# Run the async function in a thread
with ThreadPoolExecutor() as executor:
search_result = executor.submit(run_async_in_thread).result()
else:
# No event loop running, we can use run_until_complete
search_result = loop.run_until_complete(search_vectors(
query_vector=embedding,
top_k=self.top_k,
limit_k=self.limit_k,
similarity_metric=self.similarity_metric,
similarity_threshold=self.similarity_threshold,
namespace=self.namespace,
# filter=self.search_kwargs.get("filter", None)
))
# Convert to documents
documents = []
if search_result and hasattr(search_result, 'matches'):
for match in search_result.matches:
# Extract metadata
metadata = {}
if hasattr(match, 'metadata'):
metadata = match.metadata
# Add score to metadata
score = getattr(match, 'score', 0)
normalized_score = getattr(match, 'normalized_score', score)
metadata['score'] = score
metadata['normalized_score'] = normalized_score
# Extract text
text = metadata.get('text', '')
if 'text' in metadata:
del metadata['text'] # Remove from metadata since it's the content
# Create Document
doc = Document(
page_content=text,
metadata=metadata
)
documents.append(doc)
return documents
# Get the retrieval chain with Pinecone vector store
def get_chain(
index_name=PINECONE_INDEX_NAME,
namespace="Default",
top_k=DEFAULT_TOP_K,
limit_k=DEFAULT_LIMIT_K,
similarity_metric=DEFAULT_SIMILARITY_METRIC,
similarity_threshold=DEFAULT_SIMILARITY_THRESHOLD
):
"""
Get the retrieval chain with Pinecone vector store using threshold-based retrieval.
Args:
index_name: Pinecone index name
namespace: Pinecone namespace
top_k: Number of results to return after filtering
limit_k: Maximum number of results to retrieve from Pinecone
similarity_metric: Similarity metric to use (cosine, dotproduct, euclidean)
similarity_threshold: Threshold for similarity (0-1)
Returns:
ThresholdRetriever instance
"""
global _retriever_instance
try:
# If already initialized with same parameters, return cached instance
if _retriever_instance is not None:
return _retriever_instance
start_time = time.time()
logger.info("Initializing new retriever chain with threshold-based filtering")
# Initialize embeddings model
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
# Get index
pinecone_index = get_pinecone_index()
if not pinecone_index:
logger.error("Failed to get Pinecone index for retriever chain")
return None
# Get statistics for logging
try:
stats = pinecone_index.describe_index_stats()
total_vectors = stats.get('total_vector_count', 0)
logger.info(f"Pinecone index stats - Total vectors: {total_vectors}")
except Exception as e:
logger.error(f"Error getting index stats: {e}")
# Use Pinecone from langchain_community.vectorstores
from langchain_community.vectorstores import Pinecone as LangchainPinecone
logger.info(f"Creating Pinecone vectorstore with index: {index_name}, namespace: {namespace}")
vectorstore = LangchainPinecone.from_existing_index(
embedding=embeddings,
index_name=index_name,
namespace=namespace,
text_key="text"
)
# Create threshold-based retriever
logger.info(f"Creating ThresholdRetriever with top_k={top_k}, limit_k={limit_k}, " +
f"metric={similarity_metric}, threshold={similarity_threshold}")
# Create ThresholdRetriever with both vectorstore and embeddings
_retriever_instance = ThresholdRetriever(
vectorstore=vectorstore,
embeddings=embeddings, # Pass embeddings separately
top_k=top_k,
limit_k=limit_k,
similarity_metric=similarity_metric,
similarity_threshold=similarity_threshold
)
logger.info(f"Pinecone retriever initialized in {time.time() - start_time:.2f} seconds")
return _retriever_instance
except Exception as e:
logger.error(f"Error creating retrieval chain: {e}")
return None