Spaces:
Running
Running
File size: 3,740 Bytes
b110593 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package traverser
import (
"fmt"
"strings"
"github.com/pkg/errors"
"github.com/weaviate/weaviate/entities/dto"
"github.com/weaviate/weaviate/entities/schema"
"github.com/weaviate/weaviate/entities/vectorindex/common"
)
func (t *Traverser) validateExploreDistance(params ExploreParams) error {
distType, err := t.validateCrossClassDistanceCompatibility()
if err != nil {
return err
}
return t.validateExploreDistanceParams(params, distType)
}
// ensures that all classes are configured with the same distance type.
// if all classes are configured with the same type, said type is returned.
// otherwise an error indicating which classes are configured differently.
func (t *Traverser) validateCrossClassDistanceCompatibility() (distType string, err error) {
s := t.schemaGetter.GetSchemaSkipAuth()
if s.Objects == nil {
return common.DefaultDistanceMetric, nil
}
var (
// a set used to determine the discrete number
// of vector index distance types used across
// all classes. if more than one type exists,
// a cross-class vector search is not possible
distancerTypes = make(map[string]struct{})
// a mapping of class name to vector index distance
// type. used to emit an error if more than one
// distance type is found
classDistanceConfigs = make(map[string]string)
)
for _, class := range s.Objects.Classes {
if class == nil {
continue
}
vectorConfig, assertErr := schema.TypeAssertVectorIndex(class)
if assertErr != nil {
err = assertErr
return
}
distancerTypes[vectorConfig.DistanceName()] = struct{}{}
classDistanceConfigs[class.Class] = vectorConfig.DistanceName()
}
if len(distancerTypes) != 1 {
err = crossClassDistCompatError(classDistanceConfigs)
return
}
// the above check ensures that the
// map only contains one entry
for dt := range distancerTypes {
distType = dt
}
return
}
func (t *Traverser) validateExploreDistanceParams(params ExploreParams, distType string) error {
certainty := extractCertaintyFromExploreParams(params)
if certainty == 0 && !params.WithCertaintyProp {
return nil
}
if distType != common.DistanceCosine {
return certaintyUnsupportedError(distType)
}
return nil
}
func (t *Traverser) validateGetDistanceParams(params dto.GetParams) error {
sch := t.schemaGetter.GetSchemaSkipAuth()
class := sch.GetClass(schema.ClassName(params.ClassName))
if class == nil {
return fmt.Errorf("failed to find class '%s' in schema", params.ClassName)
}
vectorConfig, err := schema.TypeAssertVectorIndex(class)
if err != nil {
return err
}
if dn := vectorConfig.DistanceName(); dn != common.DistanceCosine {
return certaintyUnsupportedError(dn)
}
return nil
}
func crossClassDistCompatError(classDistanceConfigs map[string]string) error {
errorMsg := "vector search across classes not possible: found different distance metrics:"
for class, dist := range classDistanceConfigs {
errorMsg = fmt.Sprintf("%s class '%s' uses distance metric '%s',", errorMsg, class, dist)
}
errorMsg = strings.TrimSuffix(errorMsg, ",")
return fmt.Errorf(errorMsg)
}
func certaintyUnsupportedError(distType string) error {
return errors.Errorf(
"can't compute and return certainty when vector index is configured with %s distance",
distType)
}
|