Spaces:
Running
Running
File size: 3,597 Bytes
b110593 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package classification
import (
"fmt"
"testing"
"github.com/go-openapi/strfmt"
"github.com/stretchr/testify/assert"
"github.com/weaviate/weaviate/entities/search"
)
func testParallelBatchWrite(batchWriter Writer, items search.Results, resultChannel chan<- WriterResults) {
batchWriter.Start()
for _, item := range items {
batchWriter.Store(item)
}
res := batchWriter.Stop()
resultChannel <- res
}
func generateSearchResultsToSave(size int) search.Results {
items := make(search.Results, 0)
for i := 0; i < size; i++ {
res := search.Result{
ID: strfmt.UUID(fmt.Sprintf("75ba35af-6a08-40ae-b442-3bec69b35%03d", i)),
ClassName: "Article",
Vector: []float32{0.78, 0, 0},
Schema: map[string]interface{}{
"description": "Barack Obama is a former US president",
},
}
items = append(items, res)
}
return items
}
func TestWriter_SimpleWrite(t *testing.T) {
// given
searchResultsToBeSaved := testDataToBeClassified()
vectorRepo := newFakeVectorRepoKNN(searchResultsToBeSaved, testDataAlreadyClassified())
batchWriter := newBatchWriter(vectorRepo)
// when
batchWriter.Start()
for _, item := range searchResultsToBeSaved {
batchWriter.Store(item)
}
res := batchWriter.Stop()
// then
assert.Equal(t, int64(len(searchResultsToBeSaved)), res.SuccessCount())
assert.Equal(t, int64(0), res.ErrorCount())
assert.Equal(t, nil, res.Err())
}
func TestWriter_LoadWrites(t *testing.T) {
// given
searchResultsCount := 640
searchResultsToBeSaved := generateSearchResultsToSave(searchResultsCount)
vectorRepo := newFakeVectorRepoKNN(searchResultsToBeSaved, testDataAlreadyClassified())
batchWriter := newBatchWriter(vectorRepo)
// when
batchWriter.Start()
for _, item := range searchResultsToBeSaved {
batchWriter.Store(item)
}
res := batchWriter.Stop()
// then
assert.Equal(t, int64(searchResultsCount), res.SuccessCount())
assert.Equal(t, int64(0), res.ErrorCount())
assert.Equal(t, nil, res.Err())
}
func TestWriter_ParallelLoadWrites(t *testing.T) {
// given
searchResultsToBeSavedCount1 := 600
searchResultsToBeSavedCount2 := 440
searchResultsToBeSaved1 := generateSearchResultsToSave(searchResultsToBeSavedCount1)
searchResultsToBeSaved2 := generateSearchResultsToSave(searchResultsToBeSavedCount2)
vectorRepo1 := newFakeVectorRepoKNN(searchResultsToBeSaved1, testDataAlreadyClassified())
batchWriter1 := newBatchWriter(vectorRepo1)
resChannel1 := make(chan WriterResults)
vectorRepo2 := newFakeVectorRepoKNN(searchResultsToBeSaved2, testDataAlreadyClassified())
batchWriter2 := newBatchWriter(vectorRepo2)
resChannel2 := make(chan WriterResults)
// when
go testParallelBatchWrite(batchWriter1, searchResultsToBeSaved1, resChannel1)
go testParallelBatchWrite(batchWriter2, searchResultsToBeSaved2, resChannel2)
res1 := <-resChannel1
res2 := <-resChannel2
// then
assert.Equal(t, int64(searchResultsToBeSavedCount1), res1.SuccessCount())
assert.Equal(t, int64(0), res1.ErrorCount())
assert.Equal(t, nil, res1.Err())
assert.Equal(t, int64(searchResultsToBeSavedCount2), res2.SuccessCount())
assert.Equal(t, int64(0), res2.ErrorCount())
assert.Equal(t, nil, res2.Err())
}
|