SemanticSearchPOC / usecases /schema /authorization_test.go
KevinStephenson
Adding in weaviate code
b110593
raw
history blame
7.27 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package schema
import (
"context"
"errors"
"fmt"
"reflect"
"testing"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/usecases/config"
)
// A component-test like test suite that makes sure that every available UC is
// potentially protected with the Authorization plugin
func Test_Schema_Authorization(t *testing.T) {
type testCase struct {
methodName string
additionalArgs []interface{}
expectedVerb string
expectedResource string
}
tests := []testCase{
{
methodName: "GetSchema",
expectedVerb: "list",
expectedResource: "schema/*",
},
{
methodName: "GetClass",
additionalArgs: []interface{}{"classname"},
expectedVerb: "list",
expectedResource: "schema/*",
},
{
methodName: "GetShardsStatus",
additionalArgs: []interface{}{"className", "tenant"},
expectedVerb: "list",
expectedResource: "schema/className/shards",
},
{
methodName: "AddClass",
additionalArgs: []interface{}{&models.Class{}},
expectedVerb: "create",
expectedResource: "schema/objects",
},
{
methodName: "UpdateClass",
additionalArgs: []interface{}{"somename", &models.Class{}},
expectedVerb: "update",
expectedResource: "schema/objects",
},
{
methodName: "DeleteClass",
additionalArgs: []interface{}{"somename"},
expectedVerb: "delete",
expectedResource: "schema/objects",
},
{
methodName: "AddClassProperty",
additionalArgs: []interface{}{"somename", &models.Property{}},
expectedVerb: "update",
expectedResource: "schema/objects",
},
{
methodName: "MergeClassObjectProperty",
additionalArgs: []interface{}{"somename", &models.Property{}},
expectedVerb: "update",
expectedResource: "schema/objects",
},
{
methodName: "DeleteClassProperty",
additionalArgs: []interface{}{"somename", "someprop"},
expectedVerb: "update",
expectedResource: "schema/objects",
},
{
methodName: "UpdateShardStatus",
additionalArgs: []interface{}{"className", "shardName", "targetStatus"},
expectedVerb: "update",
expectedResource: "schema/className/shards/shardName",
},
{
methodName: "AddTenants",
additionalArgs: []interface{}{"className", []*models.Tenant{{Name: "P1"}}},
expectedVerb: "update",
expectedResource: tenantsPath,
},
{
methodName: "UpdateTenants",
additionalArgs: []interface{}{"className", []*models.Tenant{
{Name: "P1", ActivityStatus: models.TenantActivityStatusHOT},
}},
expectedVerb: "update",
expectedResource: tenantsPath,
},
{
methodName: "DeleteTenants",
additionalArgs: []interface{}{"className", []string{"P1"}},
expectedVerb: "delete",
expectedResource: tenantsPath,
},
{
methodName: "GetTenants",
additionalArgs: []interface{}{"className"},
expectedVerb: "get",
expectedResource: tenantsPath,
},
}
t.Run("verify that a test for every public method exists", func(t *testing.T) {
testedMethods := make([]string, len(tests))
for i, test := range tests {
testedMethods[i] = test.methodName
}
for _, method := range allExportedMethods(&Manager{}) {
switch method {
case "RegisterSchemaUpdateCallback",
"UpdateMeta", "GetSchemaSkipAuth", "IndexedInverted", "RLock", "RUnlock", "Lock", "Unlock",
"TryLock", "RLocker", "TryRLock", // introduced by sync.Mutex in go 1.18
"Nodes", "NodeName", "ClusterHealthScore", "ClusterStatus", "ResolveParentNodes",
"CopyShardingState", "TxManager", "RestoreClass",
"ShardOwner", "TenantShard", "ShardFromUUID", "LockGuard", "RLockGuard", "ShardReplicas",
"StartServing", "Shutdown": // internal methods to indicate readiness state
// don't require auth on methods which are exported because other
// packages need to call them for maintenance and other regular jobs,
// but aren't user facing
continue
}
assert.Contains(t, testedMethods, method)
}
})
t.Run("verify the tested methods require correct permissions from the Authorizer", func(t *testing.T) {
principal := &models.Principal{}
logger, _ := test.NewNullLogger()
for _, test := range tests {
t.Run(test.methodName, func(t *testing.T) {
authorizer := &authDenier{}
manager, err := NewManager(&NilMigrator{}, newFakeRepo(),
logger, authorizer, config.Config{},
dummyParseVectorConfig, &fakeVectorizerValidator{},
dummyValidateInvertedConfig, &fakeModuleConfig{},
&fakeClusterState{hosts: []string{"node1"}}, &fakeTxClient{},
&fakeTxPersistence{}, &fakeScaleOutManager{})
require.Nil(t, err)
var args []interface{}
if test.methodName == "GetSchema" {
// no context on this method
args = append([]interface{}{principal}, test.additionalArgs...)
} else {
args = append([]interface{}{context.Background(), principal}, test.additionalArgs...)
}
out, _ := callFuncByName(manager, test.methodName, args...)
require.Len(t, authorizer.calls, 1, "Authorizer must be called")
assert.Equal(t, errors.New("just a test fake"), out[len(out)-1].Interface(),
"execution must abort with Authorizer error")
assert.Equal(t, authorizeCall{principal, test.expectedVerb, test.expectedResource},
authorizer.calls[0], "correct parameters must have been used on Authorizer")
})
}
})
}
type authorizeCall struct {
principal *models.Principal
verb string
resource string
}
type authDenier struct {
calls []authorizeCall
}
func (a *authDenier) Authorize(principal *models.Principal, verb, resource string) error {
a.calls = append(a.calls, authorizeCall{principal, verb, resource})
return errors.New("just a test fake")
}
// inspired by https://stackoverflow.com/a/33008200
func callFuncByName(manager interface{}, funcName string, params ...interface{}) (out []reflect.Value, err error) {
managerValue := reflect.ValueOf(manager)
m := managerValue.MethodByName(funcName)
if !m.IsValid() {
return make([]reflect.Value, 0), fmt.Errorf("Method not found \"%s\"", funcName)
}
in := make([]reflect.Value, len(params))
for i, param := range params {
in[i] = reflect.ValueOf(param)
}
out = m.Call(in)
return
}
func allExportedMethods(subject interface{}) []string {
var methods []string
subjectType := reflect.TypeOf(subject)
for i := 0; i < subjectType.NumMethod(); i++ {
name := subjectType.Method(i).Name
if name[0] >= 'A' && name[0] <= 'Z' {
methods = append(methods, name)
}
}
return methods
}