SemanticSearchPOC / usecases /traverser /traverser_validate_distance_metrics.go
KevinStephenson
Adding in weaviate code
b110593
raw
history blame
3.74 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package traverser
import (
"fmt"
"strings"
"github.com/pkg/errors"
"github.com/weaviate/weaviate/entities/dto"
"github.com/weaviate/weaviate/entities/schema"
"github.com/weaviate/weaviate/entities/vectorindex/common"
)
func (t *Traverser) validateExploreDistance(params ExploreParams) error {
distType, err := t.validateCrossClassDistanceCompatibility()
if err != nil {
return err
}
return t.validateExploreDistanceParams(params, distType)
}
// ensures that all classes are configured with the same distance type.
// if all classes are configured with the same type, said type is returned.
// otherwise an error indicating which classes are configured differently.
func (t *Traverser) validateCrossClassDistanceCompatibility() (distType string, err error) {
s := t.schemaGetter.GetSchemaSkipAuth()
if s.Objects == nil {
return common.DefaultDistanceMetric, nil
}
var (
// a set used to determine the discrete number
// of vector index distance types used across
// all classes. if more than one type exists,
// a cross-class vector search is not possible
distancerTypes = make(map[string]struct{})
// a mapping of class name to vector index distance
// type. used to emit an error if more than one
// distance type is found
classDistanceConfigs = make(map[string]string)
)
for _, class := range s.Objects.Classes {
if class == nil {
continue
}
vectorConfig, assertErr := schema.TypeAssertVectorIndex(class)
if assertErr != nil {
err = assertErr
return
}
distancerTypes[vectorConfig.DistanceName()] = struct{}{}
classDistanceConfigs[class.Class] = vectorConfig.DistanceName()
}
if len(distancerTypes) != 1 {
err = crossClassDistCompatError(classDistanceConfigs)
return
}
// the above check ensures that the
// map only contains one entry
for dt := range distancerTypes {
distType = dt
}
return
}
func (t *Traverser) validateExploreDistanceParams(params ExploreParams, distType string) error {
certainty := extractCertaintyFromExploreParams(params)
if certainty == 0 && !params.WithCertaintyProp {
return nil
}
if distType != common.DistanceCosine {
return certaintyUnsupportedError(distType)
}
return nil
}
func (t *Traverser) validateGetDistanceParams(params dto.GetParams) error {
sch := t.schemaGetter.GetSchemaSkipAuth()
class := sch.GetClass(schema.ClassName(params.ClassName))
if class == nil {
return fmt.Errorf("failed to find class '%s' in schema", params.ClassName)
}
vectorConfig, err := schema.TypeAssertVectorIndex(class)
if err != nil {
return err
}
if dn := vectorConfig.DistanceName(); dn != common.DistanceCosine {
return certaintyUnsupportedError(dn)
}
return nil
}
func crossClassDistCompatError(classDistanceConfigs map[string]string) error {
errorMsg := "vector search across classes not possible: found different distance metrics:"
for class, dist := range classDistanceConfigs {
errorMsg = fmt.Sprintf("%s class '%s' uses distance metric '%s',", errorMsg, class, dist)
}
errorMsg = strings.TrimSuffix(errorMsg, ",")
return fmt.Errorf(errorMsg)
}
func certaintyUnsupportedError(distType string) error {
return errors.Errorf(
"can't compute and return certainty when vector index is configured with %s distance",
distType)
}