SemanticSearchPOC / usecases /objects /authorization_test.go
KevinStephenson
Adding in weaviate code
b110593
raw
history blame
10.2 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package objects
import (
"context"
"errors"
"fmt"
"reflect"
"testing"
"github.com/go-openapi/strfmt"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/weaviate/weaviate/entities/additional"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/usecases/config"
)
// A component-test like test suite that makes sure that every available UC is
// potentially protected with the Authorization plugin
func Test_Kinds_Authorization(t *testing.T) {
type testCase struct {
methodName string
additionalArgs []interface{}
expectedVerb string
expectedResource string
}
tests := []testCase{
// single kind
{
methodName: "AddObject",
additionalArgs: []interface{}{(*models.Object)(nil)},
expectedVerb: "create",
expectedResource: "objects",
},
{
methodName: "ValidateObject",
additionalArgs: []interface{}{(*models.Object)(nil)},
expectedVerb: "validate",
expectedResource: "objects",
},
{
methodName: "GetObject",
additionalArgs: []interface{}{"", strfmt.UUID("foo"), additional.Properties{}},
expectedVerb: "get",
expectedResource: "objects/foo",
},
{
methodName: "DeleteObject",
additionalArgs: []interface{}{"class", strfmt.UUID("foo")},
expectedVerb: "delete",
expectedResource: "objects/class/foo",
},
{ // deprecated by the one above
methodName: "DeleteObject",
additionalArgs: []interface{}{"", strfmt.UUID("foo")},
expectedVerb: "delete",
expectedResource: "objects/foo",
},
{
methodName: "UpdateObject",
additionalArgs: []interface{}{"class", strfmt.UUID("foo"), (*models.Object)(nil)},
expectedVerb: "update",
expectedResource: "objects/class/foo",
},
{ // deprecated by the one above
methodName: "UpdateObject",
additionalArgs: []interface{}{"", strfmt.UUID("foo"), (*models.Object)(nil)},
expectedVerb: "update",
expectedResource: "objects/foo",
},
{
methodName: "MergeObject",
additionalArgs: []interface{}{
&models.Object{Class: "class", ID: "foo"},
(*additional.ReplicationProperties)(nil),
},
expectedVerb: "update",
expectedResource: "objects/class/foo",
},
{
methodName: "GetObjectsClass",
additionalArgs: []interface{}{strfmt.UUID("foo")},
expectedVerb: "get",
expectedResource: "objects/foo",
},
{
methodName: "GetObjectClassFromName",
additionalArgs: []interface{}{strfmt.UUID("foo")},
expectedVerb: "get",
expectedResource: "objects/foo",
},
{
methodName: "HeadObject",
additionalArgs: []interface{}{"class", strfmt.UUID("foo")},
expectedVerb: "head",
expectedResource: "objects/class/foo",
},
{ // deprecated by the one above
methodName: "HeadObject",
additionalArgs: []interface{}{"", strfmt.UUID("foo")},
expectedVerb: "head",
expectedResource: "objects/foo",
},
// query objects
{
methodName: "Query",
additionalArgs: []interface{}{new(QueryParams)},
expectedVerb: "list",
expectedResource: "objects",
},
{ // list objects is deprecated by query
methodName: "GetObjects",
additionalArgs: []interface{}{(*int64)(nil), (*int64)(nil), (*string)(nil), (*string)(nil), additional.Properties{}},
expectedVerb: "list",
expectedResource: "objects",
},
// reference on objects
{
methodName: "AddObjectReference",
additionalArgs: []interface{}{AddReferenceInput{Class: "class", ID: strfmt.UUID("foo"), Property: "some prop"}, (*models.SingleRef)(nil)},
expectedVerb: "update",
expectedResource: "objects/class/foo",
},
{
methodName: "DeleteObjectReference",
additionalArgs: []interface{}{strfmt.UUID("foo"), "some prop", (*models.SingleRef)(nil)},
expectedVerb: "update",
expectedResource: "objects/foo",
},
{
methodName: "UpdateObjectReferences",
additionalArgs: []interface{}{&PutReferenceInput{Class: "class", ID: strfmt.UUID("foo"), Property: "some prop"}},
expectedVerb: "update",
expectedResource: "objects/class/foo",
},
}
t.Run("verify that a test for every public method exists", func(t *testing.T) {
testedMethods := make([]string, len(tests))
for i, test := range tests {
testedMethods[i] = test.methodName
}
for _, method := range allExportedMethods(&Manager{}) {
assert.Contains(t, testedMethods, method)
}
})
t.Run("verify the tested methods require correct permissions from the authorizer", func(t *testing.T) {
principal := &models.Principal{}
logger, _ := test.NewNullLogger()
for _, test := range tests {
if test.methodName != "MergeObject" {
continue
}
t.Run(test.methodName, func(t *testing.T) {
schemaManager := &fakeSchemaManager{}
locks := &fakeLocks{}
cfg := &config.WeaviateConfig{}
authorizer := &authDenier{}
vectorRepo := &fakeVectorRepo{}
manager := NewManager(locks, schemaManager,
cfg, logger, authorizer,
vectorRepo, getFakeModulesProvider(), nil)
args := append([]interface{}{context.Background(), principal}, test.additionalArgs...)
out, _ := callFuncByName(manager, test.methodName, args...)
require.Len(t, authorizer.calls, 1, "authorizer must be called")
aerr := out[len(out)-1].Interface().(error)
if err, ok := aerr.(*Error); !ok || !err.Forbidden() {
assert.Equal(t, errors.New("just a test fake"), aerr,
"execution must abort with authorizer error")
}
assert.Equal(t, authorizeCall{principal, test.expectedVerb, test.expectedResource},
authorizer.calls[0], "correct parameters must have been used on authorizer")
})
}
})
}
func Test_BatchKinds_Authorization(t *testing.T) {
type testCase struct {
methodName string
additionalArgs []interface{}
expectedVerb string
expectedResource string
}
tests := []testCase{
{
methodName: "AddObjects",
additionalArgs: []interface{}{
[]*models.Object{},
[]*string{},
&additional.ReplicationProperties{},
},
expectedVerb: "create",
expectedResource: "batch/objects",
},
{
methodName: "AddReferences",
additionalArgs: []interface{}{
[]*models.BatchReference{},
&additional.ReplicationProperties{},
},
expectedVerb: "update",
expectedResource: "batch/*",
},
{
methodName: "DeleteObjects",
additionalArgs: []interface{}{
&models.BatchDeleteMatch{},
(*bool)(nil),
(*string)(nil),
&additional.ReplicationProperties{},
"",
},
expectedVerb: "delete",
expectedResource: "batch/objects",
},
{
methodName: "DeleteObjectsFromGRPC",
additionalArgs: []interface{}{
BatchDeleteParams{},
&additional.ReplicationProperties{},
"",
},
expectedVerb: "delete",
expectedResource: "batch/objects",
},
}
t.Run("verify that a test for every public method exists", func(t *testing.T) {
testedMethods := make([]string, len(tests))
for i, test := range tests {
testedMethods[i] = test.methodName
}
for _, method := range allExportedMethods(&BatchManager{}) {
assert.Contains(t, testedMethods, method)
}
})
t.Run("verify the tested methods require correct permissions from the authorizer", func(t *testing.T) {
principal := &models.Principal{}
logger, _ := test.NewNullLogger()
for _, test := range tests {
schemaManager := &fakeSchemaManager{}
locks := &fakeLocks{}
cfg := &config.WeaviateConfig{}
authorizer := &authDenier{}
vectorRepo := &fakeVectorRepo{}
modulesProvider := getFakeModulesProvider()
manager := NewBatchManager(vectorRepo, modulesProvider, locks, schemaManager, cfg, logger, authorizer, nil)
args := append([]interface{}{context.Background(), principal}, test.additionalArgs...)
out, _ := callFuncByName(manager, test.methodName, args...)
require.Len(t, authorizer.calls, 1, "authorizer must be called")
assert.Equal(t, errors.New("just a test fake"), out[len(out)-1].Interface(),
"execution must abort with authorizer error")
assert.Equal(t, authorizeCall{principal, test.expectedVerb, test.expectedResource},
authorizer.calls[0], "correct parameters must have been used on authorizer")
}
})
}
type authorizeCall struct {
principal *models.Principal
verb string
resource string
}
type authDenier struct {
calls []authorizeCall
}
func (a *authDenier) Authorize(principal *models.Principal, verb, resource string) error {
a.calls = append(a.calls, authorizeCall{principal, verb, resource})
return errors.New("just a test fake")
}
// inspired by https://stackoverflow.com/a/33008200
func callFuncByName(manager interface{}, funcName string, params ...interface{}) (out []reflect.Value, err error) {
managerValue := reflect.ValueOf(manager)
m := managerValue.MethodByName(funcName)
if !m.IsValid() {
return make([]reflect.Value, 0), fmt.Errorf("Method not found \"%s\"", funcName)
}
in := make([]reflect.Value, len(params))
for i, param := range params {
in[i] = reflect.ValueOf(param)
}
out = m.Call(in)
return
}
func allExportedMethods(subject interface{}) []string {
var methods []string
subjectType := reflect.TypeOf(subject)
for i := 0; i < subjectType.NumMethod(); i++ {
name := subjectType.Method(i).Name
if name[0] >= 'A' && name[0] <= 'Z' {
methods = append(methods, name)
}
}
return methods
}