Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package classification | |
import ( | |
"fmt" | |
"testing" | |
"github.com/go-openapi/strfmt" | |
"github.com/stretchr/testify/assert" | |
"github.com/weaviate/weaviate/entities/search" | |
) | |
func testParallelBatchWrite(batchWriter Writer, items search.Results, resultChannel chan<- WriterResults) { | |
batchWriter.Start() | |
for _, item := range items { | |
batchWriter.Store(item) | |
} | |
res := batchWriter.Stop() | |
resultChannel <- res | |
} | |
func generateSearchResultsToSave(size int) search.Results { | |
items := make(search.Results, 0) | |
for i := 0; i < size; i++ { | |
res := search.Result{ | |
ID: strfmt.UUID(fmt.Sprintf("75ba35af-6a08-40ae-b442-3bec69b35%03d", i)), | |
ClassName: "Article", | |
Vector: []float32{0.78, 0, 0}, | |
Schema: map[string]interface{}{ | |
"description": "Barack Obama is a former US president", | |
}, | |
} | |
items = append(items, res) | |
} | |
return items | |
} | |
func TestWriter_SimpleWrite(t *testing.T) { | |
// given | |
searchResultsToBeSaved := testDataToBeClassified() | |
vectorRepo := newFakeVectorRepoKNN(searchResultsToBeSaved, testDataAlreadyClassified()) | |
batchWriter := newBatchWriter(vectorRepo) | |
// when | |
batchWriter.Start() | |
for _, item := range searchResultsToBeSaved { | |
batchWriter.Store(item) | |
} | |
res := batchWriter.Stop() | |
// then | |
assert.Equal(t, int64(len(searchResultsToBeSaved)), res.SuccessCount()) | |
assert.Equal(t, int64(0), res.ErrorCount()) | |
assert.Equal(t, nil, res.Err()) | |
} | |
func TestWriter_LoadWrites(t *testing.T) { | |
// given | |
searchResultsCount := 640 | |
searchResultsToBeSaved := generateSearchResultsToSave(searchResultsCount) | |
vectorRepo := newFakeVectorRepoKNN(searchResultsToBeSaved, testDataAlreadyClassified()) | |
batchWriter := newBatchWriter(vectorRepo) | |
// when | |
batchWriter.Start() | |
for _, item := range searchResultsToBeSaved { | |
batchWriter.Store(item) | |
} | |
res := batchWriter.Stop() | |
// then | |
assert.Equal(t, int64(searchResultsCount), res.SuccessCount()) | |
assert.Equal(t, int64(0), res.ErrorCount()) | |
assert.Equal(t, nil, res.Err()) | |
} | |
func TestWriter_ParallelLoadWrites(t *testing.T) { | |
// given | |
searchResultsToBeSavedCount1 := 600 | |
searchResultsToBeSavedCount2 := 440 | |
searchResultsToBeSaved1 := generateSearchResultsToSave(searchResultsToBeSavedCount1) | |
searchResultsToBeSaved2 := generateSearchResultsToSave(searchResultsToBeSavedCount2) | |
vectorRepo1 := newFakeVectorRepoKNN(searchResultsToBeSaved1, testDataAlreadyClassified()) | |
batchWriter1 := newBatchWriter(vectorRepo1) | |
resChannel1 := make(chan WriterResults) | |
vectorRepo2 := newFakeVectorRepoKNN(searchResultsToBeSaved2, testDataAlreadyClassified()) | |
batchWriter2 := newBatchWriter(vectorRepo2) | |
resChannel2 := make(chan WriterResults) | |
// when | |
go testParallelBatchWrite(batchWriter1, searchResultsToBeSaved1, resChannel1) | |
go testParallelBatchWrite(batchWriter2, searchResultsToBeSaved2, resChannel2) | |
res1 := <-resChannel1 | |
res2 := <-resChannel2 | |
// then | |
assert.Equal(t, int64(searchResultsToBeSavedCount1), res1.SuccessCount()) | |
assert.Equal(t, int64(0), res1.ErrorCount()) | |
assert.Equal(t, nil, res1.Err()) | |
assert.Equal(t, int64(searchResultsToBeSavedCount2), res2.SuccessCount()) | |
assert.Equal(t, int64(0), res2.ErrorCount()) | |
assert.Equal(t, nil, res2.Err()) | |
} | |