KevinStephenson
Adding in weaviate code
b110593
raw
history blame
3.58 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package grouper
import (
"fmt"
"github.com/sirupsen/logrus"
"github.com/weaviate/weaviate/entities/search"
"github.com/weaviate/weaviate/usecases/vectorizer"
)
// Grouper groups or merges search results by how related they are
type Grouper struct {
logger logrus.FieldLogger
}
// NewGrouper creates a Grouper UC from the specified configuration
func New(logger logrus.FieldLogger) *Grouper {
return &Grouper{logger: logger}
}
// Group using the applied strategy and force
func (g *Grouper) Group(in []search.Result, strategy string,
force float32,
) ([]search.Result, error) {
groups := groups{logger: g.logger}
for _, current := range in {
pos, ok := groups.hasMatch(current.Vector, force)
if !ok {
groups.new(current)
} else {
groups.Elements[pos].add(current)
}
}
return groups.flatten(strategy)
}
type group struct {
Elements []search.Result `json:"elements"`
}
func (g *group) add(item search.Result) {
g.Elements = append(g.Elements, item)
}
func (g group) matches(vector []float32, force float32) bool {
// iterate over all group Elements and consider it a match if any matches
for _, elem := range g.Elements {
dist, err := vectorizer.NormalizedDistance(vector, elem.Vector)
if err != nil {
// TODO: log error
// we don't expect to ever see this error, so we don't need to handle it
// explicitly, however, let's still log it in case that the above
// assumption is wrong
continue
}
if dist < force {
return true
}
}
return false
}
type groups struct {
Elements []group `json:"elements"`
logger logrus.FieldLogger
}
func (gs groups) hasMatch(vector []float32, force float32) (int, bool) {
for pos, group := range gs.Elements {
if group.matches(vector, force) {
return pos, true
}
}
return -1, false
}
func (gs *groups) new(item search.Result) {
gs.Elements = append(gs.Elements, group{Elements: []search.Result{item}})
}
func (gs groups) flatten(strategy string) (out []search.Result, err error) {
gs.logger.WithField("object", "grouping_before_flatten").
WithField("strategy", strategy).
WithField("groups", gs.Elements).
Debug("group before flattening")
switch strategy {
case "closest":
out, err = gs.flattenClosest()
case "merge":
out, err = gs.flattenMerge()
default:
return nil, fmt.Errorf("unrecognized grouping strategy '%s'", strategy)
}
if err != nil {
return
}
gs.logger.WithField("object", "grouping_after_flatten").
WithField("strategy", strategy).
WithField("groups", gs.Elements).
Debug("group after flattening")
return out, nil
}
func (gs groups) flattenClosest() ([]search.Result, error) {
out := make([]search.Result, len(gs.Elements))
for i, group := range gs.Elements {
out[i] = group.Elements[0] // hard-code "closest" strategy for now
}
return out, nil
}
func (gs groups) flattenMerge() ([]search.Result, error) {
out := make([]search.Result, len(gs.Elements))
for i, group := range gs.Elements {
merged, err := group.flattenMerge()
if err != nil {
return nil, fmt.Errorf("group %d: %v", i, err)
}
out[i] = merged
}
return out, nil
}