Spaces:
Paused
Paused
import json | |
from pathlib import Path | |
from typing import cast | |
from unittest.mock import patch | |
from openai.types.create_embedding_response import CreateEmbeddingResponse | |
from kotaemon.base import Document | |
from kotaemon.embeddings import AzureOpenAIEmbeddings | |
from kotaemon.indices import VectorIndexing, VectorRetrieval | |
from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore | |
with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f: | |
openai_embedding = CreateEmbeddingResponse.model_validate(json.load(f)) | |
def test_indexing(tmp_path): | |
db = ChromaVectorStore(path=str(tmp_path)) | |
doc_store = InMemoryDocumentStore() | |
embedding = AzureOpenAIEmbeddings( | |
azure_deployment="text-embedding-ada-002", | |
azure_endpoint="https://test.openai.azure.com/", | |
api_key="some-key", | |
api_version="version", | |
) | |
pipeline = VectorIndexing(vector_store=db, embedding=embedding, doc_store=doc_store) | |
pipeline.doc_store = cast(InMemoryDocumentStore, pipeline.doc_store) | |
pipeline.vector_store = cast(ChromaVectorStore, pipeline.vector_store) | |
assert pipeline.vector_store._collection.count() == 0, "Expected empty collection" | |
assert len(pipeline.doc_store._store) == 0, "Expected empty doc store" | |
pipeline(text=Document(text="Hello world")) | |
assert pipeline.vector_store._collection.count() == 1, "Index 1 item" | |
assert len(pipeline.doc_store._store) == 1, "Expected 1 document" | |
def test_retrieving(tmp_path): | |
db = ChromaVectorStore(path=str(tmp_path)) | |
doc_store = InMemoryDocumentStore() | |
embedding = AzureOpenAIEmbeddings( | |
azure_deployment="text-embedding-ada-002", | |
azure_endpoint="https://test.openai.azure.com/", | |
api_key="some-key", | |
api_version="version", | |
) | |
index_pipeline = VectorIndexing( | |
vector_store=db, embedding=embedding, doc_store=doc_store | |
) | |
retrieval_pipeline = VectorRetrieval( | |
vector_store=db, doc_store=doc_store, embedding=embedding | |
) | |
index_pipeline(text=Document(text="Hello world")) | |
output = retrieval_pipeline(text="Hello world") | |
output1 = retrieval_pipeline(text="Hello world") | |
assert len(output) == 1, "Expect 1 results" | |
assert output == output1, "Expect identical results" | |