File size: 3,929 Bytes
b110593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
//                           _       _
// __      _____  __ ___   ___  __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
//  CONTACT: [email protected]
//

package vectorizer

import (
	"context"
	"fmt"

	"github.com/go-openapi/strfmt"
	"github.com/weaviate/weaviate/entities/additional"
	"github.com/weaviate/weaviate/entities/models"
	"github.com/weaviate/weaviate/entities/modulecapabilities"
	"github.com/weaviate/weaviate/entities/moduletools"
	"github.com/weaviate/weaviate/entities/schema/crossref"
	"github.com/weaviate/weaviate/entities/search"
	"github.com/weaviate/weaviate/modules/ref2vec-centroid/config"
)

type calcFn func(vecs ...[]float32) ([]float32, error)

type Vectorizer struct {
	config       *config.Config
	calcFn       calcFn
	findObjectFn modulecapabilities.FindObjectFn
}

func New(cfg moduletools.ClassConfig, findFn modulecapabilities.FindObjectFn) *Vectorizer {
	v := &Vectorizer{
		config:       config.New(cfg),
		findObjectFn: findFn,
	}

	switch v.config.CalculationMethod() {
	case config.MethodMean:
		v.calcFn = calculateMean
	default:
		v.calcFn = calculateMean
	}

	return v
}

func (v *Vectorizer) Object(ctx context.Context, obj *models.Object) error {
	props := v.config.ReferenceProperties()

	refVecs, err := v.referenceVectorSearch(ctx, obj, props)
	if err != nil {
		return err
	}

	if len(refVecs) == 0 {
		obj.Vector = nil
		return nil
	}

	vec, err := v.calcFn(refVecs...)
	if err != nil {
		return fmt.Errorf("calculate vector: %w", err)
	}

	obj.Vector = vec
	return nil
}

func (v *Vectorizer) referenceVectorSearch(ctx context.Context,
	obj *models.Object, refProps map[string]struct{},
) ([][]float32, error) {
	var refVecs [][]float32
	props := obj.Properties.(map[string]interface{})

	// use the ids from parent's beacons to find the referenced objects
	beacons := beaconsForVectorization(props, refProps)
	for _, beacon := range beacons {
		res, err := v.findReferenceObject(ctx, beacon, obj.Tenant)
		if err != nil {
			return nil, err
		}

		// if the ref'd object has a vector, we grab it.
		// these will be used to compute the parent's
		// vector eventually
		if res.Vector != nil {
			refVecs = append(refVecs, res.Vector)
		}
	}

	return refVecs, nil
}

func (v *Vectorizer) findReferenceObject(ctx context.Context, beacon strfmt.URI, tenant string) (res *search.Result, err error) {
	ref, err := crossref.Parse(beacon.String())
	if err != nil {
		return nil, fmt.Errorf("parse beacon %q: %w", beacon, err)
	}

	res, err = v.findObjectFn(ctx, ref.Class, ref.TargetID,
		search.SelectProperties{}, additional.Properties{}, tenant)
	if err != nil || res == nil {
		if err == nil {
			err = fmt.Errorf("not found")
		}
		err = fmt.Errorf("find object with beacon %q': %w", beacon, err)
	}
	return
}

func beaconsForVectorization(allProps map[string]interface{},

	targetRefProps map[string]struct{},

) []strfmt.URI {
	var beacons []strfmt.URI

	// add any refs that were supplied as a part of the parent
	// object, like when caller is AddObject/UpdateObject
	for prop, val := range allProps {
		if _, ok := targetRefProps[prop]; ok {
			switch refs := val.(type) {
			case []interface{}:
				// due to the fix introduced in https://github.com/weaviate/weaviate/pull/2320,
				// MultipleRef's can appear as empty []interface{} when no actual refs are provided for
				// an object's reference property.
				//
				// if we encounter []interface{}, assume it indicates an empty ref prop, and skip it.
				continue
			case models.MultipleRef:
				for _, ref := range refs {
					beacons = append(beacons, ref.Beacon)
				}
			}
		}
	}

	return beacons
}