KevinStephenson
Adding in weaviate code
b110593
raw
history blame
3.93 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package vectorizer
import (
"context"
"fmt"
"github.com/go-openapi/strfmt"
"github.com/weaviate/weaviate/entities/additional"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/modulecapabilities"
"github.com/weaviate/weaviate/entities/moduletools"
"github.com/weaviate/weaviate/entities/schema/crossref"
"github.com/weaviate/weaviate/entities/search"
"github.com/weaviate/weaviate/modules/ref2vec-centroid/config"
)
type calcFn func(vecs ...[]float32) ([]float32, error)
type Vectorizer struct {
config *config.Config
calcFn calcFn
findObjectFn modulecapabilities.FindObjectFn
}
func New(cfg moduletools.ClassConfig, findFn modulecapabilities.FindObjectFn) *Vectorizer {
v := &Vectorizer{
config: config.New(cfg),
findObjectFn: findFn,
}
switch v.config.CalculationMethod() {
case config.MethodMean:
v.calcFn = calculateMean
default:
v.calcFn = calculateMean
}
return v
}
func (v *Vectorizer) Object(ctx context.Context, obj *models.Object) error {
props := v.config.ReferenceProperties()
refVecs, err := v.referenceVectorSearch(ctx, obj, props)
if err != nil {
return err
}
if len(refVecs) == 0 {
obj.Vector = nil
return nil
}
vec, err := v.calcFn(refVecs...)
if err != nil {
return fmt.Errorf("calculate vector: %w", err)
}
obj.Vector = vec
return nil
}
func (v *Vectorizer) referenceVectorSearch(ctx context.Context,
obj *models.Object, refProps map[string]struct{},
) ([][]float32, error) {
var refVecs [][]float32
props := obj.Properties.(map[string]interface{})
// use the ids from parent's beacons to find the referenced objects
beacons := beaconsForVectorization(props, refProps)
for _, beacon := range beacons {
res, err := v.findReferenceObject(ctx, beacon, obj.Tenant)
if err != nil {
return nil, err
}
// if the ref'd object has a vector, we grab it.
// these will be used to compute the parent's
// vector eventually
if res.Vector != nil {
refVecs = append(refVecs, res.Vector)
}
}
return refVecs, nil
}
func (v *Vectorizer) findReferenceObject(ctx context.Context, beacon strfmt.URI, tenant string) (res *search.Result, err error) {
ref, err := crossref.Parse(beacon.String())
if err != nil {
return nil, fmt.Errorf("parse beacon %q: %w", beacon, err)
}
res, err = v.findObjectFn(ctx, ref.Class, ref.TargetID,
search.SelectProperties{}, additional.Properties{}, tenant)
if err != nil || res == nil {
if err == nil {
err = fmt.Errorf("not found")
}
err = fmt.Errorf("find object with beacon %q': %w", beacon, err)
}
return
}
func beaconsForVectorization(allProps map[string]interface{},
targetRefProps map[string]struct{},
) []strfmt.URI {
var beacons []strfmt.URI
// add any refs that were supplied as a part of the parent
// object, like when caller is AddObject/UpdateObject
for prop, val := range allProps {
if _, ok := targetRefProps[prop]; ok {
switch refs := val.(type) {
case []interface{}:
// due to the fix introduced in https://github.com/weaviate/weaviate/pull/2320,
// MultipleRef's can appear as empty []interface{} when no actual refs are provided for
// an object's reference property.
//
// if we encounter []interface{}, assume it indicates an empty ref prop, and skip it.
continue
case models.MultipleRef:
for _, ref := range refs {
beacons = append(beacons, ref.Beacon)
}
}
}
}
return beacons
}