Spaces:
Sleeping
Sleeping
| // _ _ | |
| // __ _____ __ ___ ___ __ _| |_ ___ | |
| // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
| // \ V V / __/ (_| |\ V /| | (_| | || __/ | |
| // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
| // | |
| // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
| // | |
| // CONTACT: [email protected] | |
| // | |
| package objects | |
| import ( | |
| "context" | |
| "errors" | |
| "fmt" | |
| "reflect" | |
| "testing" | |
| "github.com/go-openapi/strfmt" | |
| "github.com/sirupsen/logrus/hooks/test" | |
| "github.com/stretchr/testify/assert" | |
| "github.com/stretchr/testify/require" | |
| "github.com/weaviate/weaviate/entities/additional" | |
| "github.com/weaviate/weaviate/entities/models" | |
| "github.com/weaviate/weaviate/usecases/config" | |
| ) | |
| // A component-test like test suite that makes sure that every available UC is | |
| // potentially protected with the Authorization plugin | |
| func Test_Kinds_Authorization(t *testing.T) { | |
| type testCase struct { | |
| methodName string | |
| additionalArgs []interface{} | |
| expectedVerb string | |
| expectedResource string | |
| } | |
| tests := []testCase{ | |
| // single kind | |
| { | |
| methodName: "AddObject", | |
| additionalArgs: []interface{}{(*models.Object)(nil)}, | |
| expectedVerb: "create", | |
| expectedResource: "objects", | |
| }, | |
| { | |
| methodName: "ValidateObject", | |
| additionalArgs: []interface{}{(*models.Object)(nil)}, | |
| expectedVerb: "validate", | |
| expectedResource: "objects", | |
| }, | |
| { | |
| methodName: "GetObject", | |
| additionalArgs: []interface{}{"", strfmt.UUID("foo"), additional.Properties{}}, | |
| expectedVerb: "get", | |
| expectedResource: "objects/foo", | |
| }, | |
| { | |
| methodName: "DeleteObject", | |
| additionalArgs: []interface{}{"class", strfmt.UUID("foo")}, | |
| expectedVerb: "delete", | |
| expectedResource: "objects/class/foo", | |
| }, | |
| { // deprecated by the one above | |
| methodName: "DeleteObject", | |
| additionalArgs: []interface{}{"", strfmt.UUID("foo")}, | |
| expectedVerb: "delete", | |
| expectedResource: "objects/foo", | |
| }, | |
| { | |
| methodName: "UpdateObject", | |
| additionalArgs: []interface{}{"class", strfmt.UUID("foo"), (*models.Object)(nil)}, | |
| expectedVerb: "update", | |
| expectedResource: "objects/class/foo", | |
| }, | |
| { // deprecated by the one above | |
| methodName: "UpdateObject", | |
| additionalArgs: []interface{}{"", strfmt.UUID("foo"), (*models.Object)(nil)}, | |
| expectedVerb: "update", | |
| expectedResource: "objects/foo", | |
| }, | |
| { | |
| methodName: "MergeObject", | |
| additionalArgs: []interface{}{ | |
| &models.Object{Class: "class", ID: "foo"}, | |
| (*additional.ReplicationProperties)(nil), | |
| }, | |
| expectedVerb: "update", | |
| expectedResource: "objects/class/foo", | |
| }, | |
| { | |
| methodName: "GetObjectsClass", | |
| additionalArgs: []interface{}{strfmt.UUID("foo")}, | |
| expectedVerb: "get", | |
| expectedResource: "objects/foo", | |
| }, | |
| { | |
| methodName: "GetObjectClassFromName", | |
| additionalArgs: []interface{}{strfmt.UUID("foo")}, | |
| expectedVerb: "get", | |
| expectedResource: "objects/foo", | |
| }, | |
| { | |
| methodName: "HeadObject", | |
| additionalArgs: []interface{}{"class", strfmt.UUID("foo")}, | |
| expectedVerb: "head", | |
| expectedResource: "objects/class/foo", | |
| }, | |
| { // deprecated by the one above | |
| methodName: "HeadObject", | |
| additionalArgs: []interface{}{"", strfmt.UUID("foo")}, | |
| expectedVerb: "head", | |
| expectedResource: "objects/foo", | |
| }, | |
| // query objects | |
| { | |
| methodName: "Query", | |
| additionalArgs: []interface{}{new(QueryParams)}, | |
| expectedVerb: "list", | |
| expectedResource: "objects", | |
| }, | |
| { // list objects is deprecated by query | |
| methodName: "GetObjects", | |
| additionalArgs: []interface{}{(*int64)(nil), (*int64)(nil), (*string)(nil), (*string)(nil), additional.Properties{}}, | |
| expectedVerb: "list", | |
| expectedResource: "objects", | |
| }, | |
| // reference on objects | |
| { | |
| methodName: "AddObjectReference", | |
| additionalArgs: []interface{}{AddReferenceInput{Class: "class", ID: strfmt.UUID("foo"), Property: "some prop"}, (*models.SingleRef)(nil)}, | |
| expectedVerb: "update", | |
| expectedResource: "objects/class/foo", | |
| }, | |
| { | |
| methodName: "DeleteObjectReference", | |
| additionalArgs: []interface{}{strfmt.UUID("foo"), "some prop", (*models.SingleRef)(nil)}, | |
| expectedVerb: "update", | |
| expectedResource: "objects/foo", | |
| }, | |
| { | |
| methodName: "UpdateObjectReferences", | |
| additionalArgs: []interface{}{&PutReferenceInput{Class: "class", ID: strfmt.UUID("foo"), Property: "some prop"}}, | |
| expectedVerb: "update", | |
| expectedResource: "objects/class/foo", | |
| }, | |
| } | |
| t.Run("verify that a test for every public method exists", func(t *testing.T) { | |
| testedMethods := make([]string, len(tests)) | |
| for i, test := range tests { | |
| testedMethods[i] = test.methodName | |
| } | |
| for _, method := range allExportedMethods(&Manager{}) { | |
| assert.Contains(t, testedMethods, method) | |
| } | |
| }) | |
| t.Run("verify the tested methods require correct permissions from the authorizer", func(t *testing.T) { | |
| principal := &models.Principal{} | |
| logger, _ := test.NewNullLogger() | |
| for _, test := range tests { | |
| if test.methodName != "MergeObject" { | |
| continue | |
| } | |
| t.Run(test.methodName, func(t *testing.T) { | |
| schemaManager := &fakeSchemaManager{} | |
| locks := &fakeLocks{} | |
| cfg := &config.WeaviateConfig{} | |
| authorizer := &authDenier{} | |
| vectorRepo := &fakeVectorRepo{} | |
| manager := NewManager(locks, schemaManager, | |
| cfg, logger, authorizer, | |
| vectorRepo, getFakeModulesProvider(), nil) | |
| args := append([]interface{}{context.Background(), principal}, test.additionalArgs...) | |
| out, _ := callFuncByName(manager, test.methodName, args...) | |
| require.Len(t, authorizer.calls, 1, "authorizer must be called") | |
| aerr := out[len(out)-1].Interface().(error) | |
| if err, ok := aerr.(*Error); !ok || !err.Forbidden() { | |
| assert.Equal(t, errors.New("just a test fake"), aerr, | |
| "execution must abort with authorizer error") | |
| } | |
| assert.Equal(t, authorizeCall{principal, test.expectedVerb, test.expectedResource}, | |
| authorizer.calls[0], "correct parameters must have been used on authorizer") | |
| }) | |
| } | |
| }) | |
| } | |
| func Test_BatchKinds_Authorization(t *testing.T) { | |
| type testCase struct { | |
| methodName string | |
| additionalArgs []interface{} | |
| expectedVerb string | |
| expectedResource string | |
| } | |
| tests := []testCase{ | |
| { | |
| methodName: "AddObjects", | |
| additionalArgs: []interface{}{ | |
| []*models.Object{}, | |
| []*string{}, | |
| &additional.ReplicationProperties{}, | |
| }, | |
| expectedVerb: "create", | |
| expectedResource: "batch/objects", | |
| }, | |
| { | |
| methodName: "AddReferences", | |
| additionalArgs: []interface{}{ | |
| []*models.BatchReference{}, | |
| &additional.ReplicationProperties{}, | |
| }, | |
| expectedVerb: "update", | |
| expectedResource: "batch/*", | |
| }, | |
| { | |
| methodName: "DeleteObjects", | |
| additionalArgs: []interface{}{ | |
| &models.BatchDeleteMatch{}, | |
| (*bool)(nil), | |
| (*string)(nil), | |
| &additional.ReplicationProperties{}, | |
| "", | |
| }, | |
| expectedVerb: "delete", | |
| expectedResource: "batch/objects", | |
| }, | |
| { | |
| methodName: "DeleteObjectsFromGRPC", | |
| additionalArgs: []interface{}{ | |
| BatchDeleteParams{}, | |
| &additional.ReplicationProperties{}, | |
| "", | |
| }, | |
| expectedVerb: "delete", | |
| expectedResource: "batch/objects", | |
| }, | |
| } | |
| t.Run("verify that a test for every public method exists", func(t *testing.T) { | |
| testedMethods := make([]string, len(tests)) | |
| for i, test := range tests { | |
| testedMethods[i] = test.methodName | |
| } | |
| for _, method := range allExportedMethods(&BatchManager{}) { | |
| assert.Contains(t, testedMethods, method) | |
| } | |
| }) | |
| t.Run("verify the tested methods require correct permissions from the authorizer", func(t *testing.T) { | |
| principal := &models.Principal{} | |
| logger, _ := test.NewNullLogger() | |
| for _, test := range tests { | |
| schemaManager := &fakeSchemaManager{} | |
| locks := &fakeLocks{} | |
| cfg := &config.WeaviateConfig{} | |
| authorizer := &authDenier{} | |
| vectorRepo := &fakeVectorRepo{} | |
| modulesProvider := getFakeModulesProvider() | |
| manager := NewBatchManager(vectorRepo, modulesProvider, locks, schemaManager, cfg, logger, authorizer, nil) | |
| args := append([]interface{}{context.Background(), principal}, test.additionalArgs...) | |
| out, _ := callFuncByName(manager, test.methodName, args...) | |
| require.Len(t, authorizer.calls, 1, "authorizer must be called") | |
| assert.Equal(t, errors.New("just a test fake"), out[len(out)-1].Interface(), | |
| "execution must abort with authorizer error") | |
| assert.Equal(t, authorizeCall{principal, test.expectedVerb, test.expectedResource}, | |
| authorizer.calls[0], "correct parameters must have been used on authorizer") | |
| } | |
| }) | |
| } | |
| type authorizeCall struct { | |
| principal *models.Principal | |
| verb string | |
| resource string | |
| } | |
| type authDenier struct { | |
| calls []authorizeCall | |
| } | |
| func (a *authDenier) Authorize(principal *models.Principal, verb, resource string) error { | |
| a.calls = append(a.calls, authorizeCall{principal, verb, resource}) | |
| return errors.New("just a test fake") | |
| } | |
| // inspired by https://stackoverflow.com/a/33008200 | |
| func callFuncByName(manager interface{}, funcName string, params ...interface{}) (out []reflect.Value, err error) { | |
| managerValue := reflect.ValueOf(manager) | |
| m := managerValue.MethodByName(funcName) | |
| if !m.IsValid() { | |
| return make([]reflect.Value, 0), fmt.Errorf("Method not found \"%s\"", funcName) | |
| } | |
| in := make([]reflect.Value, len(params)) | |
| for i, param := range params { | |
| in[i] = reflect.ValueOf(param) | |
| } | |
| out = m.Call(in) | |
| return | |
| } | |
| func allExportedMethods(subject interface{}) []string { | |
| var methods []string | |
| subjectType := reflect.TypeOf(subject) | |
| for i := 0; i < subjectType.NumMethod(); i++ { | |
| name := subjectType.Method(i).Name | |
| if name[0] >= 'A' && name[0] <= 'Z' { | |
| methods = append(methods, name) | |
| } | |
| } | |
| return methods | |
| } | |