|
from typing import Any |
|
|
|
from qdrant_client import QdrantClient, models |
|
from uuid import uuid4 |
|
|
|
from src.config import QDRANT_COLLECTION_NAME |
|
|
|
|
|
class QdrantStore: |
|
def __init__(self, client: QdrantClient, collection_config=None): |
|
self.client = client |
|
self.collection_names = set([i.name for i in client.get_collections().collections]) |
|
|
|
if collection_config is not None: |
|
self.create_collection(collection_config) |
|
|
|
def create_collection(self, collection_config: dict): |
|
collection_name = collection_config["collection_name"] |
|
if not self.client.collection_exists(collection_name): |
|
self.client.create_collection(**collection_config) |
|
self.collection_names.add(collection_name) |
|
|
|
def _check_collection_name(self, collection_name): |
|
if collection_name not in self.collection_names: |
|
raise ValueError(f"Collection: {collection_name} does not exist.") |
|
|
|
def upsert_points(self, |
|
vectors: Any | list[Any], |
|
payloads: dict | list[dict], |
|
collection_name: str): |
|
self._check_collection_name(collection_name) |
|
|
|
ids = [str(uuid4()) for _ in payloads] |
|
|
|
self.client.upsert( |
|
collection_name=collection_name, |
|
points=models.Batch( |
|
ids=ids, |
|
payloads=payloads, |
|
vectors=vectors |
|
) |
|
) |
|
|
|
def delete_points(self, |
|
filters: dict[str, list[models.FieldCondition]], |
|
collection_name: str): |
|
self._check_collection_name(collection_name) |
|
|
|
self.client.delete( |
|
collection_name=collection_name, |
|
points_selector=models.Filter(**filters) |
|
) |
|
|
|
def delete_points_by_match(self, |
|
key_value: tuple[str, list[str] | str], |
|
collection_name: str): |
|
key, values = key_value |
|
if isinstance(values, str): |
|
values = [values] |
|
filter = {"must": [models.FieldCondition(key=key, match=models.MatchAny(any=values))]} |
|
self.delete_points(filter, collection_name) |
|
|