KevinStephenson
Adding in weaviate code
b110593
raw
history blame
3.15 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package hnsw
import (
"context"
"github.com/pkg/errors"
"github.com/weaviate/weaviate/adapters/repos/db/helpers"
"github.com/weaviate/weaviate/adapters/repos/db/priorityqueue"
"github.com/weaviate/weaviate/entities/storobj"
)
func (h *hnsw) selectNeighborsHeuristic(input *priorityqueue.Queue[any],
max int, denyList helpers.AllowList,
) error {
if input.Len() < max {
return nil
}
// TODO, if this solution stays we might need something with fewer allocs
ids := make([]uint64, input.Len())
closestFirst := h.pools.pqHeuristic.GetMin(input.Len())
i := uint64(0)
for input.Len() > 0 {
elem := input.Pop()
closestFirst.InsertWithValue(elem.ID, elem.Dist, i)
ids[i] = elem.ID
i++
}
var returnList []priorityqueue.Item[uint64]
if h.compressed.Load() {
bag := h.compressor.NewBag()
for _, id := range ids {
err := bag.Load(context.Background(), id)
if err != nil {
return err
}
}
returnList = h.pools.pqItemSlice.Get().([]priorityqueue.Item[uint64])
for closestFirst.Len() > 0 && len(returnList) < max {
curr := closestFirst.Pop()
if denyList != nil && denyList.Contains(curr.ID) {
continue
}
distToQuery := curr.Dist
good := true
for _, item := range returnList {
peerDist, err := bag.Distance(curr.ID, item.ID)
if err != nil {
return err
}
if peerDist < distToQuery {
good = false
break
}
}
if good {
returnList = append(returnList, curr)
}
}
} else {
vecs, errs := h.multiVectorForID(context.TODO(), ids)
returnList = h.pools.pqItemSlice.Get().([]priorityqueue.Item[uint64])
for closestFirst.Len() > 0 && len(returnList) < max {
curr := closestFirst.Pop()
if denyList != nil && denyList.Contains(curr.ID) {
continue
}
distToQuery := curr.Dist
currVec := vecs[curr.Value]
if err := errs[curr.Value]; err != nil {
var e storobj.ErrNotFound
if errors.As(err, &e) {
h.handleDeletedNode(e.DocID)
continue
} else {
// not a typed error, we can recover from, return with err
return errors.Wrapf(err,
"unrecoverable error for docID %d", curr.ID)
}
}
good := true
for _, item := range returnList {
peerDist, _, _ := h.distancerProvider.SingleDist(currVec,
vecs[item.Value])
if peerDist < distToQuery {
good = false
break
}
}
if good {
returnList = append(returnList, curr)
}
}
}
h.pools.pqHeuristic.Put(closestFirst)
for _, retElem := range returnList {
input.Insert(retElem.ID, retElem.Dist)
}
// rewind and return to pool
returnList = returnList[:0]
//nolint:staticcheck
h.pools.pqItemSlice.Put(returnList)
return nil
}