SemanticSearchPOC / adapters /repos /db /vector /hnsw /compress_recall_test.go
KevinStephenson
Adding in weaviate code
b110593
raw
history blame
4.46 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
//go:build !race
package hnsw_test
import (
"context"
"fmt"
"os"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/weaviate/weaviate/adapters/repos/db/vector/common"
"github.com/weaviate/weaviate/adapters/repos/db/vector/compressionhelpers"
"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw"
"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer"
"github.com/weaviate/weaviate/adapters/repos/db/vector/testinghelpers"
"github.com/weaviate/weaviate/entities/cyclemanager"
"github.com/weaviate/weaviate/entities/storobj"
ent "github.com/weaviate/weaviate/entities/vectorindex/hnsw"
)
func distanceWrapper(provider distancer.Provider) func(x, y []float32) float32 {
return func(x, y []float32) float32 {
dist, _, _ := provider.SingleDist(x, y)
return dist
}
}
func Test_NoRaceCompressionRecall(t *testing.T) {
path := t.TempDir()
efConstruction := 64
ef := 64
maxNeighbors := 32
segments := 4
dimensions := 64
vectors_size := 10000
queries_size := 100
fmt.Println("Sift1M PQ")
before := time.Now()
vectors, queries := testinghelpers.RandomVecs(vectors_size, queries_size, dimensions)
testinghelpers.Normalize(vectors)
testinghelpers.Normalize(queries)
k := 100
distancers := []distancer.Provider{
distancer.NewL2SquaredProvider(),
distancer.NewCosineDistanceProvider(),
distancer.NewDotProductProvider(),
}
for _, distancer := range distancers {
truths := make([][]uint64, queries_size)
compressionhelpers.Concurrently(uint64(len(queries)), func(i uint64) {
truths[i], _ = testinghelpers.BruteForce(vectors, queries[i], k, distanceWrapper(distancer))
})
fmt.Printf("generating data took %s\n", time.Since(before))
uc := ent.UserConfig{
MaxConnections: maxNeighbors,
EFConstruction: efConstruction,
EF: ef,
VectorCacheMaxObjects: 10e12,
}
index, _ := hnsw.New(hnsw.Config{
RootPath: path,
ID: "recallbenchmark",
MakeCommitLoggerThunk: hnsw.MakeNoopCommitLogger,
ClassName: "clasRecallBenchmark",
ShardName: "shardRecallBenchmark",
DistanceProvider: distancer,
VectorForIDThunk: func(ctx context.Context, id uint64) ([]float32, error) {
if int(id) >= len(vectors) {
return nil, storobj.NewErrNotFoundf(id, "out of range")
}
return vectors[int(id)], nil
},
TempVectorForIDThunk: func(ctx context.Context, id uint64, container *common.VectorSlice) ([]float32, error) {
copy(container.Slice, vectors[int(id)])
return container.Slice, nil
},
}, uc, cyclemanager.NewCallbackGroupNoop(), cyclemanager.NewCallbackGroupNoop(),
cyclemanager.NewCallbackGroupNoop(), testinghelpers.NewDummyStore(t))
init := time.Now()
compressionhelpers.Concurrently(uint64(vectors_size), func(id uint64) {
index.Add(id, vectors[id])
})
before = time.Now()
fmt.Println("Start compressing...")
uc.PQ = ent.PQConfig{
Enabled: true,
Segments: dimensions / segments,
Centroids: 256,
Encoder: ent.NewDefaultUserConfig().PQ.Encoder,
}
uc.EF = 256
wg := sync.WaitGroup{}
wg.Add(1)
index.UpdateUserConfig(uc, func() {
fmt.Printf("Time to compress: %s\n", time.Since(before))
fmt.Printf("Building the index took %s\n", time.Since(init))
var relevant uint64
var retrieved int
var querying time.Duration = 0
compressionhelpers.Concurrently(uint64(len(queries)), func(i uint64) {
before = time.Now()
results, _, _ := index.SearchByVector(queries[i], k, nil)
querying += time.Since(before)
retrieved += k
relevant += testinghelpers.MatchesInLists(truths[i], results)
})
recall := float32(relevant) / float32(retrieved)
latency := float32(querying.Microseconds()) / float32(queries_size)
fmt.Println(recall, latency)
assert.True(t, recall > 0.9)
err := os.RemoveAll(path)
if err != nil {
fmt.Println(err)
}
wg.Done()
})
wg.Wait()
}
}