|
from typing import Optional, List, Dict, Any, Union |
|
import logging |
|
import time |
|
from pinecone import ServerlessSpec |
|
|
|
import asyncio |
|
import functools |
|
|
|
import concurrent.futures |
|
from pinecone.grpc import PineconeGRPC |
|
|
|
from open_webui.retrieval.vector.main import ( |
|
VectorDBBase, |
|
VectorItem, |
|
SearchResult, |
|
GetResult, |
|
) |
|
from open_webui.config import ( |
|
PINECONE_API_KEY, |
|
PINECONE_ENVIRONMENT, |
|
PINECONE_INDEX_NAME, |
|
PINECONE_DIMENSION, |
|
PINECONE_METRIC, |
|
PINECONE_CLOUD, |
|
) |
|
from open_webui.env import SRC_LOG_LEVELS |
|
|
|
NO_LIMIT = 10000 |
|
BATCH_SIZE = 100 |
|
|
|
log = logging.getLogger(__name__) |
|
log.setLevel(SRC_LOG_LEVELS["RAG"]) |
|
|
|
|
|
class PineconeClient(VectorDBBase): |
|
def __init__(self): |
|
self.collection_prefix = "open-webui" |
|
|
|
|
|
self._validate_config() |
|
|
|
|
|
self.api_key = PINECONE_API_KEY |
|
self.environment = PINECONE_ENVIRONMENT |
|
self.index_name = PINECONE_INDEX_NAME |
|
self.dimension = PINECONE_DIMENSION |
|
self.metric = PINECONE_METRIC |
|
self.cloud = PINECONE_CLOUD |
|
|
|
|
|
self.client = PineconeGRPC( |
|
api_key=self.api_key, environment=self.environment, cloud=self.cloud |
|
) |
|
|
|
|
|
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5) |
|
|
|
|
|
self._initialize_index() |
|
|
|
def _validate_config(self) -> None: |
|
"""Validate that all required configuration variables are set.""" |
|
missing_vars = [] |
|
if not PINECONE_API_KEY: |
|
missing_vars.append("PINECONE_API_KEY") |
|
if not PINECONE_ENVIRONMENT: |
|
missing_vars.append("PINECONE_ENVIRONMENT") |
|
if not PINECONE_INDEX_NAME: |
|
missing_vars.append("PINECONE_INDEX_NAME") |
|
if not PINECONE_DIMENSION: |
|
missing_vars.append("PINECONE_DIMENSION") |
|
if not PINECONE_CLOUD: |
|
missing_vars.append("PINECONE_CLOUD") |
|
|
|
if missing_vars: |
|
raise ValueError( |
|
f"Required configuration missing: {', '.join(missing_vars)}" |
|
) |
|
|
|
def _initialize_index(self) -> None: |
|
"""Initialize the Pinecone index.""" |
|
try: |
|
|
|
if self.index_name not in self.client.list_indexes().names(): |
|
log.info(f"Creating Pinecone index '{self.index_name}'...") |
|
self.client.create_index( |
|
name=self.index_name, |
|
dimension=self.dimension, |
|
metric=self.metric, |
|
spec=ServerlessSpec(cloud=self.cloud, region=self.environment), |
|
) |
|
log.info(f"Successfully created Pinecone index '{self.index_name}'") |
|
else: |
|
log.info(f"Using existing Pinecone index '{self.index_name}'") |
|
|
|
|
|
self.index = self.client.Index(self.index_name) |
|
|
|
except Exception as e: |
|
log.error(f"Failed to initialize Pinecone index: {e}") |
|
raise RuntimeError(f"Failed to initialize Pinecone index: {e}") |
|
|
|
def _create_points( |
|
self, items: List[VectorItem], collection_name_with_prefix: str |
|
) -> List[Dict[str, Any]]: |
|
"""Convert VectorItem objects to Pinecone point format.""" |
|
points = [] |
|
for item in items: |
|
|
|
metadata = item.get("metadata", {}).copy() if item.get("metadata") else {} |
|
|
|
|
|
if "text" in item: |
|
metadata["text"] = item["text"] |
|
|
|
|
|
metadata["collection_name"] = collection_name_with_prefix |
|
|
|
point = { |
|
"id": item["id"], |
|
"values": item["vector"], |
|
"metadata": metadata, |
|
} |
|
points.append(point) |
|
return points |
|
|
|
def _get_collection_name_with_prefix(self, collection_name: str) -> str: |
|
"""Get the collection name with prefix.""" |
|
return f"{self.collection_prefix}_{collection_name}" |
|
|
|
def _normalize_distance(self, score: float) -> float: |
|
"""Normalize distance score based on the metric used.""" |
|
if self.metric.lower() == "cosine": |
|
|
|
return (score + 1.0) / 2.0 |
|
elif self.metric.lower() in ["euclidean", "dotproduct"]: |
|
|
|
return score |
|
else: |
|
|
|
return score |
|
|
|
def _result_to_get_result(self, matches: list) -> GetResult: |
|
"""Convert Pinecone matches to GetResult format.""" |
|
ids = [] |
|
documents = [] |
|
metadatas = [] |
|
|
|
for match in matches: |
|
metadata = match.get("metadata", {}) |
|
ids.append(match["id"]) |
|
documents.append(metadata.get("text", "")) |
|
metadatas.append(metadata) |
|
|
|
return GetResult( |
|
**{ |
|
"ids": [ids], |
|
"documents": [documents], |
|
"metadatas": [metadatas], |
|
} |
|
) |
|
|
|
def has_collection(self, collection_name: str) -> bool: |
|
"""Check if a collection exists by searching for at least one item.""" |
|
collection_name_with_prefix = self._get_collection_name_with_prefix( |
|
collection_name |
|
) |
|
|
|
try: |
|
|
|
response = self.index.query( |
|
vector=[0.0] * self.dimension, |
|
top_k=1, |
|
filter={"collection_name": collection_name_with_prefix}, |
|
include_metadata=False, |
|
) |
|
return len(response.matches) > 0 |
|
except Exception as e: |
|
log.exception( |
|
f"Error checking collection '{collection_name_with_prefix}': {e}" |
|
) |
|
return False |
|
|
|
def delete_collection(self, collection_name: str) -> None: |
|
"""Delete a collection by removing all vectors with the collection name in metadata.""" |
|
collection_name_with_prefix = self._get_collection_name_with_prefix( |
|
collection_name |
|
) |
|
try: |
|
self.index.delete(filter={"collection_name": collection_name_with_prefix}) |
|
log.info( |
|
f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)." |
|
) |
|
except Exception as e: |
|
log.warning( |
|
f"Failed to delete collection '{collection_name_with_prefix}': {e}" |
|
) |
|
raise |
|
|
|
def insert(self, collection_name: str, items: List[VectorItem]) -> None: |
|
"""Insert vectors into a collection.""" |
|
if not items: |
|
log.warning("No items to insert") |
|
return |
|
|
|
start_time = time.time() |
|
|
|
collection_name_with_prefix = self._get_collection_name_with_prefix( |
|
collection_name |
|
) |
|
points = self._create_points(items, collection_name_with_prefix) |
|
|
|
|
|
executor = self._executor |
|
futures = [] |
|
for i in range(0, len(points), BATCH_SIZE): |
|
batch = points[i : i + BATCH_SIZE] |
|
futures.append(executor.submit(self.index.upsert, vectors=batch)) |
|
for future in concurrent.futures.as_completed(futures): |
|
try: |
|
future.result() |
|
except Exception as e: |
|
log.error(f"Error inserting batch: {e}") |
|
raise |
|
elapsed = time.time() - start_time |
|
log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds") |
|
log.info( |
|
f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'" |
|
) |
|
|
|
def upsert(self, collection_name: str, items: List[VectorItem]) -> None: |
|
"""Upsert (insert or update) vectors into a collection.""" |
|
if not items: |
|
log.warning("No items to upsert") |
|
return |
|
|
|
start_time = time.time() |
|
|
|
collection_name_with_prefix = self._get_collection_name_with_prefix( |
|
collection_name |
|
) |
|
points = self._create_points(items, collection_name_with_prefix) |
|
|
|
|
|
executor = self._executor |
|
futures = [] |
|
for i in range(0, len(points), BATCH_SIZE): |
|
batch = points[i : i + BATCH_SIZE] |
|
futures.append(executor.submit(self.index.upsert, vectors=batch)) |
|
for future in concurrent.futures.as_completed(futures): |
|
try: |
|
future.result() |
|
except Exception as e: |
|
log.error(f"Error upserting batch: {e}") |
|
raise |
|
elapsed = time.time() - start_time |
|
log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds") |
|
log.info( |
|
f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'" |
|
) |
|
|
|
async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None: |
|
"""Async version of insert using asyncio and run_in_executor for improved performance.""" |
|
if not items: |
|
log.warning("No items to insert") |
|
return |
|
|
|
collection_name_with_prefix = self._get_collection_name_with_prefix( |
|
collection_name |
|
) |
|
points = self._create_points(items, collection_name_with_prefix) |
|
|
|
|
|
batches = [ |
|
points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE) |
|
] |
|
loop = asyncio.get_event_loop() |
|
tasks = [ |
|
loop.run_in_executor( |
|
None, functools.partial(self.index.upsert, vectors=batch) |
|
) |
|
for batch in batches |
|
] |
|
results = await asyncio.gather(*tasks, return_exceptions=True) |
|
for result in results: |
|
if isinstance(result, Exception): |
|
log.error(f"Error in async insert batch: {result}") |
|
raise result |
|
log.info( |
|
f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'" |
|
) |
|
|
|
async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None: |
|
"""Async version of upsert using asyncio and run_in_executor for improved performance.""" |
|
if not items: |
|
log.warning("No items to upsert") |
|
return |
|
|
|
collection_name_with_prefix = self._get_collection_name_with_prefix( |
|
collection_name |
|
) |
|
points = self._create_points(items, collection_name_with_prefix) |
|
|
|
|
|
batches = [ |
|
points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE) |
|
] |
|
loop = asyncio.get_event_loop() |
|
tasks = [ |
|
loop.run_in_executor( |
|
None, functools.partial(self.index.upsert, vectors=batch) |
|
) |
|
for batch in batches |
|
] |
|
results = await asyncio.gather(*tasks, return_exceptions=True) |
|
for result in results: |
|
if isinstance(result, Exception): |
|
log.error(f"Error in async upsert batch: {result}") |
|
raise result |
|
log.info( |
|
f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'" |
|
) |
|
|
|
def streaming_upsert(self, collection_name: str, items: List[VectorItem]) -> None: |
|
"""Perform a streaming upsert over gRPC for performance testing.""" |
|
if not items: |
|
log.warning("No items to upsert via streaming") |
|
return |
|
|
|
collection_name_with_prefix = self._get_collection_name_with_prefix( |
|
collection_name |
|
) |
|
points = self._create_points(items, collection_name_with_prefix) |
|
|
|
|
|
stream = self.index.streaming_upsert() |
|
try: |
|
for point in points: |
|
|
|
stream.send(point) |
|
|
|
stream.close() |
|
log.info( |
|
f"Successfully streamed upsert of {len(points)} vectors into '{collection_name_with_prefix}'" |
|
) |
|
except Exception as e: |
|
log.error(f"Error during streaming upsert: {e}") |
|
raise |
|
|
|
def search( |
|
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int |
|
) -> Optional[SearchResult]: |
|
"""Search for similar vectors in a collection.""" |
|
if not vectors or not vectors[0]: |
|
log.warning("No vectors provided for search") |
|
return None |
|
|
|
collection_name_with_prefix = self._get_collection_name_with_prefix( |
|
collection_name |
|
) |
|
|
|
if limit is None or limit <= 0: |
|
limit = NO_LIMIT |
|
|
|
try: |
|
|
|
query_vector = vectors[0] |
|
|
|
|
|
query_response = self.index.query( |
|
vector=query_vector, |
|
top_k=limit, |
|
include_metadata=True, |
|
filter={"collection_name": collection_name_with_prefix}, |
|
) |
|
|
|
if not query_response.matches: |
|
|
|
return SearchResult( |
|
ids=[[]], |
|
documents=[[]], |
|
metadatas=[[]], |
|
distances=[[]], |
|
) |
|
|
|
|
|
get_result = self._result_to_get_result(query_response.matches) |
|
|
|
|
|
distances = [ |
|
[ |
|
self._normalize_distance(match.score) |
|
for match in query_response.matches |
|
] |
|
] |
|
|
|
return SearchResult( |
|
ids=get_result.ids, |
|
documents=get_result.documents, |
|
metadatas=get_result.metadatas, |
|
distances=distances, |
|
) |
|
except Exception as e: |
|
log.error(f"Error searching in '{collection_name_with_prefix}': {e}") |
|
return None |
|
|
|
def query( |
|
self, collection_name: str, filter: Dict, limit: Optional[int] = None |
|
) -> Optional[GetResult]: |
|
"""Query vectors by metadata filter.""" |
|
collection_name_with_prefix = self._get_collection_name_with_prefix( |
|
collection_name |
|
) |
|
|
|
if limit is None or limit <= 0: |
|
limit = NO_LIMIT |
|
|
|
try: |
|
|
|
zero_vector = [0.0] * self.dimension |
|
|
|
|
|
pinecone_filter = {"collection_name": collection_name_with_prefix} |
|
if filter: |
|
pinecone_filter.update(filter) |
|
|
|
|
|
query_response = self.index.query( |
|
vector=zero_vector, |
|
filter=pinecone_filter, |
|
top_k=limit, |
|
include_metadata=True, |
|
) |
|
|
|
return self._result_to_get_result(query_response.matches) |
|
|
|
except Exception as e: |
|
log.error(f"Error querying collection '{collection_name}': {e}") |
|
return None |
|
|
|
def get(self, collection_name: str) -> Optional[GetResult]: |
|
"""Get all vectors in a collection.""" |
|
collection_name_with_prefix = self._get_collection_name_with_prefix( |
|
collection_name |
|
) |
|
|
|
try: |
|
|
|
zero_vector = [0.0] * self.dimension |
|
|
|
|
|
query_response = self.index.query( |
|
vector=zero_vector, |
|
top_k=NO_LIMIT, |
|
include_metadata=True, |
|
filter={"collection_name": collection_name_with_prefix}, |
|
) |
|
|
|
return self._result_to_get_result(query_response.matches) |
|
|
|
except Exception as e: |
|
log.error(f"Error getting collection '{collection_name}': {e}") |
|
return None |
|
|
|
def delete( |
|
self, |
|
collection_name: str, |
|
ids: Optional[List[str]] = None, |
|
filter: Optional[Dict] = None, |
|
) -> None: |
|
"""Delete vectors by IDs or filter.""" |
|
collection_name_with_prefix = self._get_collection_name_with_prefix( |
|
collection_name |
|
) |
|
|
|
try: |
|
if ids: |
|
|
|
for i in range(0, len(ids), BATCH_SIZE): |
|
batch_ids = ids[i : i + BATCH_SIZE] |
|
|
|
|
|
self.index.delete(ids=batch_ids) |
|
log.debug( |
|
f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'" |
|
) |
|
log.info( |
|
f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'" |
|
) |
|
|
|
elif filter: |
|
|
|
pinecone_filter = {"collection_name": collection_name_with_prefix} |
|
if filter: |
|
pinecone_filter.update(filter) |
|
|
|
self.index.delete(filter=pinecone_filter) |
|
log.info( |
|
f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'" |
|
) |
|
|
|
else: |
|
log.warning("No ids or filter provided for delete operation") |
|
|
|
except Exception as e: |
|
log.error(f"Error deleting from collection '{collection_name}': {e}") |
|
raise |
|
|
|
def reset(self) -> None: |
|
"""Reset the database by deleting all collections.""" |
|
try: |
|
self.index.delete(delete_all=True) |
|
log.info("All vectors successfully deleted from the index.") |
|
except Exception as e: |
|
log.error(f"Failed to reset Pinecone index: {e}") |
|
raise |
|
|
|
def close(self): |
|
"""Shut down the gRPC channel and thread pool.""" |
|
try: |
|
self.client.close() |
|
log.info("Pinecone gRPC channel closed.") |
|
except Exception as e: |
|
log.warning(f"Failed to close Pinecone gRPC channel: {e}") |
|
self._executor.shutdown(wait=True) |
|
|
|
def __enter__(self): |
|
"""Enter context manager.""" |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
"""Exit context manager, ensuring resources are cleaned up.""" |
|
self.close() |
|
|