Spaces:
Running
Running
from loguru import logger | |
from qdrant_client import AsyncQdrantClient | |
from qdrant_client.http import models | |
from qdrant_client.http.models import ScoredPoint | |
class VectorRepository: | |
def __init__(self, host: str = "https://ahmed-eisa-qdrant-db.hf.space", port: int = 6333) -> None: | |
# self.db_client = AsyncQdrantClient(host=host, port=port) | |
self.db_client = AsyncQdrantClient( | |
url="https://e8342d34-1b50-48e3-95e2-d4eacd0755eb.us-east4-0.gcp.cloud.qdrant.io:6333", | |
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Q6rLdYDzVyr10B4AdYJHcPp9pCqWG7yhQ-NNmfWZqg8", | |
) | |
async def create_collection(self, collection_name: str, size: int) -> bool: | |
vectors_config = models.VectorParams( | |
size=size, distance=models.Distance.COSINE | |
) | |
response = await self.db_client.get_collections() | |
collection_exists = any( | |
collection.name == collection_name | |
for collection in response.collections | |
) | |
if collection_exists: | |
logger.debug( | |
f"Collection {collection_name} already exists - recreating it" | |
) | |
await self.db_client.delete_collection(collection_name) | |
return await self.db_client.create_collection( | |
collection_name, | |
vectors_config=vectors_config, | |
) | |
logger.debug(f"Creating collection {collection_name}") | |
return await self.db_client.create_collection( | |
collection_name=collection_name, | |
vectors_config=models.VectorParams( | |
size=size, distance=models.Distance.COSINE | |
), | |
) | |
async def delete_collection(self, name: str) -> bool: | |
logger.debug(f"Deleting collection {name}") | |
return await self.db_client.delete_collection(name) | |
async def create( | |
self, | |
collection_name: str, | |
embedding_vector: list[float], | |
original_text: str, | |
source: str, | |
) -> None: | |
response = await self.db_client.count(collection_name=collection_name) | |
logger.debug( | |
f"Creating a new vector with ID {response.count} " | |
f"inside the {collection_name}" | |
) | |
await self.db_client.upsert( | |
collection_name=collection_name, | |
points=[ | |
models.PointStruct( | |
id=response.count, | |
vector=embedding_vector, | |
payload={ | |
"source": source, | |
"original_text": original_text, | |
}, | |
) | |
], | |
) | |
async def search( | |
self, | |
collection_name: str, | |
query_vector: list[float], | |
retrieval_limit: int, | |
score_threshold: float, | |
) -> list[ScoredPoint]: | |
logger.debug( | |
f"Searching for relevant items in the {collection_name} collection" | |
) | |
response = await self.db_client.query_points( | |
collection_name=collection_name, | |
query_vector=query_vector, | |
limit=retrieval_limit, | |
score_threshold=score_threshold, | |
) | |
return response.points |