File size: 3,740 Bytes
b110593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
//                           _       _
// __      _____  __ ___   ___  __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
//  CONTACT: [email protected]
//

package traverser

import (
	"fmt"
	"strings"

	"github.com/pkg/errors"
	"github.com/weaviate/weaviate/entities/dto"
	"github.com/weaviate/weaviate/entities/schema"
	"github.com/weaviate/weaviate/entities/vectorindex/common"
)

func (t *Traverser) validateExploreDistance(params ExploreParams) error {
	distType, err := t.validateCrossClassDistanceCompatibility()
	if err != nil {
		return err
	}

	return t.validateExploreDistanceParams(params, distType)
}

// ensures that all classes are configured with the same distance type.
// if all classes are configured with the same type, said type is returned.
// otherwise an error indicating which classes are configured differently.
func (t *Traverser) validateCrossClassDistanceCompatibility() (distType string, err error) {
	s := t.schemaGetter.GetSchemaSkipAuth()
	if s.Objects == nil {
		return common.DefaultDistanceMetric, nil
	}

	var (
		// a set used to determine the discrete number
		// of vector index distance types used across
		// all classes. if more than one type exists,
		// a cross-class vector search is not possible
		distancerTypes = make(map[string]struct{})

		// a mapping of class name to vector index distance
		// type. used to emit an error if more than one
		// distance type is found
		classDistanceConfigs = make(map[string]string)
	)

	for _, class := range s.Objects.Classes {
		if class == nil {
			continue
		}

		vectorConfig, assertErr := schema.TypeAssertVectorIndex(class)
		if assertErr != nil {
			err = assertErr
			return
		}

		distancerTypes[vectorConfig.DistanceName()] = struct{}{}
		classDistanceConfigs[class.Class] = vectorConfig.DistanceName()
	}

	if len(distancerTypes) != 1 {
		err = crossClassDistCompatError(classDistanceConfigs)
		return
	}

	// the above check ensures that the
	// map only contains one entry
	for dt := range distancerTypes {
		distType = dt
	}

	return
}

func (t *Traverser) validateExploreDistanceParams(params ExploreParams, distType string) error {
	certainty := extractCertaintyFromExploreParams(params)

	if certainty == 0 && !params.WithCertaintyProp {
		return nil
	}

	if distType != common.DistanceCosine {
		return certaintyUnsupportedError(distType)
	}

	return nil
}

func (t *Traverser) validateGetDistanceParams(params dto.GetParams) error {
	sch := t.schemaGetter.GetSchemaSkipAuth()
	class := sch.GetClass(schema.ClassName(params.ClassName))
	if class == nil {
		return fmt.Errorf("failed to find class '%s' in schema", params.ClassName)
	}

	vectorConfig, err := schema.TypeAssertVectorIndex(class)
	if err != nil {
		return err
	}

	if dn := vectorConfig.DistanceName(); dn != common.DistanceCosine {
		return certaintyUnsupportedError(dn)
	}

	return nil
}

func crossClassDistCompatError(classDistanceConfigs map[string]string) error {
	errorMsg := "vector search across classes not possible: found different distance metrics:"
	for class, dist := range classDistanceConfigs {
		errorMsg = fmt.Sprintf("%s class '%s' uses distance metric '%s',", errorMsg, class, dist)
	}
	errorMsg = strings.TrimSuffix(errorMsg, ",")

	return fmt.Errorf(errorMsg)
}

func certaintyUnsupportedError(distType string) error {
	return errors.Errorf(
		"can't compute and return certainty when vector index is configured with %s distance",
		distType)
}