KevinStephenson
Adding in weaviate code
b110593
raw
history blame
3.31 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package nearestneighbors
import (
"context"
"fmt"
"github.com/weaviate/weaviate/entities/moduletools"
"github.com/pkg/errors"
"github.com/tailor-inc/graphql/language/ast"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/search"
txt2vecmodels "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional/models"
)
const (
DefaultLimit = 10
DefaultK = 32
)
type Extender struct {
searcher contextionary
}
type contextionary interface {
MultiNearestWordsByVector(ctx context.Context, vectors [][]float32, k, n int) ([]*txt2vecmodels.NearestNeighbors, error)
}
func (e *Extender) AdditionalPropertyDefaultValue() interface{} {
return true
}
func (e *Extender) AdditionalPropertyFn(ctx context.Context,
in []search.Result, params interface{}, limit *int,
argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig,
) ([]search.Result, error) {
return e.Multi(ctx, in, limit)
}
func (e *Extender) ExtractAdditionalFn(param []*ast.Argument) interface{} {
return true
}
func (e *Extender) Single(ctx context.Context, in *search.Result, limit *int) (*search.Result, error) {
if in == nil {
return nil, nil
}
multiRes, err := e.Multi(ctx, []search.Result{*in}, limit) // safe to deref, as we did a nil check before
if err != nil {
return nil, err
}
return &multiRes[0], nil
}
func (e *Extender) Multi(ctx context.Context, in []search.Result, limit *int) ([]search.Result, error) {
if in == nil {
return nil, nil
}
vectors := make([][]float32, len(in))
for i, res := range in {
if res.Vector == nil || len(res.Vector) == 0 {
return nil, fmt.Errorf("item %d has no vector", i)
}
vectors[i] = res.Vector
}
neighbors, err := e.searcher.MultiNearestWordsByVector(ctx, vectors, DefaultK, limitOrDefault(limit))
if err != nil {
return nil, errors.Wrap(err, "get neighbors for search results")
}
if len(neighbors) != len(in) {
return nil, fmt.Errorf("inconsistent results: input=%d neighbors=%d", len(in), len(neighbors))
}
for i, res := range in {
up := res.AdditionalProperties
if up == nil {
up = models.AdditionalProperties{}
}
up["nearestNeighbors"] = removeDollarElements(neighbors[i])
in[i].AdditionalProperties = up
}
return in, nil
}
func NewExtender(searcher contextionary) *Extender {
return &Extender{searcher: searcher}
}
func limitOrDefault(user *int) int {
if user == nil || *user == 0 {
return DefaultLimit
}
return *user
}
func removeDollarElements(in *txt2vecmodels.NearestNeighbors) *txt2vecmodels.NearestNeighbors {
neighbors := make([]*txt2vecmodels.NearestNeighbor, len(in.Neighbors))
i := 0
for _, elem := range in.Neighbors {
if elem.Concept[0] == '$' {
continue
}
neighbors[i] = elem
i++
}
return &txt2vecmodels.NearestNeighbors{
Neighbors: neighbors[:i],
}
}