Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package objects | |
import ( | |
"context" | |
"errors" | |
"testing" | |
"time" | |
"github.com/go-openapi/strfmt" | |
"github.com/sirupsen/logrus/hooks/test" | |
"github.com/stretchr/testify/assert" | |
"github.com/stretchr/testify/mock" | |
"github.com/stretchr/testify/require" | |
"github.com/weaviate/weaviate/entities/models" | |
"github.com/weaviate/weaviate/entities/schema" | |
"github.com/weaviate/weaviate/entities/search" | |
"github.com/weaviate/weaviate/usecases/config" | |
enthnsw "github.com/weaviate/weaviate/entities/vectorindex/hnsw" | |
) | |
func Test_UpdateAction(t *testing.T) { | |
var ( | |
db *fakeVectorRepo | |
modulesProvider *fakeModulesProvider | |
manager *Manager | |
extender *fakeExtender | |
projectorFake *fakeProjector | |
) | |
schema := schema.Schema{ | |
Objects: &models.Schema{ | |
Classes: []*models.Class{ | |
{ | |
Class: "ActionClass", | |
VectorIndexConfig: enthnsw.NewDefaultUserConfig(), | |
Properties: []*models.Property{ | |
{ | |
DataType: schema.DataTypeText.PropString(), | |
Tokenization: models.PropertyTokenizationWhitespace, | |
Name: "foo", | |
}, | |
}, | |
}, | |
}, | |
}, | |
} | |
reset := func() { | |
db = &fakeVectorRepo{} | |
schemaManager := &fakeSchemaManager{ | |
GetSchemaResponse: schema, | |
} | |
locks := &fakeLocks{} | |
cfg := &config.WeaviateConfig{} | |
cfg.Config.QueryDefaults.Limit = 20 | |
cfg.Config.QueryMaximumResults = 200 | |
authorizer := &fakeAuthorizer{} | |
logger, _ := test.NewNullLogger() | |
extender = &fakeExtender{} | |
projectorFake = &fakeProjector{} | |
metrics := &fakeMetrics{} | |
modulesProvider = getFakeModulesProviderWithCustomExtenders(extender, projectorFake) | |
manager = NewManager(locks, schemaManager, cfg, | |
logger, authorizer, db, modulesProvider, metrics) | |
} | |
t.Run("ensure creation timestamp persists", func(t *testing.T) { | |
reset() | |
beforeUpdate := time.Now().UnixNano() / int64(time.Millisecond) | |
id := strfmt.UUID("34e9df15-0c3b-468d-ab99-f929662834c7") | |
vec := []float32{0, 1, 2} | |
result := &search.Result{ | |
ID: id, | |
ClassName: "ActionClass", | |
Schema: map[string]interface{}{"foo": "bar"}, | |
Created: beforeUpdate, | |
Updated: beforeUpdate, | |
} | |
db.On("ObjectByID", id, mock.Anything, mock.Anything).Return(result, nil).Once() | |
modulesProvider.On("UpdateVector", mock.Anything, mock.AnythingOfType(FindObjectFn)). | |
Return(vec, nil) | |
db.On("PutObject", mock.Anything, mock.Anything).Return(nil).Once() | |
payload := &models.Object{ | |
Class: "ActionClass", | |
ID: id, | |
Properties: map[string]interface{}{"foo": "baz"}, | |
} | |
res, err := manager.UpdateObject(context.Background(), &models.Principal{}, "", id, payload, nil) | |
require.Nil(t, err) | |
expected := &models.Object{ | |
Class: "ActionClass", | |
ID: id, | |
Properties: map[string]interface{}{"foo": "baz"}, | |
CreationTimeUnix: beforeUpdate, | |
} | |
afterUpdate := time.Now().UnixNano() / int64(time.Millisecond) | |
assert.Equal(t, expected.Class, res.Class) | |
assert.Equal(t, expected.ID, res.ID) | |
assert.Equal(t, expected.Properties, res.Properties) | |
assert.Equal(t, expected.CreationTimeUnix, res.CreationTimeUnix) | |
assert.GreaterOrEqual(t, res.LastUpdateTimeUnix, beforeUpdate) | |
assert.LessOrEqual(t, res.LastUpdateTimeUnix, afterUpdate) | |
}) | |
} | |
func Test_UpdateObject(t *testing.T) { | |
var ( | |
cls = "MyClass" | |
id = strfmt.UUID("34e9df15-0c3b-468d-ab99-f929662834c7") | |
beforeUpdate = (time.Now().UnixNano() - 2*int64(time.Millisecond)) / int64(time.Millisecond) | |
vec = []float32{0, 1, 2} | |
anyErr = errors.New("any error") | |
) | |
schema := schema.Schema{ | |
Objects: &models.Schema{ | |
Classes: []*models.Class{ | |
{ | |
Class: cls, | |
VectorIndexConfig: enthnsw.NewDefaultUserConfig(), | |
Properties: []*models.Property{ | |
{ | |
DataType: schema.DataTypeText.PropString(), | |
Tokenization: models.PropertyTokenizationWhitespace, | |
Name: "foo", | |
}, | |
}, | |
}, | |
}, | |
}, | |
} | |
m := newFakeGetManager(schema) | |
payload := &models.Object{ | |
Class: cls, | |
ID: id, | |
Properties: map[string]interface{}{"foo": "baz"}, | |
} | |
// the object might not exist | |
m.repo.On("Object", cls, id, mock.Anything, mock.Anything, "").Return(nil, anyErr).Once() | |
_, err := m.UpdateObject(context.Background(), &models.Principal{}, cls, id, payload, nil) | |
if err == nil { | |
t.Fatalf("must return an error if object() fails") | |
} | |
result := &search.Result{ | |
ID: id, | |
ClassName: cls, | |
Schema: map[string]interface{}{"foo": "bar"}, | |
Created: beforeUpdate, | |
Updated: beforeUpdate, | |
} | |
m.repo.On("Object", cls, id, mock.Anything, mock.Anything, "").Return(result, nil).Once() | |
m.modulesProvider.On("UpdateVector", mock.Anything, mock.AnythingOfType(FindObjectFn)). | |
Return(vec, nil) | |
m.repo.On("PutObject", mock.Anything, mock.Anything).Return(nil).Once() | |
expected := &models.Object{ | |
Class: cls, | |
ID: id, | |
Properties: map[string]interface{}{"foo": "baz"}, | |
CreationTimeUnix: beforeUpdate, | |
Vector: vec, | |
} | |
res, err := m.UpdateObject(context.Background(), &models.Principal{}, cls, id, payload, nil) | |
require.Nil(t, err) | |
if res.LastUpdateTimeUnix <= beforeUpdate { | |
t.Error("time after update must be greater than time before update ") | |
} | |
res.LastUpdateTimeUnix = 0 // to allow for equality | |
assert.Equal(t, expected, res) | |
} | |