renesis / libs /kotaemon /tests /test_indexing_retrieval.py
noumanjavaid's picture
Upload folder using huggingface_hub
ad33df7 verified
raw
history blame
2.53 kB
import json
from pathlib import Path
from typing import cast
from unittest.mock import patch
from openai.types.create_embedding_response import CreateEmbeddingResponse
from kotaemon.base import Document
from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.indices import VectorIndexing, VectorRetrieval
from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore
with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f:
openai_embedding = CreateEmbeddingResponse.model_validate(json.load(f))
@patch(
"openai.resources.embeddings.Embeddings.create",
side_effect=lambda *args, **kwargs: openai_embedding,
)
def test_indexing(tmp_path):
db = ChromaVectorStore(path=str(tmp_path))
doc_store = InMemoryDocumentStore()
embedding = AzureOpenAIEmbeddings(
azure_deployment="text-embedding-ada-002",
azure_endpoint="https://test.openai.azure.com/",
api_key="some-key",
api_version="version",
)
pipeline = VectorIndexing(vector_store=db, embedding=embedding, doc_store=doc_store)
pipeline.doc_store = cast(InMemoryDocumentStore, pipeline.doc_store)
pipeline.vector_store = cast(ChromaVectorStore, pipeline.vector_store)
assert pipeline.vector_store._collection.count() == 0, "Expected empty collection"
assert len(pipeline.doc_store._store) == 0, "Expected empty doc store"
pipeline(text=Document(text="Hello world"))
assert pipeline.vector_store._collection.count() == 1, "Index 1 item"
assert len(pipeline.doc_store._store) == 1, "Expected 1 document"
@patch(
"openai.resources.embeddings.Embeddings.create",
side_effect=lambda *args, **kwargs: openai_embedding,
)
def test_retrieving(tmp_path):
db = ChromaVectorStore(path=str(tmp_path))
doc_store = InMemoryDocumentStore()
embedding = AzureOpenAIEmbeddings(
azure_deployment="text-embedding-ada-002",
azure_endpoint="https://test.openai.azure.com/",
api_key="some-key",
api_version="version",
)
index_pipeline = VectorIndexing(
vector_store=db, embedding=embedding, doc_store=doc_store
)
retrieval_pipeline = VectorRetrieval(
vector_store=db, doc_store=doc_store, embedding=embedding
)
index_pipeline(text=Document(text="Hello world"))
output = retrieval_pipeline(text="Hello world")
output1 = retrieval_pipeline(text="Hello world")
assert len(output) == 1, "Expect 1 results"
assert output == output1, "Expect identical results"