Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package flat | |
import ( | |
"context" | |
"encoding/binary" | |
"fmt" | |
"io" | |
"math" | |
"strings" | |
"sync" | |
"sync/atomic" | |
"github.com/pkg/errors" | |
"github.com/sirupsen/logrus" | |
"github.com/weaviate/weaviate/adapters/repos/db/helpers" | |
"github.com/weaviate/weaviate/adapters/repos/db/lsmkv" | |
"github.com/weaviate/weaviate/adapters/repos/db/priorityqueue" | |
"github.com/weaviate/weaviate/adapters/repos/db/vector/cache" | |
"github.com/weaviate/weaviate/adapters/repos/db/vector/common" | |
"github.com/weaviate/weaviate/adapters/repos/db/vector/compressionhelpers" | |
"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer" | |
"github.com/weaviate/weaviate/entities/schema" | |
flatent "github.com/weaviate/weaviate/entities/vectorindex/flat" | |
"github.com/weaviate/weaviate/usecases/floatcomp" | |
) | |
const ( | |
compressionBQ = "bq" | |
compressionPQ = "pq" | |
compressionNone = "none" | |
) | |
type flat struct { | |
sync.Mutex | |
id string | |
dims int32 | |
store *lsmkv.Store | |
logger logrus.FieldLogger | |
distancerProvider distancer.Provider | |
trackDimensionsOnce sync.Once | |
rescore int64 | |
bq compressionhelpers.BinaryQuantizer | |
pqResults *common.PqMaxPool | |
pool *pools | |
compression string | |
bqCache cache.Cache[uint64] | |
} | |
type distanceCalc func(vecAsBytes []byte) (float32, error) | |
func New(cfg Config, uc flatent.UserConfig, store *lsmkv.Store) (*flat, error) { | |
if err := cfg.Validate(); err != nil { | |
return nil, errors.Wrap(err, "invalid config") | |
} | |
logger := cfg.Logger | |
if logger == nil { | |
l := logrus.New() | |
l.Out = io.Discard | |
logger = l | |
} | |
index := &flat{ | |
id: cfg.ID, | |
logger: logger, | |
distancerProvider: cfg.DistanceProvider, | |
rescore: extractCompressionRescore(uc), | |
pqResults: common.NewPqMaxPool(100), | |
compression: extractCompression(uc), | |
pool: newPools(), | |
store: store, | |
} | |
index.initBuckets(context.Background()) | |
if uc.BQ.Enabled && uc.BQ.Cache { | |
index.bqCache = cache.NewShardedUInt64LockCache(index.getBQVector, uc.VectorCacheMaxObjects, cfg.Logger, 0) | |
} | |
return index, nil | |
} | |
func (flat *flat) getBQVector(ctx context.Context, id uint64) ([]uint64, error) { | |
key := flat.pool.byteSlicePool.Get(8) | |
defer flat.pool.byteSlicePool.Put(key) | |
binary.BigEndian.PutUint64(key.slice, id) | |
bytes, err := flat.store.Bucket(helpers.VectorsCompressedBucketLSM).Get(key.slice) | |
if err != nil { | |
return nil, err | |
} | |
return uint64SliceFromByteSlice(bytes, make([]uint64, len(bytes)/8)), nil | |
} | |
func extractCompression(uc flatent.UserConfig) string { | |
if uc.BQ.Enabled && uc.PQ.Enabled { | |
return compressionNone | |
} | |
if uc.BQ.Enabled { | |
return compressionBQ | |
} | |
if uc.PQ.Enabled { | |
return compressionPQ | |
} | |
return compressionNone | |
} | |
func extractCompressionRescore(uc flatent.UserConfig) int64 { | |
compression := extractCompression(uc) | |
switch compression { | |
case compressionPQ: | |
return int64(uc.PQ.RescoreLimit) | |
case compressionBQ: | |
return int64(uc.BQ.RescoreLimit) | |
default: | |
return 0 | |
} | |
} | |
func (index *flat) storeCompressedVector(id uint64, vector []byte) { | |
index.storeGenericVector(id, vector, helpers.VectorsCompressedBucketLSM) | |
} | |
func (index *flat) storeVector(id uint64, vector []byte) { | |
index.storeGenericVector(id, vector, helpers.VectorsBucketLSM) | |
} | |
func (index *flat) storeGenericVector(id uint64, vector []byte, bucket string) { | |
idBytes := make([]byte, 8) | |
binary.BigEndian.PutUint64(idBytes, id) | |
index.store.Bucket(bucket).Put(idBytes, vector) | |
} | |
func (index *flat) isBQ() bool { | |
return index.compression == compressionBQ | |
} | |
func (index *flat) isBQCached() bool { | |
return index.bqCache != nil | |
} | |
func (index *flat) Compressed() bool { | |
return index.compression != compressionNone | |
} | |
func (index *flat) initBuckets(ctx context.Context) error { | |
if err := index.store.CreateOrLoadBucket(ctx, helpers.VectorsBucketLSM, | |
lsmkv.WithForceCompation(true), | |
lsmkv.WithUseBloomFilter(false), | |
lsmkv.WithCalcCountNetAdditions(false), | |
); err != nil { | |
return fmt.Errorf("Create or load flat vectors bucket: %w", err) | |
} | |
if index.isBQ() { | |
if err := index.store.CreateOrLoadBucket(ctx, helpers.VectorsCompressedBucketLSM, | |
lsmkv.WithForceCompation(true), | |
lsmkv.WithUseBloomFilter(false), | |
lsmkv.WithCalcCountNetAdditions(false), | |
); err != nil { | |
return fmt.Errorf("Create or load flat compressed vectors bucket: %w", err) | |
} | |
} | |
return nil | |
} | |
func (index *flat) AddBatch(ctx context.Context, ids []uint64, vectors [][]float32) error { | |
if err := ctx.Err(); err != nil { | |
return err | |
} | |
if len(ids) != len(vectors) { | |
return errors.Errorf("ids and vectors sizes does not match") | |
} | |
if len(ids) == 0 { | |
return errors.Errorf("insertBatch called with empty lists") | |
} | |
for i := range ids { | |
if err := ctx.Err(); err != nil { | |
return err | |
} | |
if err := index.Add(ids[i], vectors[i]); err != nil { | |
return err | |
} | |
} | |
return nil | |
} | |
func byteSliceFromUint64Slice(vector []uint64, slice []byte) []byte { | |
for i := range vector { | |
binary.LittleEndian.PutUint64(slice[i*8:], vector[i]) | |
} | |
return slice | |
} | |
func byteSliceFromFloat32Slice(vector []float32, slice []byte) []byte { | |
for i := range vector { | |
binary.LittleEndian.PutUint32(slice[i*4:], math.Float32bits(vector[i])) | |
} | |
return slice | |
} | |
func uint64SliceFromByteSlice(vector []byte, slice []uint64) []uint64 { | |
for i := range slice { | |
slice[i] = binary.LittleEndian.Uint64(vector[i*8:]) | |
} | |
return slice | |
} | |
func float32SliceFromByteSlice(vector []byte, slice []float32) []float32 { | |
for i := range slice { | |
slice[i] = math.Float32frombits(binary.LittleEndian.Uint32(vector[i*4:])) | |
} | |
return slice | |
} | |
func (index *flat) Add(id uint64, vector []float32) error { | |
index.trackDimensionsOnce.Do(func() { | |
atomic.StoreInt32(&index.dims, int32(len(vector))) | |
if index.isBQ() { | |
index.bq = compressionhelpers.NewBinaryQuantizer(nil) | |
} | |
}) | |
if len(vector) != int(index.dims) { | |
return errors.Errorf("insert called with a vector of the wrong size") | |
} | |
vector = index.normalized(vector) | |
slice := make([]byte, len(vector)*4) | |
index.storeVector(id, byteSliceFromFloat32Slice(vector, slice)) | |
if index.isBQ() { | |
vectorBQ := index.bq.Encode(vector) | |
if index.isBQCached() { | |
index.bqCache.Grow(id) | |
index.bqCache.Preload(id, vectorBQ) | |
} | |
slice = make([]byte, len(vectorBQ)*8) | |
index.storeCompressedVector(id, byteSliceFromUint64Slice(vectorBQ, slice)) | |
} | |
return nil | |
} | |
func (index *flat) Delete(ids ...uint64) error { | |
for i := range ids { | |
if index.isBQCached() { | |
index.bqCache.Delete(context.Background(), ids[i]) | |
} | |
idBytes := make([]byte, 8) | |
binary.BigEndian.PutUint64(idBytes, ids[i]) | |
if err := index.store.Bucket(helpers.VectorsBucketLSM).Delete(idBytes); err != nil { | |
return err | |
} | |
if index.isBQ() { | |
if err := index.store.Bucket(helpers.VectorsCompressedBucketLSM).Delete(idBytes); err != nil { | |
return err | |
} | |
} | |
} | |
return nil | |
} | |
func (index *flat) searchTimeRescore(k int) int { | |
// load atomically, so we can get away with concurrent updates of the | |
// userconfig without having to set a lock each time we try to read - which | |
// can be so common that it would cause considerable overhead | |
if rescore := int(atomic.LoadInt64(&index.rescore)); rescore > k { | |
return rescore | |
} | |
return k | |
} | |
func (index *flat) SearchByVector(vector []float32, k int, allow helpers.AllowList) ([]uint64, []float32, error) { | |
switch index.compression { | |
case compressionBQ: | |
return index.searchByVectorBQ(vector, k, allow) | |
case compressionPQ: | |
// use uncompressed for now | |
fallthrough | |
default: | |
return index.searchByVector(vector, k, allow) | |
} | |
} | |
func (index *flat) searchByVector(vector []float32, k int, allow helpers.AllowList) ([]uint64, []float32, error) { | |
heap := index.pqResults.GetMax(k) | |
defer index.pqResults.Put(heap) | |
vector = index.normalized(vector) | |
if err := index.findTopVectors(heap, allow, k, | |
index.store.Bucket(helpers.VectorsBucketLSM).Cursor, | |
index.createDistanceCalc(vector), | |
); err != nil { | |
return nil, nil, err | |
} | |
ids, dists := index.extractHeap(heap) | |
return ids, dists, nil | |
} | |
func (index *flat) createDistanceCalc(vector []float32) distanceCalc { | |
return func(vecAsBytes []byte) (float32, error) { | |
vecSlice := index.pool.float32SlicePool.Get(len(vecAsBytes) / 4) | |
defer index.pool.float32SlicePool.Put(vecSlice) | |
candidate := float32SliceFromByteSlice(vecAsBytes, vecSlice.slice) | |
distance, _, err := index.distancerProvider.SingleDist(vector, candidate) | |
return distance, err | |
} | |
} | |
func (index *flat) searchByVectorBQ(vector []float32, k int, allow helpers.AllowList) ([]uint64, []float32, error) { | |
rescore := index.searchTimeRescore(k) | |
heap := index.pqResults.GetMax(rescore) | |
defer index.pqResults.Put(heap) | |
vector = index.normalized(vector) | |
vectorBQ := index.bq.Encode(vector) | |
if index.isBQCached() { | |
if err := index.findTopVectorsCached(heap, allow, rescore, vectorBQ); err != nil { | |
return nil, nil, err | |
} | |
} else { | |
if err := index.findTopVectors(heap, allow, rescore, | |
index.store.Bucket(helpers.VectorsCompressedBucketLSM).Cursor, | |
index.createDistanceCalcBQ(vectorBQ), | |
); err != nil { | |
return nil, nil, err | |
} | |
} | |
distanceCalc := index.createDistanceCalc(vector) | |
idsSlice := index.pool.uint64SlicePool.Get(heap.Len()) | |
defer index.pool.uint64SlicePool.Put(idsSlice) | |
for i := range idsSlice.slice { | |
idsSlice.slice[i] = heap.Pop().ID | |
} | |
for _, id := range idsSlice.slice { | |
candidateAsBytes, err := index.vectorById(id) | |
if err != nil { | |
return nil, nil, err | |
} | |
distance, err := distanceCalc(candidateAsBytes) | |
if err != nil { | |
return nil, nil, err | |
} | |
index.insertToHeap(heap, k, id, distance) | |
} | |
ids, dists := index.extractHeap(heap) | |
return ids, dists, nil | |
} | |
func (index *flat) createDistanceCalcBQ(vectorBQ []uint64) distanceCalc { | |
return func(vecAsBytes []byte) (float32, error) { | |
vecSliceBQ := index.pool.uint64SlicePool.Get(len(vecAsBytes) / 8) | |
defer index.pool.uint64SlicePool.Put(vecSliceBQ) | |
candidate := uint64SliceFromByteSlice(vecAsBytes, vecSliceBQ.slice) | |
return index.bq.DistanceBetweenCompressedVectors(candidate, vectorBQ) | |
} | |
} | |
func (index *flat) vectorById(id uint64) ([]byte, error) { | |
idSlice := index.pool.byteSlicePool.Get(8) | |
defer index.pool.byteSlicePool.Put(idSlice) | |
binary.BigEndian.PutUint64(idSlice.slice, id) | |
return index.store.Bucket(helpers.VectorsBucketLSM).Get(idSlice.slice) | |
} | |
// populates given heap with smallest distances and corresponding ids calculated by | |
// distanceCalc | |
func (index *flat) findTopVectors(heap *priorityqueue.Queue[any], | |
allow helpers.AllowList, limit int, cursorFn func() *lsmkv.CursorReplace, | |
distanceCalc distanceCalc, | |
) error { | |
var key []byte | |
var v []byte | |
var id uint64 | |
allowMax := uint64(0) | |
cursor := cursorFn() | |
defer cursor.Close() | |
if allow != nil { | |
// nothing allowed, skip search | |
if allow.IsEmpty() { | |
return nil | |
} | |
allowMax = allow.Max() | |
idSlice := index.pool.byteSlicePool.Get(8) | |
binary.BigEndian.PutUint64(idSlice.slice, allow.Min()) | |
key, v = cursor.Seek(idSlice.slice) | |
index.pool.byteSlicePool.Put(idSlice) | |
} else { | |
key, v = cursor.First() | |
} | |
// since keys are sorted, once key/id get greater than max allowed one | |
// further search can be stopped | |
for ; key != nil && (allow == nil || id <= allowMax); key, v = cursor.Next() { | |
id = binary.BigEndian.Uint64(key) | |
if allow == nil || allow.Contains(id) { | |
distance, err := distanceCalc(v) | |
if err != nil { | |
return err | |
} | |
index.insertToHeap(heap, limit, id, distance) | |
} | |
} | |
return nil | |
} | |
// populates given heap with smallest distances and corresponding ids calculated by | |
// distanceCalc | |
func (index *flat) findTopVectorsCached(heap *priorityqueue.Queue[any], | |
allow helpers.AllowList, limit int, vectorBQ []uint64, | |
) error { | |
var id uint64 | |
allowMax := uint64(0) | |
if allow != nil { | |
// nothing allowed, skip search | |
if allow.IsEmpty() { | |
return nil | |
} | |
allowMax = allow.Max() | |
id = allow.Min() | |
} else { | |
id = 0 | |
} | |
all := index.bqCache.Len() | |
// since keys are sorted, once key/id get greater than max allowed one | |
// further search can be stopped | |
for ; id < uint64(all) && (allow == nil || id <= allowMax); id++ { | |
if allow == nil || allow.Contains(id) { | |
vec, err := index.bqCache.Get(context.Background(), id) | |
if err != nil { | |
return err | |
} | |
if len(vec) == 0 { | |
continue | |
} | |
distance, err := index.bq.DistanceBetweenCompressedVectors(vec, vectorBQ) | |
if err != nil { | |
return err | |
} | |
index.insertToHeap(heap, limit, id, distance) | |
} | |
} | |
return nil | |
} | |
func (index *flat) insertToHeap(heap *priorityqueue.Queue[any], | |
limit int, id uint64, distance float32, | |
) { | |
if heap.Len() < limit { | |
heap.Insert(id, distance) | |
} else if heap.Top().Dist > distance { | |
heap.Pop() | |
heap.Insert(id, distance) | |
} | |
} | |
func (index *flat) extractHeap(heap *priorityqueue.Queue[any], | |
) ([]uint64, []float32) { | |
len := heap.Len() | |
ids := make([]uint64, len) | |
dists := make([]float32, len) | |
for i := len - 1; i >= 0; i-- { | |
item := heap.Pop() | |
ids[i] = item.ID | |
dists[i] = item.Dist | |
} | |
return ids, dists | |
} | |
func (index *flat) normalized(vector []float32) []float32 { | |
if index.distancerProvider.Type() == "cosine-dot" { | |
// cosine-dot requires normalized vectors, as the dot product and cosine | |
// similarity are only identical if the vector is normalized | |
return distancer.Normalize(vector) | |
} | |
return vector | |
} | |
func (index *flat) SearchByVectorDistance(vector []float32, targetDistance float32, maxLimit int64, allow helpers.AllowList) ([]uint64, []float32, error) { | |
var ( | |
searchParams = newSearchByDistParams(maxLimit) | |
resultIDs []uint64 | |
resultDist []float32 | |
) | |
recursiveSearch := func() (bool, error) { | |
totalLimit := searchParams.TotalLimit() | |
ids, dist, err := index.SearchByVector(vector, totalLimit, allow) | |
if err != nil { | |
return false, errors.Wrap(err, "vector search") | |
} | |
// if there is less results than given limit search can be stopped | |
shouldContinue := !(len(ids) < totalLimit) | |
// ensures the indexes aren't out of range | |
offsetCap := searchParams.OffsetCapacity(ids) | |
totalLimitCap := searchParams.TotalLimitCapacity(ids) | |
if offsetCap == totalLimitCap { | |
return false, nil | |
} | |
ids, dist = ids[offsetCap:totalLimitCap], dist[offsetCap:totalLimitCap] | |
for i := range ids { | |
if aboveThresh := dist[i] <= targetDistance; aboveThresh || | |
floatcomp.InDelta(float64(dist[i]), float64(targetDistance), 1e-6) { | |
resultIDs = append(resultIDs, ids[i]) | |
resultDist = append(resultDist, dist[i]) | |
} else { | |
// as soon as we encounter a certainty which | |
// is below threshold, we can stop searching | |
shouldContinue = false | |
break | |
} | |
} | |
return shouldContinue, nil | |
} | |
var shouldContinue bool | |
var err error | |
for shouldContinue, err = recursiveSearch(); shouldContinue && err == nil; { | |
searchParams.Iterate() | |
if searchParams.MaxLimitReached() { | |
index.logger. | |
WithField("action", "unlimited_vector_search"). | |
Warnf("maximum search limit of %d results has been reached", | |
searchParams.MaximumSearchLimit()) | |
break | |
} | |
} | |
if err != nil { | |
return nil, nil, err | |
} | |
return resultIDs, resultDist, nil | |
} | |
func (index *flat) UpdateUserConfig(updated schema.VectorIndexConfig, callback func()) error { | |
parsed, ok := updated.(flatent.UserConfig) | |
if !ok { | |
callback() | |
return errors.Errorf("config is not UserConfig, but %T", updated) | |
} | |
// Store automatically as a lock here would be very expensive, this value is | |
// read on every single user-facing search, which can be highly concurrent | |
atomic.StoreInt64(&index.rescore, extractCompressionRescore(parsed)) | |
callback() | |
return nil | |
} | |
func (index *flat) Drop(ctx context.Context) error { | |
// nothing to do here | |
// Shard::drop will take care of handling store's buckets | |
return nil | |
} | |
func (index *flat) Flush() error { | |
// nothing to do here | |
// Shard will take care of handling store's buckets | |
return nil | |
} | |
func (index *flat) Shutdown(ctx context.Context) error { | |
// nothing to do here | |
// Shard::shutdown will take care of handling store's buckets | |
return nil | |
} | |
func (index *flat) SwitchCommitLogs(context.Context) error { | |
return nil | |
} | |
func (index *flat) ListFiles(ctx context.Context, basePath string) ([]string, error) { | |
// nothing to do here | |
// Shard::ListBackupFiles will take care of handling store's buckets | |
return []string{}, nil | |
} | |
func (i *flat) ValidateBeforeInsert(vector []float32) error { | |
return nil | |
} | |
func (index *flat) PostStartup() { | |
if !index.isBQCached() { | |
return | |
} | |
cursor := index.store.Bucket(helpers.VectorsCompressedBucketLSM).Cursor() | |
defer cursor.Close() | |
for key, v := cursor.First(); key != nil; key, v = cursor.Next() { | |
id := binary.BigEndian.Uint64(key) | |
index.bqCache.Preload(id, uint64SliceFromByteSlice(v, make([]uint64, len(v)/8))) | |
} | |
} | |
func (index *flat) Dump(labels ...string) { | |
if len(labels) > 0 { | |
fmt.Printf("--------------------------------------------------\n") | |
fmt.Printf("-- %s\n", strings.Join(labels, ", ")) | |
} | |
fmt.Printf("--------------------------------------------------\n") | |
fmt.Printf("ID: %s\n", index.id) | |
fmt.Printf("--------------------------------------------------\n") | |
} | |
func (index *flat) DistanceBetweenVectors(x, y []float32) (float32, bool, error) { | |
return index.distancerProvider.SingleDist(x, y) | |
} | |
func (index *flat) ContainsNode(id uint64) bool { | |
return true | |
} | |
func (index *flat) DistancerProvider() distancer.Provider { | |
return index.distancerProvider | |
} | |
func newSearchByDistParams(maxLimit int64) *common.SearchByDistParams { | |
initialOffset := 0 | |
initialLimit := common.DefaultSearchByDistInitialLimit | |
return common.NewSearchByDistParams(initialOffset, initialLimit, initialOffset+initialLimit, maxLimit) | |
} | |
type immutableParameter struct { | |
accessor func(c flatent.UserConfig) interface{} | |
name string | |
} | |
func validateImmutableField(u immutableParameter, | |
previous, next flatent.UserConfig, | |
) error { | |
oldField := u.accessor(previous) | |
newField := u.accessor(next) | |
if oldField != newField { | |
return errors.Errorf("%s is immutable: attempted change from \"%v\" to \"%v\"", | |
u.name, oldField, newField) | |
} | |
return nil | |
} | |
func ValidateUserConfigUpdate(initial, updated schema.VectorIndexConfig) error { | |
initialParsed, ok := initial.(flatent.UserConfig) | |
if !ok { | |
return errors.Errorf("initial is not UserConfig, but %T", initial) | |
} | |
updatedParsed, ok := updated.(flatent.UserConfig) | |
if !ok { | |
return errors.Errorf("updated is not UserConfig, but %T", updated) | |
} | |
immutableFields := []immutableParameter{ | |
{ | |
name: "distance", | |
accessor: func(c flatent.UserConfig) interface{} { return c.Distance }, | |
}, | |
{ | |
name: "bq.cache", | |
accessor: func(c flatent.UserConfig) interface{} { return c.BQ.Cache }, | |
}, | |
{ | |
name: "pq.cache", | |
accessor: func(c flatent.UserConfig) interface{} { return c.PQ.Cache }, | |
}, | |
{ | |
name: "pq", | |
accessor: func(c flatent.UserConfig) interface{} { return c.PQ.Enabled }, | |
}, | |
{ | |
name: "bq", | |
accessor: func(c flatent.UserConfig) interface{} { return c.BQ.Enabled }, | |
}, | |
} | |
for _, u := range immutableFields { | |
if err := validateImmutableField(u, initialParsed, updatedParsed); err != nil { | |
return err | |
} | |
} | |
return nil | |
} | |