Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package vectorizer | |
import ( | |
"context" | |
"errors" | |
"reflect" | |
"testing" | |
"github.com/go-openapi/strfmt" | |
"github.com/google/uuid" | |
"github.com/stretchr/testify/assert" | |
"github.com/weaviate/weaviate/entities/models" | |
"github.com/weaviate/weaviate/entities/schema/crossref" | |
"github.com/weaviate/weaviate/entities/search" | |
"github.com/weaviate/weaviate/modules/ref2vec-centroid/config" | |
) | |
func TestVectorizer_New(t *testing.T) { | |
repo := &fakeObjectsRepo{} | |
t.Run("default is set correctly", func(t *testing.T) { | |
vzr := New(fakeClassConfig(config.Default()), repo.Object) | |
expected := reflect.ValueOf(calculateMean).Pointer() | |
received := reflect.ValueOf(vzr.calcFn).Pointer() | |
assert.EqualValues(t, expected, received) | |
}) | |
t.Run("default calcFn is used when none provided", func(t *testing.T) { | |
cfg := fakeClassConfig{"method": ""} | |
vzr := New(cfg, repo.Object) | |
expected := reflect.ValueOf(calculateMean).Pointer() | |
received := reflect.ValueOf(vzr.calcFn).Pointer() | |
assert.EqualValues(t, expected, received) | |
}) | |
} | |
func TestVectorizer_Object(t *testing.T) { | |
t.Run("calculate with mean", func(t *testing.T) { | |
type objectSearchResult struct { | |
res *search.Result | |
err error | |
} | |
tests := []struct { | |
name string | |
objectSearchResults []objectSearchResult | |
expectedResult []float32 | |
expectedCalcError error | |
}{ | |
{ | |
name: "expected success 1", | |
objectSearchResults: []objectSearchResult{ | |
{res: &search.Result{Vector: []float32{2, 4, 6}}}, | |
{res: &search.Result{Vector: []float32{4, 6, 8}}}, | |
}, | |
expectedResult: []float32{3, 5, 7}, | |
}, | |
{ | |
name: "expected success 2", | |
objectSearchResults: []objectSearchResult{ | |
{res: &search.Result{Vector: []float32{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}}}, | |
{res: &search.Result{Vector: []float32{2, 2, 2, 2, 2, 2, 2, 2, 2, 2}}}, | |
{res: &search.Result{Vector: []float32{3, 3, 3, 3, 3, 3, 3, 3, 3, 3}}}, | |
{res: &search.Result{Vector: []float32{4, 4, 4, 4, 4, 4, 4, 4, 4, 4}}}, | |
{res: &search.Result{Vector: []float32{5, 5, 5, 5, 5, 5, 5, 5, 5, 5}}}, | |
{res: &search.Result{Vector: []float32{6, 6, 6, 6, 6, 6, 6, 6, 6, 6}}}, | |
{res: &search.Result{Vector: []float32{7, 7, 7, 7, 7, 7, 7, 7, 7, 7}}}, | |
{res: &search.Result{Vector: []float32{8, 8, 8, 8, 8, 8, 8, 8, 8, 8}}}, | |
{res: &search.Result{Vector: []float32{9, 9, 9, 9, 9, 9, 9, 9, 9, 9}}}, | |
}, | |
expectedResult: []float32{5, 5, 5, 5, 5, 5, 5, 5, 5, 5}, | |
}, | |
{ | |
name: "expected success 3", | |
objectSearchResults: []objectSearchResult{{}}, | |
}, | |
{ | |
name: "expected success 4", | |
objectSearchResults: []objectSearchResult{ | |
{res: &search.Result{Vector: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}}}, | |
}, | |
expectedResult: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, | |
}, | |
{ | |
name: "expected success 5", | |
objectSearchResults: []objectSearchResult{ | |
{res: &search.Result{}}, | |
}, | |
expectedResult: nil, | |
}, | |
{ | |
name: "expected error - mismatched vector dimensions", | |
objectSearchResults: []objectSearchResult{ | |
{res: &search.Result{Vector: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}}}, | |
{res: &search.Result{Vector: []float32{1, 2, 3, 4, 5, 6, 7, 8}}}, | |
}, | |
expectedCalcError: errors.New( | |
"calculate vector: calculate mean: found vectors of different length: 9 and 8"), | |
}, | |
} | |
for _, test := range tests { | |
t.Run(test.name, func(t *testing.T) { | |
ctx := context.Background() | |
repo := &fakeObjectsRepo{} | |
refProps := []interface{}{"toRef"} | |
cfg := fakeClassConfig{"method": "mean", "referenceProperties": refProps} | |
crossRefs := make([]*crossref.Ref, len(test.objectSearchResults)) | |
modelRefs := make(models.MultipleRef, len(test.objectSearchResults)) | |
for i, res := range test.objectSearchResults { | |
crossRef := crossref.New("localhost", "SomeClass", | |
strfmt.UUID(uuid.NewString())) | |
crossRefs[i] = crossRef | |
modelRefs[i] = crossRef.SingleRef() | |
repo.On("Object", ctx, crossRef.Class, crossRef.TargetID, ""). | |
Return(res.res, res.err) | |
} | |
obj := &models.Object{ | |
Properties: map[string]interface{}{"toRef": modelRefs}, | |
} | |
err := New(cfg, repo.Object).Object(ctx, obj) | |
if test.expectedCalcError != nil { | |
assert.EqualError(t, err, test.expectedCalcError.Error()) | |
} else { | |
assert.EqualValues(t, test.expectedResult, obj.Vector) | |
} | |
}) | |
} | |
}) | |
// due to the fix introduced in https://github.com/weaviate/weaviate/pull/2320, | |
// MultipleRef's can appear as empty []interface{} when no actual refs are provided for | |
// an object's reference property. | |
// | |
// this test asserts that reference properties do not break when they are unmarshalled | |
// as empty interface{} slices. | |
t.Run("when rep prop is stored as empty interface{} slice", func(t *testing.T) { | |
ctx := context.Background() | |
repo := &fakeObjectsRepo{} | |
refProps := []interface{}{"toRef"} | |
cfg := fakeClassConfig{"method": "mean", "referenceProperties": refProps} | |
obj := &models.Object{ | |
Properties: map[string]interface{}{"toRef": []interface{}{}}, | |
} | |
err := New(cfg, repo.Object).Object(ctx, obj) | |
assert.Nil(t, err) | |
assert.Nil(t, obj.Vector) | |
}) | |
} | |
func TestVectorizer_Tenant(t *testing.T) { | |
objectSearchResults := search.Result{Vector: []float32{}} | |
ctx := context.Background() | |
repo := &fakeObjectsRepo{} | |
refProps := []interface{}{"toRef"} | |
cfg := fakeClassConfig{"method": "mean", "referenceProperties": refProps} | |
tenant := "randomTenant" | |
crossRef := crossref.New("localhost", "SomeClass", | |
strfmt.UUID(uuid.NewString())) | |
modelRefs := models.MultipleRef{crossRef.SingleRef()} | |
repo.On("Object", ctx, crossRef.Class, crossRef.TargetID, tenant). | |
Return(&objectSearchResults, nil) | |
obj := &models.Object{ | |
Properties: map[string]interface{}{"toRef": modelRefs}, | |
Tenant: tenant, | |
} | |
err := New(cfg, repo.Object).Object(ctx, obj) | |
assert.Nil(t, err) | |
assert.Nil(t, obj.Vector) | |
} | |