Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package traverser | |
import ( | |
"fmt" | |
"strings" | |
"github.com/pkg/errors" | |
"github.com/weaviate/weaviate/entities/dto" | |
"github.com/weaviate/weaviate/entities/schema" | |
"github.com/weaviate/weaviate/entities/vectorindex/common" | |
) | |
func (t *Traverser) validateExploreDistance(params ExploreParams) error { | |
distType, err := t.validateCrossClassDistanceCompatibility() | |
if err != nil { | |
return err | |
} | |
return t.validateExploreDistanceParams(params, distType) | |
} | |
// ensures that all classes are configured with the same distance type. | |
// if all classes are configured with the same type, said type is returned. | |
// otherwise an error indicating which classes are configured differently. | |
func (t *Traverser) validateCrossClassDistanceCompatibility() (distType string, err error) { | |
s := t.schemaGetter.GetSchemaSkipAuth() | |
if s.Objects == nil { | |
return common.DefaultDistanceMetric, nil | |
} | |
var ( | |
// a set used to determine the discrete number | |
// of vector index distance types used across | |
// all classes. if more than one type exists, | |
// a cross-class vector search is not possible | |
distancerTypes = make(map[string]struct{}) | |
// a mapping of class name to vector index distance | |
// type. used to emit an error if more than one | |
// distance type is found | |
classDistanceConfigs = make(map[string]string) | |
) | |
for _, class := range s.Objects.Classes { | |
if class == nil { | |
continue | |
} | |
vectorConfig, assertErr := schema.TypeAssertVectorIndex(class) | |
if assertErr != nil { | |
err = assertErr | |
return | |
} | |
distancerTypes[vectorConfig.DistanceName()] = struct{}{} | |
classDistanceConfigs[class.Class] = vectorConfig.DistanceName() | |
} | |
if len(distancerTypes) != 1 { | |
err = crossClassDistCompatError(classDistanceConfigs) | |
return | |
} | |
// the above check ensures that the | |
// map only contains one entry | |
for dt := range distancerTypes { | |
distType = dt | |
} | |
return | |
} | |
func (t *Traverser) validateExploreDistanceParams(params ExploreParams, distType string) error { | |
certainty := extractCertaintyFromExploreParams(params) | |
if certainty == 0 && !params.WithCertaintyProp { | |
return nil | |
} | |
if distType != common.DistanceCosine { | |
return certaintyUnsupportedError(distType) | |
} | |
return nil | |
} | |
func (t *Traverser) validateGetDistanceParams(params dto.GetParams) error { | |
sch := t.schemaGetter.GetSchemaSkipAuth() | |
class := sch.GetClass(schema.ClassName(params.ClassName)) | |
if class == nil { | |
return fmt.Errorf("failed to find class '%s' in schema", params.ClassName) | |
} | |
vectorConfig, err := schema.TypeAssertVectorIndex(class) | |
if err != nil { | |
return err | |
} | |
if dn := vectorConfig.DistanceName(); dn != common.DistanceCosine { | |
return certaintyUnsupportedError(dn) | |
} | |
return nil | |
} | |
func crossClassDistCompatError(classDistanceConfigs map[string]string) error { | |
errorMsg := "vector search across classes not possible: found different distance metrics:" | |
for class, dist := range classDistanceConfigs { | |
errorMsg = fmt.Sprintf("%s class '%s' uses distance metric '%s',", errorMsg, class, dist) | |
} | |
errorMsg = strings.TrimSuffix(errorMsg, ",") | |
return fmt.Errorf(errorMsg) | |
} | |
func certaintyUnsupportedError(distType string) error { | |
return errors.Errorf( | |
"can't compute and return certainty when vector index is configured with %s distance", | |
distType) | |
} | |