renesis / libs /kotaemon /tests /test_vectorstore.py
noumanjavaid's picture
Upload folder using huggingface_hub
ad33df7 verified
raw
history blame
14.2 kB
import json
import os
import pytest
from kotaemon.base import DocumentWithEmbedding
from kotaemon.storages import (
ChromaVectorStore,
InMemoryVectorStore,
MilvusVectorStore,
QdrantVectorStore,
SimpleFileVectorStore,
)
class TestChromaVectorStore:
def test_add(self, tmp_path):
"""Test that the DB add correctly"""
db = ChromaVectorStore(path=str(tmp_path))
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
ids = ["1", "2"]
assert db._collection.count() == 0, "Expected empty collection"
output = db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
assert output == ids, "Expected output to be the same as ids"
assert db._collection.count() == 2, "Expected 2 added entries"
def test_add_from_docs(self, tmp_path):
db = ChromaVectorStore(path=str(tmp_path))
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
documents = [
DocumentWithEmbedding(embedding=embedding, metadata=metadata)
for embedding, metadata in zip(embeddings, metadatas)
]
assert db._collection.count() == 0, "Expected empty collection"
output = db.add(documents)
assert len(output) == 2, "Expected outputting 2 ids"
assert db._collection.count() == 2, "Expected 2 added entries"
def test_delete(self, tmp_path):
db = ChromaVectorStore(path=str(tmp_path))
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ids = ["a", "b", "c"]
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
assert db._collection.count() == 3, "Expected 3 added entries"
db.delete(ids=["a", "b"])
assert db._collection.count() == 1, "Expected 1 remaining entry"
db.delete(ids=["c"])
assert db._collection.count() == 0, "Expected 0 remaining entry"
def test_query(self, tmp_path):
db = ChromaVectorStore(path=str(tmp_path))
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ids = ["a", "b", "c"]
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
_, sim, out_ids = db.query(embedding=[0.1, 0.2, 0.3], top_k=1)
assert sim[0] - 1.0 < 1e-6
assert out_ids == ["a"]
_, _, out_ids = db.query(embedding=[0.42, 0.52, 0.53], top_k=1)
assert out_ids == ["b"]
def test_save_load_delete(self, tmp_path):
"""Test that save/load func behave correctly."""
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ids = ["1", "2", "3"]
db = ChromaVectorStore(path=str(tmp_path))
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
db2 = ChromaVectorStore(path=str(tmp_path))
assert (
db2._collection.count() == 3
), "load function does not load data completely"
# test delete collection function
db2.drop()
# reinit the chroma with the same collection name
db2 = ChromaVectorStore(path=str(tmp_path))
assert (
db2._collection.count() == 0
), "delete collection function does not work correctly"
class TestInMemoryVectorStore:
def test_add(self):
"""Test that add func adds correctly."""
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
ids = ["1", "2"]
db = InMemoryVectorStore()
output = db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
assert output == ids, "Excepted output to be the same as ids"
def test_save_load_delete(self, tmp_path):
"""Test that delete func deletes correctly."""
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ids = ["1", "2", "3"]
db = InMemoryVectorStore()
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
db.delete(["3"])
db.save(save_path=tmp_path / "test_save_load_delete.json")
with open(tmp_path / "test_save_load_delete.json") as f:
data = json.load(f)
assert (
"1" and "2" in data["text_id_to_ref_doc_id"]
), "save function does not save data completely"
assert (
"3" not in data["text_id_to_ref_doc_id"]
), "delete function does not delete data completely"
db2 = InMemoryVectorStore()
db2.load(load_path=tmp_path / "test_save_load_delete.json")
assert db2.get("2") == [
0.4,
0.5,
0.6,
], "load function does not load data completely"
class TestSimpleFileVectorStore:
def test_add_delete(self, tmp_path):
"""Test that delete func deletes correctly."""
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ids = ["1", "2", "3"]
collection_name = "test_save_load_delete"
db = SimpleFileVectorStore(path=tmp_path, collection_name=collection_name)
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
db.delete(["3"])
with open(tmp_path / collection_name) as f:
data = json.load(f)
assert (
"1" and "2" in data["text_id_to_ref_doc_id"]
), "save function does not save data completely"
assert (
"3" not in data["text_id_to_ref_doc_id"]
), "delete function does not delete data completely"
db2 = SimpleFileVectorStore(path=tmp_path, collection_name=collection_name)
assert db2.get("2") == [
0.4,
0.5,
0.6,
], "load function does not load data completely"
os.remove(tmp_path / collection_name)
class TestMilvusVectorStore:
def test_add(self, tmp_path):
"""Test that the DB add correctly"""
db = MilvusVectorStore(
path=str(tmp_path),
overwrite=True,
)
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
ids = ["1", "2"]
assert db.count() == 0, "Expected empty collection"
output = db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
assert output == ids, "Expected output to be the same as ids"
assert db.count() == 2, "Expected 2 added entries"
def test_add_from_docs(self, tmp_path):
db = MilvusVectorStore(
path=str(tmp_path),
overwrite=True,
)
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
documents = [
DocumentWithEmbedding(embedding=embedding, metadata=metadata)
for embedding, metadata in zip(embeddings, metadatas)
]
assert db.count() == 0, "Expected empty collection"
output = db.add(documents)
assert len(output) == 2, "Expected outputting 2 ids"
assert db.count() == 2, "Expected 2 added entries"
def test_delete(self, tmp_path):
db = MilvusVectorStore(
path=str(tmp_path),
overwrite=True,
)
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ids = ["a", "b", "c"]
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
assert db.count() == 3, "Expected 3 added entries"
db.delete(ids=["a", "b"])
assert db.count() == 1, "Expected 1 remaining entry"
db.delete(ids=["c"])
assert db.count() == 0, "Expected 0 remaining entry"
def test_query(self, tmp_path):
db = MilvusVectorStore(path=str(tmp_path), overwrite=True)
import numpy as np
embeddings = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])
norms = np.linalg.norm(embeddings, axis=1)
normalized_embeddings = (embeddings / norms[:, np.newaxis]).tolist()
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ids = ["a", "b", "c"]
db.add(embeddings=normalized_embeddings, metadatas=metadatas, ids=ids)
_, sim, out_ids = db.query(embedding=normalized_embeddings[0], top_k=1)
assert sim[0] - 1.0 < 1e-6
assert out_ids == ["a"]
query_embedding = [
normalized_embeddings[1][0] + 0.02,
normalized_embeddings[1][1] + 0.02,
normalized_embeddings[1][2] + 0.02,
]
_, _, out_ids = db.query(embedding=query_embedding, top_k=1)
assert out_ids == ["b"]
def test_save_load_delete(self, tmp_path):
"""Test that save/load func behave correctly."""
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ids = ["1", "2", "3"]
db = MilvusVectorStore(path=str(tmp_path), overwrite=True)
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
db2 = MilvusVectorStore(path=str(tmp_path), overrides=False)
assert db2.count() == 3, "load function does not load data completely"
# test delete collection function
db2.drop()
# reinit the milvus with the same collection name
db2 = MilvusVectorStore(path=str(tmp_path), overwrite=False)
assert db2.count() == 0, "delete collection function does not work correctly"
class TestQdrantVectorStore:
def test_add(self):
from qdrant_client import QdrantClient
db = QdrantVectorStore(collection_name="test", client=QdrantClient(":memory:"))
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
ids = [
"0f0611b3-2d9c-4818-ab69-1f1c4cf66693",
"90aba5d3-f4f8-47c6-bad9-5ea457442e07",
]
output = db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
assert output == ids, "Expected output to be the same as ids"
assert db.count() == 2, "Expected 2 added entries"
def test_add_from_docs(self, tmp_path):
from qdrant_client import QdrantClient
db = QdrantVectorStore(collection_name="test", client=QdrantClient(":memory:"))
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
documents = [
DocumentWithEmbedding(embedding=embedding, metadata=metadata)
for embedding, metadata in zip(embeddings, metadatas)
]
output = db.add(documents)
assert len(output) == 2, "Expected outputting 2 ids"
assert db.count() == 2, "Expected 2 added entries"
def test_delete(self, tmp_path):
from qdrant_client import QdrantClient
db = QdrantVectorStore(collection_name="test", client=QdrantClient(":memory:"))
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ids = [
"0f0611b3-2d9c-4818-ab69-1f1c4cf66693",
"90aba5d3-f4f8-47c6-bad9-5ea457442e07",
"6bed07c3-d284-47a3-a711-c3f9186755b8",
]
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
assert db.count() == 3, "Expected 3 added entries"
db.delete(
ids=[
"0f0611b3-2d9c-4818-ab69-1f1c4cf66693",
"90aba5d3-f4f8-47c6-bad9-5ea457442e07",
]
)
assert db.count() == 1, "Expected 1 remaining entry"
db.delete(ids=["6bed07c3-d284-47a3-a711-c3f9186755b8"])
assert db.count() == 0, "Expected 0 remaining entry"
def test_query(self, tmp_path):
from qdrant_client import QdrantClient
db = QdrantVectorStore(collection_name="test", client=QdrantClient(":memory:"))
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ids = [
"0f0611b3-2d9c-4818-ab69-1f1c4cf66693",
"90aba5d3-f4f8-47c6-bad9-5ea457442e07",
"6bed07c3-d284-47a3-a711-c3f9186755b8",
]
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
_, sim, out_ids = db.query(embedding=[0.1, 0.2, 0.3], top_k=1)
assert sim[0] - 1.0 < 1e-6
assert out_ids == ["0f0611b3-2d9c-4818-ab69-1f1c4cf66693"]
_, _, out_ids = db.query(embedding=[0.4, 0.5, 0.6], top_k=1)
assert out_ids == ["90aba5d3-f4f8-47c6-bad9-5ea457442e07"]
def test_save_load_delete(self, tmp_path):
"""Test that save/load func behave correctly."""
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ids = [
"0f0611b3-2d9c-4818-ab69-1f1c4cf66693",
"90aba5d3-f4f8-47c6-bad9-5ea457442e07",
"6bed07c3-d284-47a3-a711-c3f9186755b8",
]
from qdrant_client import QdrantClient
db = QdrantVectorStore(
collection_name="test", client=QdrantClient(path=tmp_path)
)
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
del db
db2 = QdrantVectorStore(
collection_name="test", client=QdrantClient(path=tmp_path)
)
assert db2.count() == 3
db2.drop()
del db2
db2 = QdrantVectorStore(
collection_name="test", client=QdrantClient(path=tmp_path)
)
with pytest.raises(Exception):
# Since no docs were added, the collection should not exist yet
# and thus the count function should raise an exception
db2.count()