KevinStephenson
Adding in weaviate code
b110593
raw
history blame
5.35 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package modcentroid
import (
"context"
"fmt"
"testing"
"time"
"github.com/go-openapi/strfmt"
"github.com/google/uuid"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/weaviate/weaviate/entities/additional"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/modulecapabilities"
"github.com/weaviate/weaviate/entities/moduletools"
"github.com/weaviate/weaviate/entities/schema"
"github.com/weaviate/weaviate/entities/schema/crossref"
"github.com/weaviate/weaviate/entities/search"
"github.com/weaviate/weaviate/usecases/config"
)
func TestRef2VecCentroid(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
sp := newFakeStorageProvider(t)
logger, _ := test.NewNullLogger()
params := moduletools.NewInitParams(sp, nil, config.Config{}, logger)
mod := New()
classConfig := fakeClassConfig(mod.ClassConfigDefaults())
refProp := "someRef"
classConfig["referenceProperties"] = []interface{}{refProp}
t.Run("Init", func(t *testing.T) {
err := mod.Init(ctx, params)
assert.Nil(t, err)
})
t.Run("RootHandler", func(t *testing.T) {
h := mod.RootHandler()
assert.Nil(t, h)
})
t.Run("Type", func(t *testing.T) {
typ := mod.Type()
assert.Equal(t, modulecapabilities.Ref2Vec, typ)
})
t.Run("Name", func(t *testing.T) {
name := mod.Name()
assert.Equal(t, Name, name)
})
t.Run("MetaInfo", func(t *testing.T) {
meta, err := mod.MetaInfo()
assert.Nil(t, err)
assert.Empty(t, meta)
})
t.Run("PropertyConfigDefaults", func(t *testing.T) {
dt := schema.DataType("dataType")
props := mod.PropertyConfigDefaults(&dt)
assert.Nil(t, props)
})
t.Run("ValidateClass", func(t *testing.T) {
t.Run("expected success", func(t *testing.T) {
class := &models.Class{}
err := mod.ValidateClass(ctx, class, classConfig)
assert.Nil(t, err)
})
t.Run("expected error", func(t *testing.T) {
class := &models.Class{Class: "InvalidConfigClass"}
cfg := fakeClassConfig{}
expectedErr := fmt.Sprintf(
"validate %q: invalid config: must have at least one "+
"value in the \"referenceProperties\" field",
class.Class)
err := mod.ValidateClass(ctx, class, cfg)
assert.EqualError(t, err, expectedErr)
})
})
t.Run("VectorizeObject", func(t *testing.T) {
t.Run("expected success", func(t *testing.T) {
t.Run("one refVec", func(t *testing.T) {
repo := &fakeObjectsRepo{}
ref := crossref.New("localhost", "SomeClass", strfmt.UUID(uuid.NewString()))
obj := &models.Object{Properties: map[string]interface{}{
refProp: models.MultipleRef{ref.SingleRef()},
}}
repo.On("Object", ctx, ref.Class, ref.TargetID).
Return(&search.Result{Vector: []float32{1, 2, 3}}, nil)
err := mod.VectorizeObject(ctx, obj, classConfig, repo.Object)
assert.Nil(t, err)
expectedVec := models.C11yVector{1, 2, 3}
assert.EqualValues(t, expectedVec, obj.Vector)
})
t.Run("no refVecs", func(t *testing.T) {
repo := &fakeObjectsRepo{}
ref := crossref.New("localhost", "SomeClass", strfmt.UUID(uuid.NewString()))
obj := &models.Object{Properties: map[string]interface{}{
refProp: models.MultipleRef{ref.SingleRef()},
}}
repo.On("Object", ctx, ref.Class, ref.TargetID).
Return(&search.Result{}, nil)
err := mod.VectorizeObject(ctx, obj, classConfig, repo.Object)
assert.Nil(t, err)
assert.Nil(t, nil, obj.Vector)
})
})
t.Run("expected error", func(t *testing.T) {
t.Run("mismatched refVec lengths", func(t *testing.T) {
repo := &fakeObjectsRepo{}
ref1 := crossref.New("localhost", "SomeClass", strfmt.UUID(uuid.NewString()))
ref2 := crossref.New("localhost", "OtherClass", strfmt.UUID(uuid.NewString()))
obj := &models.Object{Properties: map[string]interface{}{
refProp: models.MultipleRef{
ref1.SingleRef(),
ref2.SingleRef(),
},
}}
expectedErr := fmt.Errorf("calculate vector: calculate mean: " +
"found vectors of different length: 2 and 3")
repo.On("Object", ctx, ref1.Class, ref1.TargetID).
Return(&search.Result{Vector: []float32{1, 2}}, nil)
repo.On("Object", ctx, ref2.Class, ref2.TargetID).
Return(&search.Result{Vector: []float32{1, 2, 3}}, nil)
err := mod.VectorizeObject(ctx, obj, classConfig, repo.Object)
assert.EqualError(t, err, expectedErr.Error())
})
})
})
}
type fakeObjectsRepo struct {
mock.Mock
}
func (r *fakeObjectsRepo) Object(ctx context.Context, class string,
id strfmt.UUID, props search.SelectProperties,
addl additional.Properties, tenant string,
) (*search.Result, error) {
args := r.Called(ctx, class, id)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*search.Result), args.Error(1)
}