KevinStephenson
Adding in weaviate code
b110593
raw
history blame
21.6 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package modules
import (
"context"
"fmt"
"io"
"net/http"
"testing"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/tailor-inc/graphql"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/modulecapabilities"
"github.com/weaviate/weaviate/entities/moduletools"
enitiesSchema "github.com/weaviate/weaviate/entities/schema"
ubackup "github.com/weaviate/weaviate/usecases/backup"
)
func TestModulesProvider(t *testing.T) {
t.Run("should register simple module", func(t *testing.T) {
// given
modulesProvider := NewProvider()
class := &models.Class{
Class: "ClassOne",
Vectorizer: "mod1",
}
schema := &models.Schema{
Classes: []*models.Class{class},
}
schemaGetter := getFakeSchemaGetter()
modulesProvider.SetSchemaGetter(schemaGetter)
params := map[string]interface{}{}
params["nearArgumentSomeParam"] = string("doesn't matter here")
arguments := map[string]interface{}{}
arguments["nearArgument"] = params
// when
modulesProvider.Register(newGraphQLModule("mod1").withArg("nearArgument"))
logger, _ := test.NewNullLogger()
err := modulesProvider.Init(context.Background(), nil, logger)
registered := modulesProvider.GetAll()
getArgs := modulesProvider.GetArguments(class)
exploreArgs := modulesProvider.ExploreArguments(schema)
extractedArgs := modulesProvider.ExtractSearchParams(arguments, class.Class)
// then
mod1 := registered[0]
assert.Nil(t, err)
assert.Equal(t, "mod1", mod1.Name())
assert.NotNil(t, getArgs["nearArgument"])
assert.NotNil(t, exploreArgs["nearArgument"])
assert.NotNil(t, extractedArgs["nearArgument"])
})
t.Run("should not register modules providing the same search param", func(t *testing.T) {
// given
modulesProvider := NewProvider()
schemaGetter := getFakeSchemaGetter()
modulesProvider.SetSchemaGetter(schemaGetter)
// when
modulesProvider.Register(newGraphQLModule("mod1").withArg("nearArgument"))
modulesProvider.Register(newGraphQLModule("mod2").withArg("nearArgument"))
logger, _ := test.NewNullLogger()
err := modulesProvider.Init(context.Background(), nil, logger)
// then
assert.Nil(t, err)
})
t.Run("should not register modules providing internal search param", func(t *testing.T) {
// given
modulesProvider := NewProvider()
schemaGetter := getFakeSchemaGetter()
modulesProvider.SetSchemaGetter(schemaGetter)
// when
modulesProvider.Register(newGraphQLModule("mod1").withArg("nearArgument"))
modulesProvider.Register(newGraphQLModule("mod3").
withExtractFn("limit").
withExtractFn("where").
withExtractFn("nearVector").
withExtractFn("nearObject").
withExtractFn("group"),
)
logger, _ := test.NewNullLogger()
err := modulesProvider.Init(context.Background(), nil, logger)
// then
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "nearObject conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "nearVector conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "where conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "group conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "limit conflicts with weaviate's internal searcher in modules: [mod3]")
})
t.Run("should not register modules providing faulty params", func(t *testing.T) {
// given
modulesProvider := NewProvider()
schemaGetter := getFakeSchemaGetter()
modulesProvider.SetSchemaGetter(schemaGetter)
// when
modulesProvider.Register(newGraphQLModule("mod1").withArg("nearArgument"))
modulesProvider.Register(newGraphQLModule("mod2").withArg("nearArgument"))
modulesProvider.Register(newGraphQLModule("mod3").
withExtractFn("limit").
withExtractFn("where").
withExtractFn("nearVector").
withExtractFn("nearObject").
withExtractFn("group"),
)
logger, _ := test.NewNullLogger()
err := modulesProvider.Init(context.Background(), nil, logger)
// then
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "nearObject conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "nearVector conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "where conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "group conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "limit conflicts with weaviate's internal searcher in modules: [mod3]")
})
t.Run("should register simple additional property module", func(t *testing.T) {
// given
modulesProvider := NewProvider()
class := &models.Class{
Class: "ClassOne",
Vectorizer: "mod1",
}
schema := &models.Schema{
Classes: []*models.Class{class},
}
schemaGetter := getFakeSchemaGetter()
modulesProvider.SetSchemaGetter(schemaGetter)
params := map[string]interface{}{}
params["nearArgumentSomeParam"] = string("doesn't matter here")
arguments := map[string]interface{}{}
arguments["nearArgument"] = params
// when
modulesProvider.Register(newGraphQLAdditionalModule("mod1").
withGraphQLArg("featureProjection", []string{"featureProjection"}).
withGraphQLArg("interpretation", []string{"interpretation"}).
withRestApiArg("featureProjection", []string{"featureProjection", "fp", "f-p"}).
withRestApiArg("interpretation", []string{"interpretation"}).
withArg("nearArgument"),
)
logger, _ := test.NewNullLogger()
err := modulesProvider.Init(context.Background(), nil, logger)
registered := modulesProvider.GetAll()
getArgs := modulesProvider.GetArguments(class)
exploreArgs := modulesProvider.ExploreArguments(schema)
extractedArgs := modulesProvider.ExtractSearchParams(arguments, class.Class)
restApiFPArgs := modulesProvider.RestApiAdditionalProperties("featureProjection", class)
restApiInterpretationArgs := modulesProvider.RestApiAdditionalProperties("interpretation", class)
graphQLArgs := modulesProvider.GraphQLAdditionalFieldNames()
// then
mod1 := registered[0]
assert.Nil(t, err)
assert.Equal(t, "mod1", mod1.Name())
assert.NotNil(t, getArgs["nearArgument"])
assert.NotNil(t, exploreArgs["nearArgument"])
assert.NotNil(t, extractedArgs["nearArgument"])
assert.NotNil(t, restApiFPArgs["featureProjection"])
assert.NotNil(t, restApiInterpretationArgs["interpretation"])
assert.Contains(t, graphQLArgs, "featureProjection")
assert.Contains(t, graphQLArgs, "interpretation")
})
t.Run("should not register additional property modules providing the same params", func(t *testing.T) {
// given
modulesProvider := NewProvider()
schemaGetter := getFakeSchemaGetter()
modulesProvider.SetSchemaGetter(schemaGetter)
// when
modulesProvider.Register(newGraphQLAdditionalModule("mod1").
withArg("nearArgument").
withGraphQLArg("featureProjection", []string{"featureProjection"}).
withRestApiArg("featureProjection", []string{"featureProjection", "fp", "f-p"}),
)
modulesProvider.Register(newGraphQLAdditionalModule("mod2").
withArg("nearArgument").
withGraphQLArg("featureProjection", []string{"featureProjection"}).
withRestApiArg("featureProjection", []string{"featureProjection", "fp", "f-p"}),
)
logger, _ := test.NewNullLogger()
err := modulesProvider.Init(context.Background(), nil, logger)
// then
assert.Nil(t, err)
})
t.Run("should not register additional property modules providing internal search param", func(t *testing.T) {
// given
modulesProvider := NewProvider()
schemaGetter := getFakeSchemaGetter()
modulesProvider.SetSchemaGetter(schemaGetter)
// when
modulesProvider.Register(newGraphQLAdditionalModule("mod1").withArg("nearArgument"))
modulesProvider.Register(newGraphQLAdditionalModule("mod3").
withExtractFn("limit").
withExtractFn("where").
withExtractFn("nearVector").
withExtractFn("nearObject").
withExtractFn("group").
withExtractFn("groupBy").
withExtractFn("hybrid").
withExtractFn("bm25").
withExtractFn("offset").
withExtractFn("after").
withGraphQLArg("group", []string{"group"}).
withGraphQLArg("classification", []string{"classification"}).
withRestApiArg("classification", []string{"classification"}).
withGraphQLArg("certainty", []string{"certainty"}).
withRestApiArg("certainty", []string{"certainty"}).
withGraphQLArg("distance", []string{"distance"}).
withRestApiArg("distance", []string{"distance"}).
withGraphQLArg("id", []string{"id"}).
withRestApiArg("id", []string{"id"}),
)
logger, _ := test.NewNullLogger()
err := modulesProvider.Init(context.Background(), nil, logger)
// then
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "searcher: nearObject conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "searcher: nearVector conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "searcher: where conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "searcher: group conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "searcher: groupBy conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "searcher: hybrid conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "searcher: bm25 conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "searcher: limit conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "searcher: offset conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "searcher: after conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "rest api additional property: classification conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "rest api additional property: certainty conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "rest api additional property: distance conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "rest api additional property: id conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "graphql additional property: classification conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "graphql additional property: certainty conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "graphql additional property: distance conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "graphql additional property: id conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "graphql additional property: group conflicts with weaviate's internal searcher in modules: [mod3]")
})
t.Run("should not register additional property modules providing faulty params", func(t *testing.T) {
// given
modulesProvider := NewProvider()
schemaGetter := getFakeSchemaGetter()
modulesProvider.SetSchemaGetter(schemaGetter)
// when
modulesProvider.Register(newGraphQLAdditionalModule("mod1").
withArg("nearArgument").
withGraphQLArg("semanticPath", []string{"semanticPath"}).
withRestApiArg("featureProjection", []string{"featureProjection", "fp", "f-p"}),
)
modulesProvider.Register(newGraphQLAdditionalModule("mod2").
withArg("nearArgument").
withGraphQLArg("semanticPath", []string{"semanticPath"}).
withRestApiArg("featureProjection", []string{"featureProjection", "fp", "f-p"}),
)
modulesProvider.Register(newGraphQLModule("mod3").
withExtractFn("limit").
withExtractFn("where").
withExtractFn("nearVector").
withExtractFn("nearObject").
withExtractFn("group"),
)
modulesProvider.Register(newGraphQLAdditionalModule("mod4").
withGraphQLArg("classification", []string{"classification"}).
withRestApiArg("classification", []string{"classification"}).
withGraphQLArg("certainty", []string{"certainty"}).
withRestApiArg("certainty", []string{"certainty"}).
withGraphQLArg("id", []string{"id"}).
withRestApiArg("id", []string{"id"}),
)
logger, _ := test.NewNullLogger()
err := modulesProvider.Init(context.Background(), nil, logger)
// then
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "searcher: nearObject conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "searcher: nearVector conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "searcher: where conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "searcher: group conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "searcher: limit conflicts with weaviate's internal searcher in modules: [mod3]")
assert.Contains(t, err.Error(), "rest api additional property: classification conflicts with weaviate's internal searcher in modules: [mod4]")
assert.Contains(t, err.Error(), "rest api additional property: certainty conflicts with weaviate's internal searcher in modules: [mod4]")
assert.Contains(t, err.Error(), "rest api additional property: id conflicts with weaviate's internal searcher in modules: [mod4]")
assert.Contains(t, err.Error(), "graphql additional property: classification conflicts with weaviate's internal searcher in modules: [mod4]")
assert.Contains(t, err.Error(), "graphql additional property: certainty conflicts with weaviate's internal searcher in modules: [mod4]")
assert.Contains(t, err.Error(), "graphql additional property: id conflicts with weaviate's internal searcher in modules: [mod4]")
})
t.Run("should register module with alt names", func(t *testing.T) {
module := &dummyBackupModuleWithAltNames{}
modulesProvider := NewProvider()
modulesProvider.Register(module)
modByName := modulesProvider.GetByName("SomeBackend")
modByAltName1 := modulesProvider.GetByName("AltBackendName")
modByAltName2 := modulesProvider.GetByName("YetAnotherBackendName")
modMissing := modulesProvider.GetByName("DoesNotExist")
assert.NotNil(t, modByName)
assert.NotNil(t, modByAltName1)
assert.NotNil(t, modByAltName2)
assert.Nil(t, modMissing)
})
t.Run("should provide backup backend", func(t *testing.T) {
module := &dummyBackupModuleWithAltNames{}
modulesProvider := NewProvider()
modulesProvider.Register(module)
provider, ok := interface{}(modulesProvider).(ubackup.BackupBackendProvider)
assert.True(t, ok)
fmt.Printf("provider: %v\n", provider)
backendByName, err1 := provider.BackupBackend("SomeBackend")
backendByAltName, err2 := provider.BackupBackend("YetAnotherBackendName")
assert.NotNil(t, backendByName)
assert.Nil(t, err1)
assert.NotNil(t, backendByAltName)
assert.Nil(t, err2)
})
}
func fakeExtractFn(param map[string]interface{}) interface{} {
extracted := map[string]interface{}{}
extracted["nearArgumentParam"] = []string{"fake"}
return extracted
}
func fakeValidateFn(param interface{}) error {
return nil
}
func newGraphQLModule(name string) *dummyGraphQLModule {
return &dummyGraphQLModule{
dummyText2VecModuleNoCapabilities: newDummyText2VecModule(name),
arguments: map[string]modulecapabilities.GraphQLArgument{},
}
}
type dummyGraphQLModule struct {
dummyText2VecModuleNoCapabilities
arguments map[string]modulecapabilities.GraphQLArgument
}
func (m *dummyGraphQLModule) withArg(argName string) *dummyGraphQLModule {
arg := modulecapabilities.GraphQLArgument{
GetArgumentsFunction: func(classname string) *graphql.ArgumentConfig { return &graphql.ArgumentConfig{} },
ExploreArgumentsFunction: func() *graphql.ArgumentConfig { return &graphql.ArgumentConfig{} },
ExtractFunction: fakeExtractFn,
ValidateFunction: fakeValidateFn,
}
m.arguments[argName] = arg
return m
}
func (m *dummyGraphQLModule) withExtractFn(argName string) *dummyGraphQLModule {
arg := m.arguments[argName]
arg.ExtractFunction = fakeExtractFn
m.arguments[argName] = arg
return m
}
func (m *dummyGraphQLModule) Arguments() map[string]modulecapabilities.GraphQLArgument {
return m.arguments
}
func newGraphQLAdditionalModule(name string) *dummyAdditionalModule {
return &dummyAdditionalModule{
dummyGraphQLModule: *newGraphQLModule(name),
additionalProperties: map[string]modulecapabilities.AdditionalProperty{},
}
}
type dummyAdditionalModule struct {
dummyGraphQLModule
additionalProperties map[string]modulecapabilities.AdditionalProperty
}
func (m *dummyAdditionalModule) withArg(argName string) *dummyAdditionalModule {
m.dummyGraphQLModule.withArg(argName)
return m
}
func (m *dummyAdditionalModule) withExtractFn(argName string) *dummyAdditionalModule {
arg := m.dummyGraphQLModule.arguments[argName]
arg.ExtractFunction = fakeExtractFn
m.dummyGraphQLModule.arguments[argName] = arg
return m
}
func (m *dummyAdditionalModule) withGraphQLArg(argName string, values []string) *dummyAdditionalModule {
prop := m.additionalProperties[argName]
if prop.GraphQLNames == nil {
prop.GraphQLNames = []string{}
}
prop.GraphQLNames = append(prop.GraphQLNames, values...)
m.additionalProperties[argName] = prop
return m
}
func (m *dummyAdditionalModule) withRestApiArg(argName string, values []string) *dummyAdditionalModule {
prop := m.additionalProperties[argName]
if prop.RestNames == nil {
prop.RestNames = []string{}
}
prop.RestNames = append(prop.RestNames, values...)
prop.DefaultValue = 100
m.additionalProperties[argName] = prop
return m
}
func (m *dummyAdditionalModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
return m.additionalProperties
}
func getFakeSchemaGetter() schemaGetter {
sch := enitiesSchema.Schema{
Objects: &models.Schema{
Classes: []*models.Class{
{
Class: "ClassOne",
Vectorizer: "mod1",
ModuleConfig: map[string]interface{}{
"mod": map[string]interface{}{
"some-config": "some-config-value",
},
},
},
{
Class: "ClassTwo",
Vectorizer: "mod2",
ModuleConfig: map[string]interface{}{
"mod": map[string]interface{}{
"some-config": "some-config-value",
},
},
},
{
Class: "ClassThree",
Vectorizer: "mod3",
ModuleConfig: map[string]interface{}{
"mod": map[string]interface{}{
"some-config": "some-config-value",
},
},
},
},
},
}
return &fakeSchemaGetter{schema: sch}
}
type dummyBackupModuleWithAltNames struct{}
func (m *dummyBackupModuleWithAltNames) Name() string {
return "SomeBackend"
}
func (m *dummyBackupModuleWithAltNames) AltNames() []string {
return []string{"AltBackendName", "YetAnotherBackendName"}
}
func (m *dummyBackupModuleWithAltNames) Init(ctx context.Context, params moduletools.ModuleInitParams) error {
return nil
}
func (m *dummyBackupModuleWithAltNames) RootHandler() http.Handler {
return nil
}
func (m *dummyBackupModuleWithAltNames) Type() modulecapabilities.ModuleType {
return modulecapabilities.Backup
}
func (m *dummyBackupModuleWithAltNames) HomeDir(backupID string) string {
return ""
}
func (m *dummyBackupModuleWithAltNames) GetObject(ctx context.Context, backupID, key string) ([]byte, error) {
return nil, nil
}
func (m *dummyBackupModuleWithAltNames) WriteToFile(ctx context.Context, backupID, key, destPath string) error {
return nil
}
func (m *dummyBackupModuleWithAltNames) Write(ctx context.Context, backupID, key string, r io.ReadCloser) (int64, error) {
return 0, nil
}
func (m *dummyBackupModuleWithAltNames) Read(ctx context.Context, backupID, key string, w io.WriteCloser) (int64, error) {
return 0, nil
}
func (m *dummyBackupModuleWithAltNames) SourceDataPath() string {
return ""
}
func (*dummyBackupModuleWithAltNames) IsExternal() bool {
return true
}
func (m *dummyBackupModuleWithAltNames) PutFile(ctx context.Context, backupID, key, srcPath string) error {
return nil
}
func (m *dummyBackupModuleWithAltNames) PutObject(ctx context.Context, backupID, key string, byes []byte) error {
return nil
}
func (m *dummyBackupModuleWithAltNames) Initialize(ctx context.Context, backupID string) error {
return nil
}