genai_service / rag /repository.py
ahmed-eisa's picture
remove qudra database instance
a960c43
raw
history blame
3.25 kB
from loguru import logger
from qdrant_client import AsyncQdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import ScoredPoint
class VectorRepository:
def __init__(self, host: str = "https://ahmed-eisa-qdrant-db.hf.space", port: int = 6333) -> None:
# self.db_client = AsyncQdrantClient(host=host, port=port)
self.db_client = AsyncQdrantClient(
url="https://e8342d34-1b50-48e3-95e2-d4eacd0755eb.us-east4-0.gcp.cloud.qdrant.io:6333",
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Q6rLdYDzVyr10B4AdYJHcPp9pCqWG7yhQ-NNmfWZqg8",
)
async def create_collection(self, collection_name: str, size: int) -> bool:
vectors_config = models.VectorParams(
size=size, distance=models.Distance.COSINE
)
response = await self.db_client.get_collections()
collection_exists = any(
collection.name == collection_name
for collection in response.collections
)
if collection_exists:
logger.debug(
f"Collection {collection_name} already exists - recreating it"
)
await self.db_client.delete_collection(collection_name)
return await self.db_client.create_collection(
collection_name,
vectors_config=vectors_config,
)
logger.debug(f"Creating collection {collection_name}")
return await self.db_client.create_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(
size=size, distance=models.Distance.COSINE
),
)
async def delete_collection(self, name: str) -> bool:
logger.debug(f"Deleting collection {name}")
return await self.db_client.delete_collection(name)
async def create(
self,
collection_name: str,
embedding_vector: list[float],
original_text: str,
source: str,
) -> None:
response = await self.db_client.count(collection_name=collection_name)
logger.debug(
f"Creating a new vector with ID {response.count} "
f"inside the {collection_name}"
)
await self.db_client.upsert(
collection_name=collection_name,
points=[
models.PointStruct(
id=response.count,
vector=embedding_vector,
payload={
"source": source,
"original_text": original_text,
},
)
],
)
async def search(
self,
collection_name: str,
query_vector: list[float],
retrieval_limit: int,
score_threshold: float,
) -> list[ScoredPoint]:
logger.debug(
f"Searching for relevant items in the {collection_name} collection"
)
response = await self.db_client.query_points(
collection_name=collection_name,
query_vector=query_vector,
limit=retrieval_limit,
score_threshold=score_threshold,
)
return response.points