Spaces:
Building
Building
from pymilvus import MilvusClient as Client | |
from pymilvus import FieldSchema, DataType | |
import json | |
import logging | |
from typing import Optional | |
from open_webui.retrieval.vector.main import ( | |
VectorDBBase, | |
VectorItem, | |
SearchResult, | |
GetResult, | |
) | |
from open_webui.config import ( | |
MILVUS_URI, | |
MILVUS_DB, | |
MILVUS_TOKEN, | |
MILVUS_INDEX_TYPE, | |
MILVUS_METRIC_TYPE, | |
MILVUS_HNSW_M, | |
MILVUS_HNSW_EFCONSTRUCTION, | |
MILVUS_IVF_FLAT_NLIST, | |
) | |
from open_webui.env import SRC_LOG_LEVELS | |
log = logging.getLogger(__name__) | |
log.setLevel(SRC_LOG_LEVELS["RAG"]) | |
class MilvusClient(VectorDBBase): | |
def __init__(self): | |
self.collection_prefix = "open_webui" | |
if MILVUS_TOKEN is None: | |
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB) | |
else: | |
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN) | |
def _result_to_get_result(self, result) -> GetResult: | |
ids = [] | |
documents = [] | |
metadatas = [] | |
for match in result: | |
_ids = [] | |
_documents = [] | |
_metadatas = [] | |
for item in match: | |
_ids.append(item.get("id")) | |
_documents.append(item.get("data", {}).get("text")) | |
_metadatas.append(item.get("metadata")) | |
ids.append(_ids) | |
documents.append(_documents) | |
metadatas.append(_metadatas) | |
return GetResult( | |
**{ | |
"ids": ids, | |
"documents": documents, | |
"metadatas": metadatas, | |
} | |
) | |
def _result_to_search_result(self, result) -> SearchResult: | |
ids = [] | |
distances = [] | |
documents = [] | |
metadatas = [] | |
for match in result: | |
_ids = [] | |
_distances = [] | |
_documents = [] | |
_metadatas = [] | |
for item in match: | |
_ids.append(item.get("id")) | |
# normalize milvus score from [-1, 1] to [0, 1] range | |
# https://milvus.io/docs/de/metric.md | |
_dist = (item.get("distance") + 1.0) / 2.0 | |
_distances.append(_dist) | |
_documents.append(item.get("entity", {}).get("data", {}).get("text")) | |
_metadatas.append(item.get("entity", {}).get("metadata")) | |
ids.append(_ids) | |
distances.append(_distances) | |
documents.append(_documents) | |
metadatas.append(_metadatas) | |
return SearchResult( | |
**{ | |
"ids": ids, | |
"distances": distances, | |
"documents": documents, | |
"metadatas": metadatas, | |
} | |
) | |
def _create_collection(self, collection_name: str, dimension: int): | |
schema = self.client.create_schema( | |
auto_id=False, | |
enable_dynamic_field=True, | |
) | |
schema.add_field( | |
field_name="id", | |
datatype=DataType.VARCHAR, | |
is_primary=True, | |
max_length=65535, | |
) | |
schema.add_field( | |
field_name="vector", | |
datatype=DataType.FLOAT_VECTOR, | |
dim=dimension, | |
description="vector", | |
) | |
schema.add_field(field_name="data", datatype=DataType.JSON, description="data") | |
schema.add_field( | |
field_name="metadata", datatype=DataType.JSON, description="metadata" | |
) | |
index_params = self.client.prepare_index_params() | |
# Use configurations from config.py | |
index_type = MILVUS_INDEX_TYPE.upper() | |
metric_type = MILVUS_METRIC_TYPE.upper() | |
log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}") | |
index_creation_params = {} | |
if index_type == "HNSW": | |
index_creation_params = { | |
"M": MILVUS_HNSW_M, | |
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION, | |
} | |
log.info(f"HNSW params: {index_creation_params}") | |
elif index_type == "IVF_FLAT": | |
index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST} | |
log.info(f"IVF_FLAT params: {index_creation_params}") | |
elif index_type in ["FLAT", "AUTOINDEX"]: | |
log.info(f"Using {index_type} index with no specific build-time params.") | |
else: | |
log.warning( | |
f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. " | |
f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX. " | |
f"Milvus will use its default for the collection if this type is not directly supported for index creation." | |
) | |
# For unsupported types, pass the type directly to Milvus; it might handle it or use a default. | |
# If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var. | |
index_params.add_index( | |
field_name="vector", | |
index_type=index_type, | |
metric_type=metric_type, | |
params=index_creation_params, | |
) | |
self.client.create_collection( | |
collection_name=f"{self.collection_prefix}_{collection_name}", | |
schema=schema, | |
index_params=index_params, | |
) | |
log.info( | |
f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{index_type}' and metric '{metric_type}'." | |
) | |
def has_collection(self, collection_name: str) -> bool: | |
# Check if the collection exists based on the collection name. | |
collection_name = collection_name.replace("-", "_") | |
return self.client.has_collection( | |
collection_name=f"{self.collection_prefix}_{collection_name}" | |
) | |
def delete_collection(self, collection_name: str): | |
# Delete the collection based on the collection name. | |
collection_name = collection_name.replace("-", "_") | |
return self.client.drop_collection( | |
collection_name=f"{self.collection_prefix}_{collection_name}" | |
) | |
def search( | |
self, collection_name: str, vectors: list[list[float | int]], limit: int | |
) -> Optional[SearchResult]: | |
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results. | |
collection_name = collection_name.replace("-", "_") | |
# For some index types like IVF_FLAT, search params like nprobe can be set. | |
# Example: search_params = {"nprobe": 10} if using IVF_FLAT | |
# For simplicity, not adding configurable search_params here, but could be extended. | |
result = self.client.search( | |
collection_name=f"{self.collection_prefix}_{collection_name}", | |
data=vectors, | |
limit=limit, | |
output_fields=["data", "metadata"], | |
# search_params=search_params # Potentially add later if needed | |
) | |
return self._result_to_search_result(result) | |
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): | |
# Construct the filter string for querying | |
collection_name = collection_name.replace("-", "_") | |
if not self.has_collection(collection_name): | |
log.warning( | |
f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}" | |
) | |
return None | |
filter_string = " && ".join( | |
[ | |
f'metadata["{key}"] == {json.dumps(value)}' | |
for key, value in filter.items() | |
] | |
) | |
max_limit = 16383 # The maximum number of records per request | |
all_results = [] | |
if limit is None: | |
# Milvus default limit for query if not specified is 16384, but docs mention iteration. | |
# Let's set a practical high number if "all" is intended, or handle true pagination. | |
# For now, if limit is None, we'll fetch in batches up to a very large number. | |
# This part could be refined based on expected use cases for "get all". | |
# For this function signature, None implies "as many as possible" up to Milvus limits. | |
limit = ( | |
16384 * 10 | |
) # A large number to signify fetching many, will be capped by actual data or max_limit per call. | |
log.info( | |
f"Limit not specified for query, fetching up to {limit} results in batches." | |
) | |
# Initialize offset and remaining to handle pagination | |
offset = 0 | |
remaining = limit | |
try: | |
log.info( | |
f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}" | |
) | |
# Loop until there are no more items to fetch or the desired limit is reached | |
while remaining > 0: | |
current_fetch = min( | |
max_limit, remaining if isinstance(remaining, int) else max_limit | |
) | |
log.debug( | |
f"Querying with offset: {offset}, current_fetch: {current_fetch}" | |
) | |
results = self.client.query( | |
collection_name=f"{self.collection_prefix}_{collection_name}", | |
filter=filter_string, | |
output_fields=[ | |
"id", | |
"data", | |
"metadata", | |
], # Explicitly list needed fields. Vector not usually needed in query. | |
limit=current_fetch, | |
offset=offset, | |
) | |
if not results: | |
log.debug("No more results from query.") | |
break | |
all_results.extend(results) | |
results_count = len(results) | |
log.debug(f"Fetched {results_count} results in this batch.") | |
if isinstance(remaining, int): | |
remaining -= results_count | |
offset += results_count | |
# Break the loop if the results returned are less than the requested fetch count (means end of data) | |
if results_count < current_fetch: | |
log.debug( | |
"Fetched less than requested, assuming end of results for this query." | |
) | |
break | |
log.info(f"Total results from query: {len(all_results)}") | |
return self._result_to_get_result([all_results]) | |
except Exception as e: | |
log.exception( | |
f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}" | |
) | |
return None | |
def get(self, collection_name: str) -> Optional[GetResult]: | |
# Get all the items in the collection. This can be very resource-intensive for large collections. | |
collection_name = collection_name.replace("-", "_") | |
log.warning( | |
f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections." | |
) | |
# Using query with a trivial filter to get all items. | |
# This will use the paginated query logic. | |
return self.query(collection_name=collection_name, filter={}, limit=None) | |
def insert(self, collection_name: str, items: list[VectorItem]): | |
# Insert the items into the collection, if the collection does not exist, it will be created. | |
collection_name = collection_name.replace("-", "_") | |
if not self.client.has_collection( | |
collection_name=f"{self.collection_prefix}_{collection_name}" | |
): | |
log.info( | |
f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now." | |
) | |
if not items: | |
log.error( | |
f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension." | |
) | |
raise ValueError( | |
"Cannot create Milvus collection without items to determine vector dimension." | |
) | |
self._create_collection( | |
collection_name=collection_name, dimension=len(items[0]["vector"]) | |
) | |
log.info( | |
f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}." | |
) | |
return self.client.insert( | |
collection_name=f"{self.collection_prefix}_{collection_name}", | |
data=[ | |
{ | |
"id": item["id"], | |
"vector": item["vector"], | |
"data": {"text": item["text"]}, | |
"metadata": item["metadata"], | |
} | |
for item in items | |
], | |
) | |
def upsert(self, collection_name: str, items: list[VectorItem]): | |
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. | |
collection_name = collection_name.replace("-", "_") | |
if not self.client.has_collection( | |
collection_name=f"{self.collection_prefix}_{collection_name}" | |
): | |
log.info( | |
f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now." | |
) | |
if not items: | |
log.error( | |
f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension." | |
) | |
raise ValueError( | |
"Cannot create Milvus collection for upsert without items to determine vector dimension." | |
) | |
self._create_collection( | |
collection_name=collection_name, dimension=len(items[0]["vector"]) | |
) | |
log.info( | |
f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}." | |
) | |
return self.client.upsert( | |
collection_name=f"{self.collection_prefix}_{collection_name}", | |
data=[ | |
{ | |
"id": item["id"], | |
"vector": item["vector"], | |
"data": {"text": item["text"]}, | |
"metadata": item["metadata"], | |
} | |
for item in items | |
], | |
) | |
def delete( | |
self, | |
collection_name: str, | |
ids: Optional[list[str]] = None, | |
filter: Optional[dict] = None, | |
): | |
# Delete the items from the collection based on the ids or filter. | |
collection_name = collection_name.replace("-", "_") | |
if not self.has_collection(collection_name): | |
log.warning( | |
f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}" | |
) | |
return None | |
if ids: | |
log.info( | |
f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}" | |
) | |
return self.client.delete( | |
collection_name=f"{self.collection_prefix}_{collection_name}", | |
ids=ids, | |
) | |
elif filter: | |
filter_string = " && ".join( | |
[ | |
f'metadata["{key}"] == {json.dumps(value)}' | |
for key, value in filter.items() | |
] | |
) | |
log.info( | |
f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}" | |
) | |
return self.client.delete( | |
collection_name=f"{self.collection_prefix}_{collection_name}", | |
filter=filter_string, | |
) | |
else: | |
log.warning( | |
f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken." | |
) | |
return None | |
def reset(self): | |
# Resets the database. This will delete all collections and item entries that match the prefix. | |
log.warning( | |
f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'." | |
) | |
collection_names = self.client.list_collections() | |
deleted_collections = [] | |
for collection_name_full in collection_names: | |
if collection_name_full.startswith(self.collection_prefix): | |
try: | |
self.client.drop_collection(collection_name=collection_name_full) | |
deleted_collections.append(collection_name_full) | |
log.info(f"Deleted collection: {collection_name_full}") | |
except Exception as e: | |
log.error(f"Error deleting collection {collection_name_full}: {e}") | |
log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}") | |