KevinStephenson
Adding in weaviate code
b110593
raw
history blame
20.3 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package flat
import (
"context"
"encoding/binary"
"fmt"
"io"
"math"
"strings"
"sync"
"sync/atomic"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/weaviate/weaviate/adapters/repos/db/helpers"
"github.com/weaviate/weaviate/adapters/repos/db/lsmkv"
"github.com/weaviate/weaviate/adapters/repos/db/priorityqueue"
"github.com/weaviate/weaviate/adapters/repos/db/vector/cache"
"github.com/weaviate/weaviate/adapters/repos/db/vector/common"
"github.com/weaviate/weaviate/adapters/repos/db/vector/compressionhelpers"
"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer"
"github.com/weaviate/weaviate/entities/schema"
flatent "github.com/weaviate/weaviate/entities/vectorindex/flat"
"github.com/weaviate/weaviate/usecases/floatcomp"
)
const (
compressionBQ = "bq"
compressionPQ = "pq"
compressionNone = "none"
)
type flat struct {
sync.Mutex
id string
dims int32
store *lsmkv.Store
logger logrus.FieldLogger
distancerProvider distancer.Provider
trackDimensionsOnce sync.Once
rescore int64
bq compressionhelpers.BinaryQuantizer
pqResults *common.PqMaxPool
pool *pools
compression string
bqCache cache.Cache[uint64]
}
type distanceCalc func(vecAsBytes []byte) (float32, error)
func New(cfg Config, uc flatent.UserConfig, store *lsmkv.Store) (*flat, error) {
if err := cfg.Validate(); err != nil {
return nil, errors.Wrap(err, "invalid config")
}
logger := cfg.Logger
if logger == nil {
l := logrus.New()
l.Out = io.Discard
logger = l
}
index := &flat{
id: cfg.ID,
logger: logger,
distancerProvider: cfg.DistanceProvider,
rescore: extractCompressionRescore(uc),
pqResults: common.NewPqMaxPool(100),
compression: extractCompression(uc),
pool: newPools(),
store: store,
}
index.initBuckets(context.Background())
if uc.BQ.Enabled && uc.BQ.Cache {
index.bqCache = cache.NewShardedUInt64LockCache(index.getBQVector, uc.VectorCacheMaxObjects, cfg.Logger, 0)
}
return index, nil
}
func (flat *flat) getBQVector(ctx context.Context, id uint64) ([]uint64, error) {
key := flat.pool.byteSlicePool.Get(8)
defer flat.pool.byteSlicePool.Put(key)
binary.BigEndian.PutUint64(key.slice, id)
bytes, err := flat.store.Bucket(helpers.VectorsCompressedBucketLSM).Get(key.slice)
if err != nil {
return nil, err
}
return uint64SliceFromByteSlice(bytes, make([]uint64, len(bytes)/8)), nil
}
func extractCompression(uc flatent.UserConfig) string {
if uc.BQ.Enabled && uc.PQ.Enabled {
return compressionNone
}
if uc.BQ.Enabled {
return compressionBQ
}
if uc.PQ.Enabled {
return compressionPQ
}
return compressionNone
}
func extractCompressionRescore(uc flatent.UserConfig) int64 {
compression := extractCompression(uc)
switch compression {
case compressionPQ:
return int64(uc.PQ.RescoreLimit)
case compressionBQ:
return int64(uc.BQ.RescoreLimit)
default:
return 0
}
}
func (index *flat) storeCompressedVector(id uint64, vector []byte) {
index.storeGenericVector(id, vector, helpers.VectorsCompressedBucketLSM)
}
func (index *flat) storeVector(id uint64, vector []byte) {
index.storeGenericVector(id, vector, helpers.VectorsBucketLSM)
}
func (index *flat) storeGenericVector(id uint64, vector []byte, bucket string) {
idBytes := make([]byte, 8)
binary.BigEndian.PutUint64(idBytes, id)
index.store.Bucket(bucket).Put(idBytes, vector)
}
func (index *flat) isBQ() bool {
return index.compression == compressionBQ
}
func (index *flat) isBQCached() bool {
return index.bqCache != nil
}
func (index *flat) Compressed() bool {
return index.compression != compressionNone
}
func (index *flat) initBuckets(ctx context.Context) error {
if err := index.store.CreateOrLoadBucket(ctx, helpers.VectorsBucketLSM,
lsmkv.WithForceCompation(true),
lsmkv.WithUseBloomFilter(false),
lsmkv.WithCalcCountNetAdditions(false),
); err != nil {
return fmt.Errorf("Create or load flat vectors bucket: %w", err)
}
if index.isBQ() {
if err := index.store.CreateOrLoadBucket(ctx, helpers.VectorsCompressedBucketLSM,
lsmkv.WithForceCompation(true),
lsmkv.WithUseBloomFilter(false),
lsmkv.WithCalcCountNetAdditions(false),
); err != nil {
return fmt.Errorf("Create or load flat compressed vectors bucket: %w", err)
}
}
return nil
}
func (index *flat) AddBatch(ctx context.Context, ids []uint64, vectors [][]float32) error {
if err := ctx.Err(); err != nil {
return err
}
if len(ids) != len(vectors) {
return errors.Errorf("ids and vectors sizes does not match")
}
if len(ids) == 0 {
return errors.Errorf("insertBatch called with empty lists")
}
for i := range ids {
if err := ctx.Err(); err != nil {
return err
}
if err := index.Add(ids[i], vectors[i]); err != nil {
return err
}
}
return nil
}
func byteSliceFromUint64Slice(vector []uint64, slice []byte) []byte {
for i := range vector {
binary.LittleEndian.PutUint64(slice[i*8:], vector[i])
}
return slice
}
func byteSliceFromFloat32Slice(vector []float32, slice []byte) []byte {
for i := range vector {
binary.LittleEndian.PutUint32(slice[i*4:], math.Float32bits(vector[i]))
}
return slice
}
func uint64SliceFromByteSlice(vector []byte, slice []uint64) []uint64 {
for i := range slice {
slice[i] = binary.LittleEndian.Uint64(vector[i*8:])
}
return slice
}
func float32SliceFromByteSlice(vector []byte, slice []float32) []float32 {
for i := range slice {
slice[i] = math.Float32frombits(binary.LittleEndian.Uint32(vector[i*4:]))
}
return slice
}
func (index *flat) Add(id uint64, vector []float32) error {
index.trackDimensionsOnce.Do(func() {
atomic.StoreInt32(&index.dims, int32(len(vector)))
if index.isBQ() {
index.bq = compressionhelpers.NewBinaryQuantizer(nil)
}
})
if len(vector) != int(index.dims) {
return errors.Errorf("insert called with a vector of the wrong size")
}
vector = index.normalized(vector)
slice := make([]byte, len(vector)*4)
index.storeVector(id, byteSliceFromFloat32Slice(vector, slice))
if index.isBQ() {
vectorBQ := index.bq.Encode(vector)
if index.isBQCached() {
index.bqCache.Grow(id)
index.bqCache.Preload(id, vectorBQ)
}
slice = make([]byte, len(vectorBQ)*8)
index.storeCompressedVector(id, byteSliceFromUint64Slice(vectorBQ, slice))
}
return nil
}
func (index *flat) Delete(ids ...uint64) error {
for i := range ids {
if index.isBQCached() {
index.bqCache.Delete(context.Background(), ids[i])
}
idBytes := make([]byte, 8)
binary.BigEndian.PutUint64(idBytes, ids[i])
if err := index.store.Bucket(helpers.VectorsBucketLSM).Delete(idBytes); err != nil {
return err
}
if index.isBQ() {
if err := index.store.Bucket(helpers.VectorsCompressedBucketLSM).Delete(idBytes); err != nil {
return err
}
}
}
return nil
}
func (index *flat) searchTimeRescore(k int) int {
// load atomically, so we can get away with concurrent updates of the
// userconfig without having to set a lock each time we try to read - which
// can be so common that it would cause considerable overhead
if rescore := int(atomic.LoadInt64(&index.rescore)); rescore > k {
return rescore
}
return k
}
func (index *flat) SearchByVector(vector []float32, k int, allow helpers.AllowList) ([]uint64, []float32, error) {
switch index.compression {
case compressionBQ:
return index.searchByVectorBQ(vector, k, allow)
case compressionPQ:
// use uncompressed for now
fallthrough
default:
return index.searchByVector(vector, k, allow)
}
}
func (index *flat) searchByVector(vector []float32, k int, allow helpers.AllowList) ([]uint64, []float32, error) {
heap := index.pqResults.GetMax(k)
defer index.pqResults.Put(heap)
vector = index.normalized(vector)
if err := index.findTopVectors(heap, allow, k,
index.store.Bucket(helpers.VectorsBucketLSM).Cursor,
index.createDistanceCalc(vector),
); err != nil {
return nil, nil, err
}
ids, dists := index.extractHeap(heap)
return ids, dists, nil
}
func (index *flat) createDistanceCalc(vector []float32) distanceCalc {
return func(vecAsBytes []byte) (float32, error) {
vecSlice := index.pool.float32SlicePool.Get(len(vecAsBytes) / 4)
defer index.pool.float32SlicePool.Put(vecSlice)
candidate := float32SliceFromByteSlice(vecAsBytes, vecSlice.slice)
distance, _, err := index.distancerProvider.SingleDist(vector, candidate)
return distance, err
}
}
func (index *flat) searchByVectorBQ(vector []float32, k int, allow helpers.AllowList) ([]uint64, []float32, error) {
rescore := index.searchTimeRescore(k)
heap := index.pqResults.GetMax(rescore)
defer index.pqResults.Put(heap)
vector = index.normalized(vector)
vectorBQ := index.bq.Encode(vector)
if index.isBQCached() {
if err := index.findTopVectorsCached(heap, allow, rescore, vectorBQ); err != nil {
return nil, nil, err
}
} else {
if err := index.findTopVectors(heap, allow, rescore,
index.store.Bucket(helpers.VectorsCompressedBucketLSM).Cursor,
index.createDistanceCalcBQ(vectorBQ),
); err != nil {
return nil, nil, err
}
}
distanceCalc := index.createDistanceCalc(vector)
idsSlice := index.pool.uint64SlicePool.Get(heap.Len())
defer index.pool.uint64SlicePool.Put(idsSlice)
for i := range idsSlice.slice {
idsSlice.slice[i] = heap.Pop().ID
}
for _, id := range idsSlice.slice {
candidateAsBytes, err := index.vectorById(id)
if err != nil {
return nil, nil, err
}
distance, err := distanceCalc(candidateAsBytes)
if err != nil {
return nil, nil, err
}
index.insertToHeap(heap, k, id, distance)
}
ids, dists := index.extractHeap(heap)
return ids, dists, nil
}
func (index *flat) createDistanceCalcBQ(vectorBQ []uint64) distanceCalc {
return func(vecAsBytes []byte) (float32, error) {
vecSliceBQ := index.pool.uint64SlicePool.Get(len(vecAsBytes) / 8)
defer index.pool.uint64SlicePool.Put(vecSliceBQ)
candidate := uint64SliceFromByteSlice(vecAsBytes, vecSliceBQ.slice)
return index.bq.DistanceBetweenCompressedVectors(candidate, vectorBQ)
}
}
func (index *flat) vectorById(id uint64) ([]byte, error) {
idSlice := index.pool.byteSlicePool.Get(8)
defer index.pool.byteSlicePool.Put(idSlice)
binary.BigEndian.PutUint64(idSlice.slice, id)
return index.store.Bucket(helpers.VectorsBucketLSM).Get(idSlice.slice)
}
// populates given heap with smallest distances and corresponding ids calculated by
// distanceCalc
func (index *flat) findTopVectors(heap *priorityqueue.Queue[any],
allow helpers.AllowList, limit int, cursorFn func() *lsmkv.CursorReplace,
distanceCalc distanceCalc,
) error {
var key []byte
var v []byte
var id uint64
allowMax := uint64(0)
cursor := cursorFn()
defer cursor.Close()
if allow != nil {
// nothing allowed, skip search
if allow.IsEmpty() {
return nil
}
allowMax = allow.Max()
idSlice := index.pool.byteSlicePool.Get(8)
binary.BigEndian.PutUint64(idSlice.slice, allow.Min())
key, v = cursor.Seek(idSlice.slice)
index.pool.byteSlicePool.Put(idSlice)
} else {
key, v = cursor.First()
}
// since keys are sorted, once key/id get greater than max allowed one
// further search can be stopped
for ; key != nil && (allow == nil || id <= allowMax); key, v = cursor.Next() {
id = binary.BigEndian.Uint64(key)
if allow == nil || allow.Contains(id) {
distance, err := distanceCalc(v)
if err != nil {
return err
}
index.insertToHeap(heap, limit, id, distance)
}
}
return nil
}
// populates given heap with smallest distances and corresponding ids calculated by
// distanceCalc
func (index *flat) findTopVectorsCached(heap *priorityqueue.Queue[any],
allow helpers.AllowList, limit int, vectorBQ []uint64,
) error {
var id uint64
allowMax := uint64(0)
if allow != nil {
// nothing allowed, skip search
if allow.IsEmpty() {
return nil
}
allowMax = allow.Max()
id = allow.Min()
} else {
id = 0
}
all := index.bqCache.Len()
// since keys are sorted, once key/id get greater than max allowed one
// further search can be stopped
for ; id < uint64(all) && (allow == nil || id <= allowMax); id++ {
if allow == nil || allow.Contains(id) {
vec, err := index.bqCache.Get(context.Background(), id)
if err != nil {
return err
}
if len(vec) == 0 {
continue
}
distance, err := index.bq.DistanceBetweenCompressedVectors(vec, vectorBQ)
if err != nil {
return err
}
index.insertToHeap(heap, limit, id, distance)
}
}
return nil
}
func (index *flat) insertToHeap(heap *priorityqueue.Queue[any],
limit int, id uint64, distance float32,
) {
if heap.Len() < limit {
heap.Insert(id, distance)
} else if heap.Top().Dist > distance {
heap.Pop()
heap.Insert(id, distance)
}
}
func (index *flat) extractHeap(heap *priorityqueue.Queue[any],
) ([]uint64, []float32) {
len := heap.Len()
ids := make([]uint64, len)
dists := make([]float32, len)
for i := len - 1; i >= 0; i-- {
item := heap.Pop()
ids[i] = item.ID
dists[i] = item.Dist
}
return ids, dists
}
func (index *flat) normalized(vector []float32) []float32 {
if index.distancerProvider.Type() == "cosine-dot" {
// cosine-dot requires normalized vectors, as the dot product and cosine
// similarity are only identical if the vector is normalized
return distancer.Normalize(vector)
}
return vector
}
func (index *flat) SearchByVectorDistance(vector []float32, targetDistance float32, maxLimit int64, allow helpers.AllowList) ([]uint64, []float32, error) {
var (
searchParams = newSearchByDistParams(maxLimit)
resultIDs []uint64
resultDist []float32
)
recursiveSearch := func() (bool, error) {
totalLimit := searchParams.TotalLimit()
ids, dist, err := index.SearchByVector(vector, totalLimit, allow)
if err != nil {
return false, errors.Wrap(err, "vector search")
}
// if there is less results than given limit search can be stopped
shouldContinue := !(len(ids) < totalLimit)
// ensures the indexes aren't out of range
offsetCap := searchParams.OffsetCapacity(ids)
totalLimitCap := searchParams.TotalLimitCapacity(ids)
if offsetCap == totalLimitCap {
return false, nil
}
ids, dist = ids[offsetCap:totalLimitCap], dist[offsetCap:totalLimitCap]
for i := range ids {
if aboveThresh := dist[i] <= targetDistance; aboveThresh ||
floatcomp.InDelta(float64(dist[i]), float64(targetDistance), 1e-6) {
resultIDs = append(resultIDs, ids[i])
resultDist = append(resultDist, dist[i])
} else {
// as soon as we encounter a certainty which
// is below threshold, we can stop searching
shouldContinue = false
break
}
}
return shouldContinue, nil
}
var shouldContinue bool
var err error
for shouldContinue, err = recursiveSearch(); shouldContinue && err == nil; {
searchParams.Iterate()
if searchParams.MaxLimitReached() {
index.logger.
WithField("action", "unlimited_vector_search").
Warnf("maximum search limit of %d results has been reached",
searchParams.MaximumSearchLimit())
break
}
}
if err != nil {
return nil, nil, err
}
return resultIDs, resultDist, nil
}
func (index *flat) UpdateUserConfig(updated schema.VectorIndexConfig, callback func()) error {
parsed, ok := updated.(flatent.UserConfig)
if !ok {
callback()
return errors.Errorf("config is not UserConfig, but %T", updated)
}
// Store automatically as a lock here would be very expensive, this value is
// read on every single user-facing search, which can be highly concurrent
atomic.StoreInt64(&index.rescore, extractCompressionRescore(parsed))
callback()
return nil
}
func (index *flat) Drop(ctx context.Context) error {
// nothing to do here
// Shard::drop will take care of handling store's buckets
return nil
}
func (index *flat) Flush() error {
// nothing to do here
// Shard will take care of handling store's buckets
return nil
}
func (index *flat) Shutdown(ctx context.Context) error {
// nothing to do here
// Shard::shutdown will take care of handling store's buckets
return nil
}
func (index *flat) SwitchCommitLogs(context.Context) error {
return nil
}
func (index *flat) ListFiles(ctx context.Context, basePath string) ([]string, error) {
// nothing to do here
// Shard::ListBackupFiles will take care of handling store's buckets
return []string{}, nil
}
func (i *flat) ValidateBeforeInsert(vector []float32) error {
return nil
}
func (index *flat) PostStartup() {
if !index.isBQCached() {
return
}
cursor := index.store.Bucket(helpers.VectorsCompressedBucketLSM).Cursor()
defer cursor.Close()
for key, v := cursor.First(); key != nil; key, v = cursor.Next() {
id := binary.BigEndian.Uint64(key)
index.bqCache.Preload(id, uint64SliceFromByteSlice(v, make([]uint64, len(v)/8)))
}
}
func (index *flat) Dump(labels ...string) {
if len(labels) > 0 {
fmt.Printf("--------------------------------------------------\n")
fmt.Printf("-- %s\n", strings.Join(labels, ", "))
}
fmt.Printf("--------------------------------------------------\n")
fmt.Printf("ID: %s\n", index.id)
fmt.Printf("--------------------------------------------------\n")
}
func (index *flat) DistanceBetweenVectors(x, y []float32) (float32, bool, error) {
return index.distancerProvider.SingleDist(x, y)
}
func (index *flat) ContainsNode(id uint64) bool {
return true
}
func (index *flat) DistancerProvider() distancer.Provider {
return index.distancerProvider
}
func newSearchByDistParams(maxLimit int64) *common.SearchByDistParams {
initialOffset := 0
initialLimit := common.DefaultSearchByDistInitialLimit
return common.NewSearchByDistParams(initialOffset, initialLimit, initialOffset+initialLimit, maxLimit)
}
type immutableParameter struct {
accessor func(c flatent.UserConfig) interface{}
name string
}
func validateImmutableField(u immutableParameter,
previous, next flatent.UserConfig,
) error {
oldField := u.accessor(previous)
newField := u.accessor(next)
if oldField != newField {
return errors.Errorf("%s is immutable: attempted change from \"%v\" to \"%v\"",
u.name, oldField, newField)
}
return nil
}
func ValidateUserConfigUpdate(initial, updated schema.VectorIndexConfig) error {
initialParsed, ok := initial.(flatent.UserConfig)
if !ok {
return errors.Errorf("initial is not UserConfig, but %T", initial)
}
updatedParsed, ok := updated.(flatent.UserConfig)
if !ok {
return errors.Errorf("updated is not UserConfig, but %T", updated)
}
immutableFields := []immutableParameter{
{
name: "distance",
accessor: func(c flatent.UserConfig) interface{} { return c.Distance },
},
{
name: "bq.cache",
accessor: func(c flatent.UserConfig) interface{} { return c.BQ.Cache },
},
{
name: "pq.cache",
accessor: func(c flatent.UserConfig) interface{} { return c.PQ.Cache },
},
{
name: "pq",
accessor: func(c flatent.UserConfig) interface{} { return c.PQ.Enabled },
},
{
name: "bq",
accessor: func(c flatent.UserConfig) interface{} { return c.BQ.Enabled },
},
}
for _, u := range immutableFields {
if err := validateImmutableField(u, initialParsed, updatedParsed); err != nil {
return err
}
}
return nil
}