Spaces:
Sleeping
Sleeping
| // _ _ | |
| // __ _____ __ ___ ___ __ _| |_ ___ | |
| // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
| // \ V V / __/ (_| |\ V /| | (_| | || __/ | |
| // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
| // | |
| // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
| // | |
| // CONTACT: [email protected] | |
| // | |
| package vectorizer | |
| import ( | |
| "context" | |
| "fmt" | |
| "github.com/go-openapi/strfmt" | |
| "github.com/weaviate/weaviate/entities/additional" | |
| "github.com/weaviate/weaviate/entities/models" | |
| "github.com/weaviate/weaviate/entities/modulecapabilities" | |
| "github.com/weaviate/weaviate/entities/moduletools" | |
| "github.com/weaviate/weaviate/entities/schema/crossref" | |
| "github.com/weaviate/weaviate/entities/search" | |
| "github.com/weaviate/weaviate/modules/ref2vec-centroid/config" | |
| ) | |
| type calcFn func(vecs ...[]float32) ([]float32, error) | |
| type Vectorizer struct { | |
| config *config.Config | |
| calcFn calcFn | |
| findObjectFn modulecapabilities.FindObjectFn | |
| } | |
| func New(cfg moduletools.ClassConfig, findFn modulecapabilities.FindObjectFn) *Vectorizer { | |
| v := &Vectorizer{ | |
| config: config.New(cfg), | |
| findObjectFn: findFn, | |
| } | |
| switch v.config.CalculationMethod() { | |
| case config.MethodMean: | |
| v.calcFn = calculateMean | |
| default: | |
| v.calcFn = calculateMean | |
| } | |
| return v | |
| } | |
| func (v *Vectorizer) Object(ctx context.Context, obj *models.Object) error { | |
| props := v.config.ReferenceProperties() | |
| refVecs, err := v.referenceVectorSearch(ctx, obj, props) | |
| if err != nil { | |
| return err | |
| } | |
| if len(refVecs) == 0 { | |
| obj.Vector = nil | |
| return nil | |
| } | |
| vec, err := v.calcFn(refVecs...) | |
| if err != nil { | |
| return fmt.Errorf("calculate vector: %w", err) | |
| } | |
| obj.Vector = vec | |
| return nil | |
| } | |
| func (v *Vectorizer) referenceVectorSearch(ctx context.Context, | |
| obj *models.Object, refProps map[string]struct{}, | |
| ) ([][]float32, error) { | |
| var refVecs [][]float32 | |
| props := obj.Properties.(map[string]interface{}) | |
| // use the ids from parent's beacons to find the referenced objects | |
| beacons := beaconsForVectorization(props, refProps) | |
| for _, beacon := range beacons { | |
| res, err := v.findReferenceObject(ctx, beacon, obj.Tenant) | |
| if err != nil { | |
| return nil, err | |
| } | |
| // if the ref'd object has a vector, we grab it. | |
| // these will be used to compute the parent's | |
| // vector eventually | |
| if res.Vector != nil { | |
| refVecs = append(refVecs, res.Vector) | |
| } | |
| } | |
| return refVecs, nil | |
| } | |
| func (v *Vectorizer) findReferenceObject(ctx context.Context, beacon strfmt.URI, tenant string) (res *search.Result, err error) { | |
| ref, err := crossref.Parse(beacon.String()) | |
| if err != nil { | |
| return nil, fmt.Errorf("parse beacon %q: %w", beacon, err) | |
| } | |
| res, err = v.findObjectFn(ctx, ref.Class, ref.TargetID, | |
| search.SelectProperties{}, additional.Properties{}, tenant) | |
| if err != nil || res == nil { | |
| if err == nil { | |
| err = fmt.Errorf("not found") | |
| } | |
| err = fmt.Errorf("find object with beacon %q': %w", beacon, err) | |
| } | |
| return | |
| } | |
| func beaconsForVectorization(allProps map[string]interface{}, | |
| targetRefProps map[string]struct{}, | |
| ) []strfmt.URI { | |
| var beacons []strfmt.URI | |
| // add any refs that were supplied as a part of the parent | |
| // object, like when caller is AddObject/UpdateObject | |
| for prop, val := range allProps { | |
| if _, ok := targetRefProps[prop]; ok { | |
| switch refs := val.(type) { | |
| case []interface{}: | |
| // due to the fix introduced in https://github.com/weaviate/weaviate/pull/2320, | |
| // MultipleRef's can appear as empty []interface{} when no actual refs are provided for | |
| // an object's reference property. | |
| // | |
| // if we encounter []interface{}, assume it indicates an empty ref prop, and skip it. | |
| continue | |
| case models.MultipleRef: | |
| for _, ref := range refs { | |
| beacons = append(beacons, ref.Beacon) | |
| } | |
| } | |
| } | |
| } | |
| return beacons | |
| } | |