Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package aggregator | |
import ( | |
"context" | |
"fmt" | |
"github.com/weaviate/weaviate/adapters/repos/db/helpers" | |
"github.com/weaviate/weaviate/adapters/repos/db/inverted" | |
"github.com/weaviate/weaviate/entities/additional" | |
"github.com/weaviate/weaviate/entities/storobj" | |
) | |
func (a *Aggregator) vectorSearch(allow helpers.AllowList, vec []float32) ([]uint64, []float32, error) { | |
if a.params.ObjectLimit != nil { | |
return a.searchByVector(vec, a.params.ObjectLimit, allow) | |
} | |
return a.searchByVectorDistance(vec, allow) | |
} | |
func (a *Aggregator) searchByVector(searchVector []float32, limit *int, ids helpers.AllowList) ([]uint64, []float32, error) { | |
idsFound, dists, err := a.vectorIndex.SearchByVector(searchVector, *limit, ids) | |
if err != nil { | |
return idsFound, nil, err | |
} | |
if a.params.Certainty > 0 { | |
targetDist := float32(1-a.params.Certainty) * 2 | |
i := 0 | |
for _, dist := range dists { | |
if dist > targetDist { | |
break | |
} | |
i++ | |
} | |
return idsFound[:i], dists, nil | |
} | |
return idsFound, dists, nil | |
} | |
func (a *Aggregator) searchByVectorDistance(searchVector []float32, ids helpers.AllowList) ([]uint64, []float32, error) { | |
if a.params.Certainty <= 0 { | |
return nil, nil, fmt.Errorf("must provide certainty or objectLimit with vector search") | |
} | |
targetDist := float32(1-a.params.Certainty) * 2 | |
idsFound, dists, err := a.vectorIndex.SearchByVectorDistance(searchVector, targetDist, -1, ids) | |
if err != nil { | |
return nil, nil, fmt.Errorf("aggregate search by vector: %w", err) | |
} | |
return idsFound, dists, nil | |
} | |
func (a *Aggregator) objectVectorSearch(searchVector []float32, | |
allowList helpers.AllowList, | |
) ([]*storobj.Object, []float32, error) { | |
ids, dists, err := a.vectorSearch(allowList, searchVector) | |
if err != nil { | |
return nil, nil, err | |
} | |
bucket := a.store.Bucket(helpers.ObjectsBucketLSM) | |
objs, err := storobj.ObjectsByDocID(bucket, ids, additional.Properties{}) | |
if err != nil { | |
return nil, nil, fmt.Errorf("get objects by doc id: %w", err) | |
} | |
return objs, dists, nil | |
} | |
func (a *Aggregator) buildAllowList(ctx context.Context) (helpers.AllowList, error) { | |
var ( | |
allow helpers.AllowList | |
err error | |
) | |
if a.params.Filters != nil { | |
s := a.getSchema.GetSchemaSkipAuth() | |
allow, err = inverted.NewSearcher(a.logger, a.store, s, nil, | |
a.classSearcher, a.stopwords, a.shardVersion, a.isFallbackToSearchable, | |
a.tenant, a.nestedCrossRefLimit). | |
DocIDs(ctx, a.params.Filters, additional.Properties{}, | |
a.params.ClassName) | |
if err != nil { | |
return nil, fmt.Errorf("retrieve doc IDs from searcher: %w", err) | |
} | |
} | |
return allow, nil | |
} | |