Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package traverser | |
import ( | |
"context" | |
"fmt" | |
"strings" | |
"github.com/go-openapi/strfmt" | |
"github.com/pkg/errors" | |
"github.com/weaviate/weaviate/entities/additional" | |
"github.com/weaviate/weaviate/entities/modulecapabilities" | |
"github.com/weaviate/weaviate/entities/schema/crossref" | |
"github.com/weaviate/weaviate/entities/search" | |
"github.com/weaviate/weaviate/entities/searchparams" | |
libvectorizer "github.com/weaviate/weaviate/usecases/vectorizer" | |
) | |
type nearParamsVector struct { | |
modulesProvider ModulesProvider | |
search nearParamsSearcher | |
} | |
type nearParamsSearcher interface { | |
Object(ctx context.Context, className string, id strfmt.UUID, | |
props search.SelectProperties, additional additional.Properties, | |
repl *additional.ReplicationProperties, tenant string) (*search.Result, error) | |
ObjectsByID(ctx context.Context, id strfmt.UUID, props search.SelectProperties, | |
additional additional.Properties, tenant string) (search.Results, error) | |
} | |
func newNearParamsVector(modulesProvider ModulesProvider, search nearParamsSearcher) *nearParamsVector { | |
return &nearParamsVector{modulesProvider, search} | |
} | |
func (v *nearParamsVector) vectorFromParams(ctx context.Context, | |
nearVector *searchparams.NearVector, nearObject *searchparams.NearObject, | |
moduleParams map[string]interface{}, className string, tenant string, | |
) ([]float32, error) { | |
err := v.validateNearParams(nearVector, nearObject, moduleParams, className) | |
if err != nil { | |
return nil, err | |
} | |
if len(moduleParams) == 1 { | |
for name, value := range moduleParams { | |
return v.vectorFromModules(ctx, className, name, value, tenant) | |
} | |
} | |
if nearVector != nil { | |
return nearVector.Vector, nil | |
} | |
if nearObject != nil { | |
vector, err := v.vectorFromNearObjectParams(ctx, className, nearObject, tenant) | |
if err != nil { | |
return nil, errors.Errorf("nearObject params: %v", err) | |
} | |
return vector, nil | |
} | |
// either nearObject or nearVector or module search param has to be set, | |
// so if we land here, something has gone very wrong | |
panic("vectorFromParams was called without any known params present") | |
} | |
func (v *nearParamsVector) validateNearParams(nearVector *searchparams.NearVector, | |
nearObject *searchparams.NearObject, | |
moduleParams map[string]interface{}, className ...string, | |
) error { | |
if len(moduleParams) == 1 && nearVector != nil && nearObject != nil { | |
return errors.Errorf("found 'nearText' and 'nearVector' and 'nearObject' parameters " + | |
"which are conflicting, choose one instead") | |
} | |
if len(moduleParams) == 1 && nearVector != nil { | |
return errors.Errorf("found both 'nearText' and 'nearVector' parameters " + | |
"which are conflicting, choose one instead") | |
} | |
if len(moduleParams) == 1 && nearObject != nil { | |
return errors.Errorf("found both 'nearText' and 'nearObject' parameters " + | |
"which are conflicting, choose one instead") | |
} | |
if nearVector != nil && nearObject != nil { | |
return errors.Errorf("found both 'nearVector' and 'nearObject' parameters " + | |
"which are conflicting, choose one instead") | |
} | |
if v.modulesProvider != nil { | |
if len(moduleParams) > 1 { | |
params := []string{} | |
for p := range moduleParams { | |
params = append(params, fmt.Sprintf("'%s'", p)) | |
} | |
return errors.Errorf("found more then one module param: %s which are conflicting "+ | |
"choose one instead", strings.Join(params, ", ")) | |
} | |
for name, value := range moduleParams { | |
if len(className) == 1 { | |
err := v.modulesProvider.ValidateSearchParam(name, value, className[0]) | |
if err != nil { | |
return err | |
} | |
} else { | |
err := v.modulesProvider.CrossClassValidateSearchParam(name, value) | |
if err != nil { | |
return err | |
} | |
} | |
} | |
} | |
if nearVector != nil { | |
if nearVector.Certainty != 0 && nearVector.Distance != 0 { | |
return errors.Errorf("found 'certainty' and 'distance' set in nearVector " + | |
"which are conflicting, choose one instead") | |
} | |
} | |
if nearObject != nil { | |
if nearObject.Certainty != 0 && nearObject.Distance != 0 { | |
return errors.Errorf("found 'certainty' and 'distance' set in nearObject " + | |
"which are conflicting, choose one instead") | |
} | |
} | |
return nil | |
} | |
func (v *nearParamsVector) vectorFromModules(ctx context.Context, | |
className, paramName string, paramValue interface{}, tenant string, | |
) ([]float32, error) { | |
if v.modulesProvider != nil { | |
vector, err := v.modulesProvider.VectorFromSearchParam(ctx, | |
className, paramName, paramValue, v.findVector, tenant, | |
) | |
if err != nil { | |
return nil, errors.Errorf("vectorize params: %v", err) | |
} | |
return vector, nil | |
} | |
return nil, errors.New("no modules defined") | |
} | |
func (v *nearParamsVector) findVector(ctx context.Context, className string, id strfmt.UUID, tenant string) ([]float32, error) { | |
switch className { | |
case "": | |
// Explore cross class searches where we don't have class context | |
return v.crossClassFindVector(ctx, id) | |
default: | |
return v.classFindVector(ctx, className, id, tenant) | |
} | |
} | |
func (v *nearParamsVector) classFindVector(ctx context.Context, className string, | |
id strfmt.UUID, tenant string, | |
) ([]float32, error) { | |
res, err := v.search.Object(ctx, className, id, search.SelectProperties{}, additional.Properties{}, nil, tenant) | |
if err != nil { | |
return nil, err | |
} | |
if res == nil { | |
return nil, errors.New("vector not found") | |
} | |
return res.Vector, nil | |
} | |
func (v *nearParamsVector) crossClassFindVector(ctx context.Context, id strfmt.UUID) ([]float32, error) { | |
res, err := v.search.ObjectsByID(ctx, id, search.SelectProperties{}, additional.Properties{}, "") | |
if err != nil { | |
return nil, errors.Wrap(err, "find objects") | |
} | |
switch len(res) { | |
case 0: | |
return nil, errors.New("vector not found") | |
case 1: | |
return res[0].Vector, nil | |
default: | |
vectors := make([][]float32, len(res)) | |
for i := range res { | |
vectors[i] = res[i].Vector | |
} | |
return libvectorizer.CombineVectors(vectors), nil | |
} | |
} | |
func (v *nearParamsVector) crossClassVectorFromNearObjectParams(ctx context.Context, | |
params *searchparams.NearObject, | |
) ([]float32, error) { | |
return v.vectorFromNearObjectParams(ctx, "", params, "") | |
} | |
func (v *nearParamsVector) vectorFromNearObjectParams(ctx context.Context, | |
className string, params *searchparams.NearObject, tenant string, | |
) ([]float32, error) { | |
if len(params.ID) == 0 && len(params.Beacon) == 0 { | |
return nil, errors.New("empty id and beacon") | |
} | |
var id strfmt.UUID | |
targetClassName := className | |
if len(params.ID) > 0 { | |
id = strfmt.UUID(params.ID) | |
} else { | |
ref, err := crossref.Parse(params.Beacon) | |
if err != nil { | |
return nil, err | |
} | |
id = ref.TargetID | |
if ref.Class != "" { | |
targetClassName = ref.Class | |
} | |
} | |
return v.findVector(ctx, targetClassName, id, tenant) | |
} | |
func (v *nearParamsVector) extractCertaintyFromParams(nearVector *searchparams.NearVector, | |
nearObject *searchparams.NearObject, moduleParams map[string]interface{}, | |
) float64 { | |
if nearVector != nil { | |
if nearVector.Certainty != 0 { | |
return nearVector.Certainty | |
} else if nearVector.WithDistance { | |
return additional.DistToCertainty(nearVector.Distance) | |
} | |
} | |
if nearObject != nil { | |
if nearObject.Certainty != 0 { | |
return nearObject.Certainty | |
} else if nearObject.WithDistance { | |
return additional.DistToCertainty(nearObject.Distance) | |
} | |
} | |
if len(moduleParams) == 1 { | |
return v.extractCertaintyFromModuleParams(moduleParams) | |
} | |
return 0 | |
} | |
func (v *nearParamsVector) extractCertaintyFromModuleParams(moduleParams map[string]interface{}) float64 { | |
for _, param := range moduleParams { | |
if nearParam, ok := param.(modulecapabilities.NearParam); ok { | |
if nearParam.SimilarityMetricProvided() { | |
if certainty := nearParam.GetCertainty(); certainty != 0 { | |
return certainty | |
} else { | |
return additional.DistToCertainty(nearParam.GetDistance()) | |
} | |
} | |
} | |
} | |
return 0 | |
} | |