KevinStephenson
Adding in weaviate code
b110593
raw
history blame
19.6 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package inverted
import (
"context"
"encoding/binary"
"fmt"
"math"
"sort"
"strconv"
"strings"
"github.com/weaviate/weaviate/adapters/repos/db/inverted/stopwords"
"golang.org/x/sync/errgroup"
"github.com/weaviate/sroar"
"github.com/weaviate/weaviate/adapters/repos/db/priorityqueue"
"github.com/weaviate/weaviate/entities/inverted"
"github.com/weaviate/weaviate/entities/models"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/weaviate/weaviate/adapters/repos/db/helpers"
"github.com/weaviate/weaviate/adapters/repos/db/lsmkv"
"github.com/weaviate/weaviate/adapters/repos/db/propertyspecific"
"github.com/weaviate/weaviate/entities/schema"
"github.com/weaviate/weaviate/entities/searchparams"
"github.com/weaviate/weaviate/entities/storobj"
)
type BM25Searcher struct {
config schema.BM25Config
store *lsmkv.Store
schema schema.Schema
classSearcher ClassSearcher // to allow recursive searches on ref-props
propIndices propertyspecific.Indices
propLenTracker propLengthRetriever
logger logrus.FieldLogger
shardVersion uint16
}
type propLengthRetriever interface {
PropertyMean(prop string) (float32, error)
}
func NewBM25Searcher(config schema.BM25Config, store *lsmkv.Store,
schema schema.Schema, propIndices propertyspecific.Indices,
classSearcher ClassSearcher, propLenTracker propLengthRetriever,
logger logrus.FieldLogger, shardVersion uint16,
) *BM25Searcher {
return &BM25Searcher{
config: config,
store: store,
schema: schema,
propIndices: propIndices,
classSearcher: classSearcher,
propLenTracker: propLenTracker,
logger: logger.WithField("action", "bm25_search"),
shardVersion: shardVersion,
}
}
func (b *BM25Searcher) BM25F(ctx context.Context, filterDocIds helpers.AllowList, className schema.ClassName, limit int, keywordRanking searchparams.KeywordRanking) ([]*storobj.Object, []float32, error) {
// WEAVIATE-471 - If a property is not searchable, return an error
for _, property := range keywordRanking.Properties {
if !PropertyHasSearchableIndex(b.schema.Objects, string(className), property) {
return nil, nil, inverted.NewMissingSearchableIndexError(property)
}
}
class, err := schema.GetClassByName(b.schema.Objects, string(className))
if err != nil {
return nil, nil, err
}
objs, scores, err := b.wand(ctx, filterDocIds, class, keywordRanking, limit)
if err != nil {
return nil, nil, errors.Wrap(err, "wand")
}
return objs, scores, nil
}
func (b *BM25Searcher) GetPropertyLengthTracker() *JsonShardMetaData {
return b.propLenTracker.(*JsonShardMetaData)
}
func (b *BM25Searcher) wand(
ctx context.Context, filterDocIds helpers.AllowList, class *models.Class, params searchparams.KeywordRanking, limit int,
) ([]*storobj.Object, []float32, error) {
N := float64(b.store.Bucket(helpers.ObjectsBucketLSM).Count())
var stopWordDetector *stopwords.Detector
if class.InvertedIndexConfig != nil && class.InvertedIndexConfig.Stopwords != nil {
var err error
stopWordDetector, err = stopwords.NewDetectorFromConfig(*(class.InvertedIndexConfig.Stopwords))
if err != nil {
return nil, nil, err
}
}
// There are currently cases, for different tokenization:
// word, lowercase, whitespace and field.
// Query is tokenized and respective properties are then searched for the search terms,
// results at the end are combined using WAND
tokenizationsOrdered := []string{
models.PropertyTokenizationWord,
models.PropertyTokenizationLowercase,
models.PropertyTokenizationWhitespace,
models.PropertyTokenizationField,
}
queryTermsByTokenization := map[string][]string{}
duplicateBoostsByTokenization := map[string][]int{}
propNamesByTokenization := map[string][]string{}
propertyBoosts := make(map[string]float32, len(params.Properties))
for _, tokenization := range tokenizationsOrdered {
queryTermsByTokenization[tokenization], duplicateBoostsByTokenization[tokenization] = helpers.TokenizeAndCountDuplicates(tokenization, params.Query)
// stopword filtering for word tokenization
if tokenization == models.PropertyTokenizationWord {
queryTermsByTokenization[tokenization], duplicateBoostsByTokenization[tokenization] = b.removeStopwordsFromQueryTerms(queryTermsByTokenization[tokenization], duplicateBoostsByTokenization[tokenization], stopWordDetector)
}
propNamesByTokenization[tokenization] = make([]string, 0)
}
averagePropLength := 0.
for _, propertyWithBoost := range params.Properties {
property := propertyWithBoost
propBoost := 1
if strings.Contains(propertyWithBoost, "^") {
property = strings.Split(propertyWithBoost, "^")[0]
boostStr := strings.Split(propertyWithBoost, "^")[1]
propBoost, _ = strconv.Atoi(boostStr)
}
propertyBoosts[property] = float32(propBoost)
propMean, err := b.GetPropertyLengthTracker().PropertyMean(property)
if err != nil {
return nil, nil, err
}
averagePropLength += float64(propMean)
prop, err := schema.GetPropertyByName(class, property)
if err != nil {
return nil, nil, err
}
switch dt, _ := schema.AsPrimitive(prop.DataType); dt {
case schema.DataTypeText, schema.DataTypeTextArray:
if _, exists := propNamesByTokenization[prop.Tokenization]; !exists {
return nil, nil, fmt.Errorf("cannot handle tokenization '%v' of property '%s'",
prop.Tokenization, prop.Name)
}
propNamesByTokenization[prop.Tokenization] = append(propNamesByTokenization[prop.Tokenization], property)
default:
return nil, nil, fmt.Errorf("cannot handle datatype '%v' of property '%s'", dt, prop.Name)
}
}
averagePropLength = averagePropLength / float64(len(params.Properties))
// preallocate the results
lengthAllResults := 0
for tokenization, propNames := range propNamesByTokenization {
if len(propNames) > 0 {
lengthAllResults += len(queryTermsByTokenization[tokenization])
}
}
results := make(terms, lengthAllResults)
indices := make([]map[uint64]int, lengthAllResults)
var eg errgroup.Group
eg.SetLimit(_NUMCPU)
offset := 0
for _, tokenization := range tokenizationsOrdered {
propNames := propNamesByTokenization[tokenization]
if len(propNames) > 0 {
queryTerms := queryTermsByTokenization[tokenization]
duplicateBoosts := duplicateBoostsByTokenization[tokenization]
for i := range queryTerms {
j := i
k := i + offset
eg.Go(func() error {
termResult, docIndices, err := b.createTerm(N, filterDocIds, queryTerms[j], propNames,
propertyBoosts, duplicateBoosts[j], params.AdditionalExplanations)
if err != nil {
return err
}
results[k] = termResult
indices[k] = docIndices
return nil
})
}
offset += len(queryTerms)
}
}
if err := eg.Wait(); err != nil {
return nil, nil, err
}
// all results. Sum up the length of the results from all terms to get an upper bound of how many results there are
if limit == 0 {
for _, ind := range indices {
limit += len(ind)
}
}
// the results are needed in the original order to be able to locate frequency/property length for the top-results
resultsOriginalOrder := make(terms, len(results))
copy(resultsOriginalOrder, results)
topKHeap := b.getTopKHeap(limit, results, averagePropLength)
return b.getTopKObjects(topKHeap, resultsOriginalOrder, indices, params.AdditionalExplanations)
}
func (b *BM25Searcher) removeStopwordsFromQueryTerms(queryTerms []string, duplicateBoost []int, detector *stopwords.Detector) ([]string, []int) {
if detector == nil || len(queryTerms) == 0 {
return queryTerms, duplicateBoost
}
i := 0
WordLoop:
for {
if i == len(queryTerms) {
return queryTerms, duplicateBoost
}
queryTerm := queryTerms[i]
if detector.IsStopword(queryTerm) {
queryTerms[i] = queryTerms[len(queryTerms)-1]
queryTerms = queryTerms[:len(queryTerms)-1]
duplicateBoost[i] = duplicateBoost[len(duplicateBoost)-1]
duplicateBoost = duplicateBoost[:len(duplicateBoost)-1]
continue WordLoop
}
i++
}
}
func (b *BM25Searcher) getTopKObjects(topKHeap *priorityqueue.Queue[any],
results terms, indices []map[uint64]int, additionalExplanations bool,
) ([]*storobj.Object, []float32, error) {
objectsBucket := b.store.Bucket(helpers.ObjectsBucketLSM)
if objectsBucket == nil {
return nil, nil, errors.Errorf("objects bucket not found")
}
objects := make([]*storobj.Object, 0, topKHeap.Len())
scores := make([]float32, 0, topKHeap.Len())
buf := make([]byte, 8)
for topKHeap.Len() > 0 {
res := topKHeap.Pop()
binary.LittleEndian.PutUint64(buf, res.ID)
objectByte, err := objectsBucket.GetBySecondary(0, buf)
if err != nil {
return nil, nil, err
}
if len(objectByte) == 0 {
b.logger.Warnf("Skipping object in BM25: object with id %v has a length of 0 bytes.", res.ID)
continue
}
obj, err := storobj.FromBinary(objectByte)
if err != nil {
return nil, nil, err
}
if additionalExplanations {
// add score explanation
if obj.AdditionalProperties() == nil {
obj.Object.Additional = make(map[string]interface{})
}
for j, result := range results {
if termIndice, ok := indices[j][res.ID]; ok {
queryTerm := result.queryTerm
obj.Object.Additional["BM25F_"+queryTerm+"_frequency"] = result.data[termIndice].frequency
obj.Object.Additional["BM25F_"+queryTerm+"_propLength"] = result.data[termIndice].propLength
}
}
}
objects = append(objects, obj)
scores = append(scores, res.Dist)
}
return objects, scores, nil
}
func (b *BM25Searcher) getTopKHeap(limit int, results terms, averagePropLength float64,
) *priorityqueue.Queue[any] {
topKHeap := priorityqueue.NewMin[any](limit)
worstDist := float64(-10000) // tf score can be negative
sort.Sort(results)
for {
if results.completelyExhausted() || results.pivot(worstDist) {
return topKHeap
}
id, score := results.scoreNext(averagePropLength, b.config)
if topKHeap.Len() < limit || topKHeap.Top().Dist < float32(score) {
topKHeap.Insert(id, float32(score))
for topKHeap.Len() > limit {
topKHeap.Pop()
}
// only update the worst distance when the queue is full, otherwise results can be missing if the first
// entry that is checked already has a very high score
if topKHeap.Len() >= limit {
worstDist = float64(topKHeap.Top().Dist)
}
}
}
}
func (b *BM25Searcher) createTerm(N float64, filterDocIds helpers.AllowList, query string, propertyNames []string, propertyBoosts map[string]float32, duplicateTextBoost int, additionalExplanations bool) (term, map[uint64]int, error) {
termResult := term{queryTerm: query}
filteredDocIDs := sroar.NewBitmap() // to build the global n if there is a filter
allMsAndProps := make(AllMapPairsAndPropName, 0, len(propertyNames))
for _, propName := range propertyNames {
bucket := b.store.Bucket(helpers.BucketSearchableFromPropNameLSM(propName))
if bucket == nil {
return termResult, nil, fmt.Errorf("could not find bucket for property %v", propName)
}
preM, err := bucket.MapList([]byte(query))
if err != nil {
return termResult, nil, err
}
var m []lsmkv.MapPair
if filterDocIds != nil {
m = make([]lsmkv.MapPair, 0, len(preM))
for _, val := range preM {
docID := binary.BigEndian.Uint64(val.Key)
if filterDocIds.Contains(docID) {
m = append(m, val)
} else {
filteredDocIDs.Set(docID)
}
}
} else {
m = preM
}
if len(m) == 0 {
continue
}
allMsAndProps = append(allMsAndProps, MapPairsAndPropName{MapPairs: m, propname: propName})
}
// sort ascending, this code has two effects
// 1) We can skip writing the indices from the last property to the map (see next comment). Therefore, having the
// biggest property at the end will save us most writes on average
// 2) For the first property all entries are new, and we can create the map with the respective size. When choosing
// the second-biggest entry as the first property we save additional allocations later
sort.Sort(allMsAndProps)
if len(allMsAndProps) > 2 {
allMsAndProps[len(allMsAndProps)-2], allMsAndProps[0] = allMsAndProps[0], allMsAndProps[len(allMsAndProps)-2]
}
var docMapPairs []docPointerWithScore = nil
var docMapPairsIndices map[uint64]int = nil
for i, mAndProps := range allMsAndProps {
m := mAndProps.MapPairs
propName := mAndProps.propname
// The indices are needed for two things:
// a) combining the results of different properties
// b) Retrieve additional information that helps to understand the results when debugging. The retrieval is done
// in a later step, after it is clear which objects are the most relevant
//
// When b) is not needed the results from the last property do not need to be added to the index-map as there
// won't be any follow-up combinations.
includeIndicesForLastElement := false
if additionalExplanations || i < len(allMsAndProps)-1 {
includeIndicesForLastElement = true
}
// only create maps/slices if we know how many entries there are
if docMapPairs == nil {
docMapPairs = make([]docPointerWithScore, 0, len(m))
docMapPairsIndices = make(map[uint64]int, len(m))
for k, val := range m {
if len(val.Value) < 8 {
b.logger.Warnf("Skipping pair in BM25: MapPair.Value should be 8 bytes long, but is %d.", len(val.Value))
continue
}
freqBits := binary.LittleEndian.Uint32(val.Value[0:4])
propLenBits := binary.LittleEndian.Uint32(val.Value[4:8])
docMapPairs = append(docMapPairs,
docPointerWithScore{
id: binary.BigEndian.Uint64(val.Key),
frequency: math.Float32frombits(freqBits) * propertyBoosts[propName],
propLength: math.Float32frombits(propLenBits),
})
if includeIndicesForLastElement {
docMapPairsIndices[binary.BigEndian.Uint64(val.Key)] = k
}
}
} else {
for _, val := range m {
if len(val.Value) < 8 {
b.logger.Warnf("Skipping pair in BM25: MapPair.Value should be 8 bytes long, but is %d.", len(val.Value))
continue
}
key := binary.BigEndian.Uint64(val.Key)
ind, ok := docMapPairsIndices[key]
freqBits := binary.LittleEndian.Uint32(val.Value[0:4])
propLenBits := binary.LittleEndian.Uint32(val.Value[4:8])
if ok {
docMapPairs[ind].propLength += math.Float32frombits(propLenBits)
docMapPairs[ind].frequency += math.Float32frombits(freqBits) * propertyBoosts[propName]
} else {
docMapPairs = append(docMapPairs,
docPointerWithScore{
id: binary.BigEndian.Uint64(val.Key),
frequency: math.Float32frombits(freqBits) * propertyBoosts[propName],
propLength: math.Float32frombits(propLenBits),
})
if includeIndicesForLastElement {
docMapPairsIndices[binary.BigEndian.Uint64(val.Key)] = len(docMapPairs) - 1 // current last entry
}
}
}
}
}
if docMapPairs == nil {
termResult.exhausted = true
return termResult, docMapPairsIndices, nil
}
termResult.data = docMapPairs
n := float64(len(docMapPairs))
if filterDocIds != nil {
n += float64(filteredDocIDs.GetCardinality())
}
termResult.idf = math.Log(float64(1)+(N-n+0.5)/(n+0.5)) * float64(duplicateTextBoost)
termResult.posPointer = 0
termResult.idPointer = termResult.data[0].id
return termResult, docMapPairsIndices, nil
}
type term struct {
// doubles as max impact (with tf=1, the max impact would be 1*idf), if there
// is a boost for a queryTerm, simply apply it here once
idf float64
idPointer uint64
posPointer uint64
data []docPointerWithScore
exhausted bool
queryTerm string
}
func (t *term) scoreAndAdvance(averagePropLength float64, config schema.BM25Config) (uint64, float64) {
id := t.idPointer
pair := t.data[t.posPointer]
freq := float64(pair.frequency)
tf := freq / (freq + config.K1*(1-config.B+config.B*float64(pair.propLength)/averagePropLength))
// advance
t.posPointer++
if t.posPointer >= uint64(len(t.data)) {
t.exhausted = true
} else {
t.idPointer = t.data[t.posPointer].id
}
return id, tf * t.idf
}
func (t *term) advanceAtLeast(minID uint64) {
for t.idPointer < minID {
t.posPointer++
if t.posPointer >= uint64(len(t.data)) {
t.exhausted = true
return
}
t.idPointer = t.data[t.posPointer].id
}
}
type terms []term
func (t terms) completelyExhausted() bool {
for i := range t {
if !t[i].exhausted {
return false
}
}
return true
}
func (t terms) pivot(minScore float64) bool {
minID, pivotPoint, abort := t.findMinID(minScore)
if abort {
return true
}
if pivotPoint == 0 {
return false
}
t.advanceAllAtLeast(minID)
sort.Sort(t)
return false
}
func (t terms) advanceAllAtLeast(minID uint64) {
for i := range t {
t[i].advanceAtLeast(minID)
}
}
func (t terms) findMinID(minScore float64) (uint64, int, bool) {
cumScore := float64(0)
for i, term := range t {
if term.exhausted {
continue
}
cumScore += term.idf
if cumScore >= minScore {
return term.idPointer, i, false
}
}
return 0, 0, true
}
func (t terms) findFirstNonExhausted() (int, bool) {
for i := range t {
if !t[i].exhausted {
return i, true
}
}
return -1, false
}
func (t terms) scoreNext(averagePropLength float64, config schema.BM25Config) (uint64, float64) {
pos, ok := t.findFirstNonExhausted()
if !ok {
// done, nothing left to score
return 0, 0
}
id := t[pos].idPointer
var cumScore float64
for i := pos; i < len(t); i++ {
if t[i].idPointer != id || t[i].exhausted {
continue
}
_, score := t[i].scoreAndAdvance(averagePropLength, config)
cumScore += score
}
sort.Sort(t) // pointer was advanced in scoreAndAdvance
return id, cumScore
}
// provide sort interface
func (t terms) Len() int {
return len(t)
}
func (t terms) Less(i, j int) bool {
return t[i].idPointer < t[j].idPointer
}
func (t terms) Swap(i, j int) {
t[i], t[j] = t[j], t[i]
}
type MapPairsAndPropName struct {
propname string
MapPairs []lsmkv.MapPair
}
type AllMapPairsAndPropName []MapPairsAndPropName
// provide sort interface
func (m AllMapPairsAndPropName) Len() int {
return len(m)
}
func (m AllMapPairsAndPropName) Less(i, j int) bool {
return len(m[i].MapPairs) < len(m[j].MapPairs)
}
func (m AllMapPairsAndPropName) Swap(i, j int) {
m[i], m[j] = m[j], m[i]
}
func PropertyHasSearchableIndex(schemaDefinition *models.Schema, className, tentativePropertyName string) bool {
propertyName := strings.Split(tentativePropertyName, "^")[0]
c, err := schema.GetClassByName(schemaDefinition, string(className))
if err != nil {
return false
}
p, err := schema.GetPropertyByName(c, propertyName)
if err != nil {
return false
}
return HasSearchableIndex(p)
}