KevinStephenson
Adding in weaviate code
b110593
raw
history blame
7.6 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package aggregator
import (
"context"
"fmt"
"github.com/pkg/errors"
"github.com/weaviate/weaviate/adapters/repos/db/docid"
"github.com/weaviate/weaviate/adapters/repos/db/helpers"
"github.com/weaviate/weaviate/adapters/repos/db/lsmkv"
"github.com/weaviate/weaviate/entities/aggregation"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/storobj"
"github.com/weaviate/weaviate/usecases/traverser/hybrid"
bolt "go.etcd.io/bbolt"
)
// grouper is the component which identifies the top-n groups for a specific
// group-by parameter. It is used as part of the grouped aggregator, which then
// additionally performs an aggregation for each group.
type grouper struct {
*Aggregator
values map[interface{}]map[uint64]struct{} // map[value][docID]struct, to keep docIds unique
topGroups []group
limit int
}
func newGrouper(a *Aggregator, limit int) *grouper {
return &grouper{
Aggregator: a,
values: map[interface{}]map[uint64]struct{}{},
limit: limit,
}
}
func (g *grouper) Do(ctx context.Context) ([]group, error) {
if len(g.params.GroupBy.Slice()) > 1 {
return nil, fmt.Errorf("grouping by cross-refs not supported")
}
if g.params.Filters == nil && len(g.params.SearchVector) == 0 && g.params.Hybrid == nil {
return g.groupAll(ctx)
} else {
return g.groupFiltered(ctx)
}
}
func (g *grouper) groupAll(ctx context.Context) ([]group, error) {
err := ScanAllLSM(g.store, func(prop *models.PropertySchema, docID uint64) (bool, error) {
return true, g.addElementById(prop, docID)
})
if err != nil {
return nil, errors.Wrap(err, "group all (unfiltered)")
}
return g.aggregateAndSelect()
}
func (g *grouper) groupFiltered(ctx context.Context) ([]group, error) {
ids, err := g.fetchDocIDs(ctx)
if err != nil {
return nil, err
}
if err := docid.ScanObjectsLSM(g.store, ids,
func(prop *models.PropertySchema, docID uint64) (bool, error) {
return true, g.addElementById(prop, docID)
}, []string{g.params.GroupBy.Property.String()}); err != nil {
return nil, err
}
return g.aggregateAndSelect()
}
func (g *grouper) fetchDocIDs(ctx context.Context) (ids []uint64, err error) {
allowList, err := g.buildAllowList(ctx)
if err != nil {
return nil, err
}
if len(g.params.SearchVector) > 0 {
ids, _, err = g.vectorSearch(allowList, g.params.SearchVector)
if err != nil {
return nil, fmt.Errorf("failed to perform vector search: %w", err)
}
} else if g.params.Hybrid != nil {
ids, err = g.hybrid(ctx, allowList)
if err != nil {
return nil, fmt.Errorf("hybrid search: %w", err)
}
} else {
ids = allowList.Slice()
}
return
}
func (g *grouper) hybrid(ctx context.Context, allowList helpers.AllowList) ([]uint64, error) {
sparseSearch := func() ([]*storobj.Object, []float32, error) {
kw, err := g.buildHybridKeywordRanking()
if err != nil {
return nil, nil, fmt.Errorf("build hybrid keyword ranking: %w", err)
}
if g.params.ObjectLimit == nil {
limit := hybrid.DefaultLimit
g.params.ObjectLimit = &limit
}
sparse, dists, err := g.bm25Objects(ctx, kw)
if err != nil {
return nil, nil, fmt.Errorf("aggregate sparse search: %w", err)
}
return sparse, dists, nil
}
denseSearch := func(vec []float32) ([]*storobj.Object, []float32, error) {
res, dists, err := g.objectVectorSearch(vec, allowList)
if err != nil {
return nil, nil, fmt.Errorf("aggregate grouped dense search: %w", err)
}
return res, dists, nil
}
res, err := hybrid.Search(ctx, &hybrid.Params{
HybridSearch: g.params.Hybrid,
Keyword: nil,
Class: g.params.ClassName.String(),
}, g.logger, sparseSearch, denseSearch, nil, nil)
if err != nil {
return nil, err
}
ids := make([]uint64, len(res))
for i, r := range res {
ids[i] = r.DocID
}
return ids, nil
}
func (g *grouper) addElementById(s *models.PropertySchema, docID uint64) error {
if s == nil {
return nil
}
item, ok := (*s).(map[string]interface{})[g.params.GroupBy.Property.String()]
if !ok {
return nil
}
switch val := item.(type) {
case []string:
for i := range val {
g.addItem(val[i], docID)
}
case []float64:
for i := range val {
g.addItem(val[i], docID)
}
case []bool:
for i := range val {
g.addItem(val[i], docID)
}
case []interface{}:
for i := range val {
g.addItem(val[i], docID)
}
case models.MultipleRef:
for i := range val {
g.addItem(val[i].Beacon, docID)
}
default:
g.addItem(val, docID)
}
return nil
}
func (g *grouper) addItem(item interface{}, docID uint64) {
idsMap, ok := g.values[item]
if !ok {
idsMap = map[uint64]struct{}{}
}
idsMap[docID] = struct{}{}
g.values[item] = idsMap
}
func (g *grouper) aggregateAndSelect() ([]group, error) {
for value, idsMap := range g.values {
count := len(idsMap)
ids := make([]uint64, count)
i := 0
for id := range idsMap {
ids[i] = id
i++
}
g.insertOrdered(group{
res: aggregation.Group{
GroupedBy: &aggregation.GroupedBy{
Path: g.params.GroupBy.Slice(),
Value: value,
},
Count: count,
},
docIDs: ids,
})
}
return g.topGroups, nil
}
func (g *grouper) insertOrdered(elem group) {
if len(g.topGroups) == 0 {
g.topGroups = []group{elem}
return
}
added := false
for i, existing := range g.topGroups {
if existing.res.Count > elem.res.Count {
continue
}
// we have found the first one that's smaller so we must insert before i
g.topGroups = append(
g.topGroups[:i], append(
[]group{elem},
g.topGroups[i:]...,
)...,
)
added = true
break
}
if len(g.topGroups) > g.limit {
g.topGroups = g.topGroups[:len(g.topGroups)-1]
}
if !added && len(g.topGroups) < g.limit {
g.topGroups = append(g.topGroups, elem)
}
}
// ScanAll iterates over every row in the object buckets
// TODO: where should this live?
func ScanAll(tx *bolt.Tx, scan docid.ObjectScanFn) error {
b := tx.Bucket(helpers.ObjectsBucket)
if b == nil {
return fmt.Errorf("objects bucket not found")
}
b.ForEach(func(_, v []byte) error {
elem, err := storobj.FromBinary(v)
if err != nil {
return errors.Wrapf(err, "unmarshal data object")
}
// scanAll has no abort, so we can ignore the first arg
properties := elem.Properties()
_, err = scan(&properties, elem.DocID())
return err
})
return nil
}
// ScanAllLSM iterates over every row in the object buckets
func ScanAllLSM(store *lsmkv.Store, scan docid.ObjectScanFn) error {
b := store.Bucket(helpers.ObjectsBucketLSM)
if b == nil {
return fmt.Errorf("objects bucket not found")
}
c := b.Cursor()
defer c.Close()
for k, v := c.First(); k != nil; k, v = c.Next() {
elem, err := storobj.FromBinary(v)
if err != nil {
return errors.Wrapf(err, "unmarshal data object")
}
// scanAll has no abort, so we can ignore the first arg
properties := elem.Properties()
_, err = scan(&properties, elem.DocID())
if err != nil {
return err
}
}
return nil
}