SemanticSearchPOC / usecases /traverser /near_params_vector.go
KevinStephenson
Adding in weaviate code
b110593
raw
history blame
8.41 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package traverser
import (
"context"
"fmt"
"strings"
"github.com/go-openapi/strfmt"
"github.com/pkg/errors"
"github.com/weaviate/weaviate/entities/additional"
"github.com/weaviate/weaviate/entities/modulecapabilities"
"github.com/weaviate/weaviate/entities/schema/crossref"
"github.com/weaviate/weaviate/entities/search"
"github.com/weaviate/weaviate/entities/searchparams"
libvectorizer "github.com/weaviate/weaviate/usecases/vectorizer"
)
type nearParamsVector struct {
modulesProvider ModulesProvider
search nearParamsSearcher
}
type nearParamsSearcher interface {
Object(ctx context.Context, className string, id strfmt.UUID,
props search.SelectProperties, additional additional.Properties,
repl *additional.ReplicationProperties, tenant string) (*search.Result, error)
ObjectsByID(ctx context.Context, id strfmt.UUID, props search.SelectProperties,
additional additional.Properties, tenant string) (search.Results, error)
}
func newNearParamsVector(modulesProvider ModulesProvider, search nearParamsSearcher) *nearParamsVector {
return &nearParamsVector{modulesProvider, search}
}
func (v *nearParamsVector) vectorFromParams(ctx context.Context,
nearVector *searchparams.NearVector, nearObject *searchparams.NearObject,
moduleParams map[string]interface{}, className string, tenant string,
) ([]float32, error) {
err := v.validateNearParams(nearVector, nearObject, moduleParams, className)
if err != nil {
return nil, err
}
if len(moduleParams) == 1 {
for name, value := range moduleParams {
return v.vectorFromModules(ctx, className, name, value, tenant)
}
}
if nearVector != nil {
return nearVector.Vector, nil
}
if nearObject != nil {
vector, err := v.vectorFromNearObjectParams(ctx, className, nearObject, tenant)
if err != nil {
return nil, errors.Errorf("nearObject params: %v", err)
}
return vector, nil
}
// either nearObject or nearVector or module search param has to be set,
// so if we land here, something has gone very wrong
panic("vectorFromParams was called without any known params present")
}
func (v *nearParamsVector) validateNearParams(nearVector *searchparams.NearVector,
nearObject *searchparams.NearObject,
moduleParams map[string]interface{}, className ...string,
) error {
if len(moduleParams) == 1 && nearVector != nil && nearObject != nil {
return errors.Errorf("found 'nearText' and 'nearVector' and 'nearObject' parameters " +
"which are conflicting, choose one instead")
}
if len(moduleParams) == 1 && nearVector != nil {
return errors.Errorf("found both 'nearText' and 'nearVector' parameters " +
"which are conflicting, choose one instead")
}
if len(moduleParams) == 1 && nearObject != nil {
return errors.Errorf("found both 'nearText' and 'nearObject' parameters " +
"which are conflicting, choose one instead")
}
if nearVector != nil && nearObject != nil {
return errors.Errorf("found both 'nearVector' and 'nearObject' parameters " +
"which are conflicting, choose one instead")
}
if v.modulesProvider != nil {
if len(moduleParams) > 1 {
params := []string{}
for p := range moduleParams {
params = append(params, fmt.Sprintf("'%s'", p))
}
return errors.Errorf("found more then one module param: %s which are conflicting "+
"choose one instead", strings.Join(params, ", "))
}
for name, value := range moduleParams {
if len(className) == 1 {
err := v.modulesProvider.ValidateSearchParam(name, value, className[0])
if err != nil {
return err
}
} else {
err := v.modulesProvider.CrossClassValidateSearchParam(name, value)
if err != nil {
return err
}
}
}
}
if nearVector != nil {
if nearVector.Certainty != 0 && nearVector.Distance != 0 {
return errors.Errorf("found 'certainty' and 'distance' set in nearVector " +
"which are conflicting, choose one instead")
}
}
if nearObject != nil {
if nearObject.Certainty != 0 && nearObject.Distance != 0 {
return errors.Errorf("found 'certainty' and 'distance' set in nearObject " +
"which are conflicting, choose one instead")
}
}
return nil
}
func (v *nearParamsVector) vectorFromModules(ctx context.Context,
className, paramName string, paramValue interface{}, tenant string,
) ([]float32, error) {
if v.modulesProvider != nil {
vector, err := v.modulesProvider.VectorFromSearchParam(ctx,
className, paramName, paramValue, v.findVector, tenant,
)
if err != nil {
return nil, errors.Errorf("vectorize params: %v", err)
}
return vector, nil
}
return nil, errors.New("no modules defined")
}
func (v *nearParamsVector) findVector(ctx context.Context, className string, id strfmt.UUID, tenant string) ([]float32, error) {
switch className {
case "":
// Explore cross class searches where we don't have class context
return v.crossClassFindVector(ctx, id)
default:
return v.classFindVector(ctx, className, id, tenant)
}
}
func (v *nearParamsVector) classFindVector(ctx context.Context, className string,
id strfmt.UUID, tenant string,
) ([]float32, error) {
res, err := v.search.Object(ctx, className, id, search.SelectProperties{}, additional.Properties{}, nil, tenant)
if err != nil {
return nil, err
}
if res == nil {
return nil, errors.New("vector not found")
}
return res.Vector, nil
}
func (v *nearParamsVector) crossClassFindVector(ctx context.Context, id strfmt.UUID) ([]float32, error) {
res, err := v.search.ObjectsByID(ctx, id, search.SelectProperties{}, additional.Properties{}, "")
if err != nil {
return nil, errors.Wrap(err, "find objects")
}
switch len(res) {
case 0:
return nil, errors.New("vector not found")
case 1:
return res[0].Vector, nil
default:
vectors := make([][]float32, len(res))
for i := range res {
vectors[i] = res[i].Vector
}
return libvectorizer.CombineVectors(vectors), nil
}
}
func (v *nearParamsVector) crossClassVectorFromNearObjectParams(ctx context.Context,
params *searchparams.NearObject,
) ([]float32, error) {
return v.vectorFromNearObjectParams(ctx, "", params, "")
}
func (v *nearParamsVector) vectorFromNearObjectParams(ctx context.Context,
className string, params *searchparams.NearObject, tenant string,
) ([]float32, error) {
if len(params.ID) == 0 && len(params.Beacon) == 0 {
return nil, errors.New("empty id and beacon")
}
var id strfmt.UUID
targetClassName := className
if len(params.ID) > 0 {
id = strfmt.UUID(params.ID)
} else {
ref, err := crossref.Parse(params.Beacon)
if err != nil {
return nil, err
}
id = ref.TargetID
if ref.Class != "" {
targetClassName = ref.Class
}
}
return v.findVector(ctx, targetClassName, id, tenant)
}
func (v *nearParamsVector) extractCertaintyFromParams(nearVector *searchparams.NearVector,
nearObject *searchparams.NearObject, moduleParams map[string]interface{},
) float64 {
if nearVector != nil {
if nearVector.Certainty != 0 {
return nearVector.Certainty
} else if nearVector.WithDistance {
return additional.DistToCertainty(nearVector.Distance)
}
}
if nearObject != nil {
if nearObject.Certainty != 0 {
return nearObject.Certainty
} else if nearObject.WithDistance {
return additional.DistToCertainty(nearObject.Distance)
}
}
if len(moduleParams) == 1 {
return v.extractCertaintyFromModuleParams(moduleParams)
}
return 0
}
func (v *nearParamsVector) extractCertaintyFromModuleParams(moduleParams map[string]interface{}) float64 {
for _, param := range moduleParams {
if nearParam, ok := param.(modulecapabilities.NearParam); ok {
if nearParam.SimilarityMetricProvided() {
if certainty := nearParam.GetCertainty(); certainty != 0 {
return certainty
} else {
return additional.DistToCertainty(nearParam.GetDistance())
}
}
}
}
return 0
}