KevinStephenson
Adding in weaviate code
b110593
raw
history blame
3.02 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package aggregator
import (
"context"
"fmt"
"github.com/weaviate/weaviate/adapters/repos/db/helpers"
"github.com/weaviate/weaviate/adapters/repos/db/inverted"
"github.com/weaviate/weaviate/entities/additional"
"github.com/weaviate/weaviate/entities/storobj"
)
func (a *Aggregator) vectorSearch(allow helpers.AllowList, vec []float32) ([]uint64, []float32, error) {
if a.params.ObjectLimit != nil {
return a.searchByVector(vec, a.params.ObjectLimit, allow)
}
return a.searchByVectorDistance(vec, allow)
}
func (a *Aggregator) searchByVector(searchVector []float32, limit *int, ids helpers.AllowList) ([]uint64, []float32, error) {
idsFound, dists, err := a.vectorIndex.SearchByVector(searchVector, *limit, ids)
if err != nil {
return idsFound, nil, err
}
if a.params.Certainty > 0 {
targetDist := float32(1-a.params.Certainty) * 2
i := 0
for _, dist := range dists {
if dist > targetDist {
break
}
i++
}
return idsFound[:i], dists, nil
}
return idsFound, dists, nil
}
func (a *Aggregator) searchByVectorDistance(searchVector []float32, ids helpers.AllowList) ([]uint64, []float32, error) {
if a.params.Certainty <= 0 {
return nil, nil, fmt.Errorf("must provide certainty or objectLimit with vector search")
}
targetDist := float32(1-a.params.Certainty) * 2
idsFound, dists, err := a.vectorIndex.SearchByVectorDistance(searchVector, targetDist, -1, ids)
if err != nil {
return nil, nil, fmt.Errorf("aggregate search by vector: %w", err)
}
return idsFound, dists, nil
}
func (a *Aggregator) objectVectorSearch(searchVector []float32,
allowList helpers.AllowList,
) ([]*storobj.Object, []float32, error) {
ids, dists, err := a.vectorSearch(allowList, searchVector)
if err != nil {
return nil, nil, err
}
bucket := a.store.Bucket(helpers.ObjectsBucketLSM)
objs, err := storobj.ObjectsByDocID(bucket, ids, additional.Properties{})
if err != nil {
return nil, nil, fmt.Errorf("get objects by doc id: %w", err)
}
return objs, dists, nil
}
func (a *Aggregator) buildAllowList(ctx context.Context) (helpers.AllowList, error) {
var (
allow helpers.AllowList
err error
)
if a.params.Filters != nil {
s := a.getSchema.GetSchemaSkipAuth()
allow, err = inverted.NewSearcher(a.logger, a.store, s, nil,
a.classSearcher, a.stopwords, a.shardVersion, a.isFallbackToSearchable,
a.tenant, a.nestedCrossRefLimit).
DocIDs(ctx, a.params.Filters, additional.Properties{},
a.params.ClassName)
if err != nil {
return nil, fmt.Errorf("retrieve doc IDs from searcher: %w", err)
}
}
return allow, nil
}