Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package modcontextionary | |
import ( | |
"context" | |
"fmt" | |
"net/http" | |
"github.com/sirupsen/logrus/hooks/test" | |
"github.com/tailor-inc/graphql" | |
"github.com/tailor-inc/graphql/language/ast" | |
"github.com/weaviate/weaviate/adapters/handlers/graphql/local/explore" | |
"github.com/weaviate/weaviate/adapters/handlers/graphql/local/get" | |
test_helper "github.com/weaviate/weaviate/adapters/handlers/graphql/test/helper" | |
"github.com/weaviate/weaviate/entities/dto" | |
"github.com/weaviate/weaviate/entities/models" | |
"github.com/weaviate/weaviate/entities/modulecapabilities" | |
"github.com/weaviate/weaviate/entities/moduletools" | |
"github.com/weaviate/weaviate/entities/search" | |
text2vecadditional "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional" | |
text2vecadditionalsempath "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional/sempath" | |
text2vecadditionalprojector "github.com/weaviate/weaviate/usecases/modulecomponents/additional/projector" | |
text2vecneartext "github.com/weaviate/weaviate/usecases/modulecomponents/nearText" | |
"github.com/weaviate/weaviate/usecases/traverser" | |
) | |
type mockRequestsLog struct{} | |
func (m *mockRequestsLog) Register(first string, second string) { | |
} | |
type mockResolver struct { | |
test_helper.MockResolver | |
} | |
type fakeInterpretation struct{} | |
func (f *fakeInterpretation) AdditionalPropertyFn(ctx context.Context, | |
in []search.Result, params interface{}, limit *int, | |
argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig, | |
) ([]search.Result, error) { | |
return in, nil | |
} | |
func (f *fakeInterpretation) ExtractAdditionalFn(param []*ast.Argument) interface{} { | |
return true | |
} | |
func (f *fakeInterpretation) AdditionalPropertyDefaultValue() interface{} { | |
return true | |
} | |
type fakeExtender struct { | |
returnArgs []search.Result | |
} | |
func (f *fakeExtender) AdditionalPropertyFn(ctx context.Context, | |
in []search.Result, params interface{}, limit *int, | |
argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig, | |
) ([]search.Result, error) { | |
return f.returnArgs, nil | |
} | |
func (f *fakeExtender) ExtractAdditionalFn(param []*ast.Argument) interface{} { | |
return true | |
} | |
func (f *fakeExtender) AdditionalPropertyDefaultValue() interface{} { | |
return true | |
} | |
type fakeProjector struct { | |
returnArgs []search.Result | |
} | |
func (f *fakeProjector) AdditionalPropertyFn(ctx context.Context, | |
in []search.Result, params interface{}, limit *int, | |
argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig, | |
) ([]search.Result, error) { | |
return f.returnArgs, nil | |
} | |
func (f *fakeProjector) ExtractAdditionalFn(param []*ast.Argument) interface{} { | |
if len(param) > 0 { | |
p := &text2vecadditionalprojector.Params{} | |
err := p.SetDefaultsAndValidate(100, 4) | |
if err != nil { | |
return nil | |
} | |
return p | |
} | |
return &text2vecadditionalprojector.Params{ | |
Enabled: true, | |
} | |
} | |
func (f *fakeProjector) AdditionalPropertyDefaultValue() interface{} { | |
return &text2vecadditionalprojector.Params{} | |
} | |
type fakePathBuilder struct { | |
returnArgs []search.Result | |
} | |
func (f *fakePathBuilder) AdditionalPropertyFn(ctx context.Context, | |
in []search.Result, params interface{}, limit *int, | |
argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig, | |
) ([]search.Result, error) { | |
return f.returnArgs, nil | |
} | |
func (f *fakePathBuilder) ExtractAdditionalFn(param []*ast.Argument) interface{} { | |
return &text2vecadditionalsempath.Params{} | |
} | |
func (f *fakePathBuilder) AdditionalPropertyDefaultValue() interface{} { | |
return &text2vecadditionalsempath.Params{} | |
} | |
type mockText2vecContextionaryModule struct{} | |
func (m *mockText2vecContextionaryModule) Name() string { | |
return "text2vec-contextionary" | |
} | |
func (m *mockText2vecContextionaryModule) Init(params moduletools.ModuleInitParams) error { | |
return nil | |
} | |
func (m *mockText2vecContextionaryModule) RootHandler() http.Handler { | |
return nil | |
} | |
// graphql arguments | |
func (m *mockText2vecContextionaryModule) Arguments() map[string]modulecapabilities.GraphQLArgument { | |
return text2vecneartext.New(nil).Arguments() | |
} | |
// additional properties | |
func (m *mockText2vecContextionaryModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty { | |
return text2vecadditional.New(&fakeExtender{}, &fakeProjector{}, &fakePathBuilder{}, &fakeInterpretation{}).AdditionalProperties() | |
} | |
type fakeModulesProvider struct{} | |
func (fmp *fakeModulesProvider) GetAll() []modulecapabilities.Module { | |
panic("implement me") | |
} | |
func (fmp *fakeModulesProvider) VectorFromInput(ctx context.Context, className string, input string) ([]float32, error) { | |
panic("not implemented") | |
} | |
func (fmp *fakeModulesProvider) GetArguments(class *models.Class) map[string]*graphql.ArgumentConfig { | |
args := map[string]*graphql.ArgumentConfig{} | |
txt2vec := &mockText2vecContextionaryModule{} | |
if class.Vectorizer == txt2vec.Name() { | |
for name, argument := range txt2vec.Arguments() { | |
args[name] = argument.GetArgumentsFunction(class.Class) | |
} | |
} | |
return args | |
} | |
func (fmp *fakeModulesProvider) ExploreArguments(schema *models.Schema) map[string]*graphql.ArgumentConfig { | |
args := map[string]*graphql.ArgumentConfig{} | |
txt2vec := &mockText2vecContextionaryModule{} | |
for _, c := range schema.Classes { | |
if c.Vectorizer == txt2vec.Name() { | |
for name, argument := range txt2vec.Arguments() { | |
args[name] = argument.ExploreArgumentsFunction() | |
} | |
} | |
} | |
return args | |
} | |
func (fmp *fakeModulesProvider) CrossClassExtractSearchParams(arguments map[string]interface{}) map[string]interface{} { | |
return fmp.ExtractSearchParams(arguments, "") | |
} | |
func (fmp *fakeModulesProvider) ExtractSearchParams(arguments map[string]interface{}, className string) map[string]interface{} { | |
exractedParams := map[string]interface{}{} | |
if param, ok := arguments["nearText"]; ok { | |
exractedParams["nearText"] = extractNearTextParam(param.(map[string]interface{})) | |
} | |
return exractedParams | |
} | |
func (fmp *fakeModulesProvider) GetAdditionalFields(class *models.Class) map[string]*graphql.Field { | |
txt2vec := &mockText2vecContextionaryModule{} | |
additionalProperties := map[string]*graphql.Field{} | |
for name, additionalProperty := range txt2vec.AdditionalProperties() { | |
if additionalProperty.GraphQLFieldFunction != nil { | |
additionalProperties[name] = additionalProperty.GraphQLFieldFunction(class.Class) | |
} | |
} | |
return additionalProperties | |
} | |
func (fmp *fakeModulesProvider) ExtractAdditionalField(className, name string, params []*ast.Argument) interface{} { | |
txt2vec := &mockText2vecContextionaryModule{} | |
if additionalProperties := txt2vec.AdditionalProperties(); len(additionalProperties) > 0 { | |
if additionalProperty, ok := additionalProperties[name]; ok { | |
if additionalProperty.GraphQLExtractFunction != nil { | |
return additionalProperty.GraphQLExtractFunction(params) | |
} | |
} | |
} | |
return nil | |
} | |
func (fmp *fakeModulesProvider) GetExploreAdditionalExtend(ctx context.Context, in []search.Result, | |
moduleParams map[string]interface{}, searchVector []float32, | |
argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig, | |
) ([]search.Result, error) { | |
return fmp.additionalExtend(ctx, in, moduleParams, searchVector, "ExploreGet", argumentModuleParams, nil) | |
} | |
func (fmp *fakeModulesProvider) additionalExtend(ctx context.Context, | |
in search.Results, moduleParams map[string]interface{}, | |
searchVector []float32, capability string, argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig, | |
) (search.Results, error) { | |
txt2vec := &mockText2vecContextionaryModule{} | |
additionalProperties := txt2vec.AdditionalProperties() | |
for name, value := range moduleParams { | |
additionalPropertyFn := fmp.getAdditionalPropertyFn(additionalProperties[name], capability) | |
if additionalPropertyFn != nil && value != nil { | |
searchValue := value | |
if searchVectorValue, ok := value.(modulecapabilities.AdditionalPropertyWithSearchVector); ok { | |
searchVectorValue.SetSearchVector(searchVector) | |
searchValue = searchVectorValue | |
} | |
resArray, err := additionalPropertyFn(ctx, in, searchValue, nil, nil, nil) | |
if err != nil { | |
return nil, err | |
} | |
in = resArray | |
} | |
} | |
return in, nil | |
} | |
func (fmp *fakeModulesProvider) getAdditionalPropertyFn(additionalProperty modulecapabilities.AdditionalProperty, | |
capability string, | |
) modulecapabilities.AdditionalPropertyFn { | |
switch capability { | |
case "ObjectGet": | |
return additionalProperty.SearchFunctions.ObjectGet | |
case "ObjectList": | |
return additionalProperty.SearchFunctions.ObjectList | |
case "ExploreGet": | |
return additionalProperty.SearchFunctions.ExploreGet | |
case "ExploreList": | |
return additionalProperty.SearchFunctions.ExploreList | |
default: | |
return nil | |
} | |
} | |
func (fmp *fakeModulesProvider) GraphQLAdditionalFieldNames() []string { | |
txt2vec := &mockText2vecContextionaryModule{} | |
additionalPropertiesNames := []string{} | |
for _, additionalProperty := range txt2vec.AdditionalProperties() { | |
if additionalProperty.GraphQLNames != nil { | |
additionalPropertiesNames = append(additionalPropertiesNames, additionalProperty.GraphQLNames...) | |
} | |
} | |
return additionalPropertiesNames | |
} | |
func extractNearTextParam(param map[string]interface{}) interface{} { | |
txt2vec := &mockText2vecContextionaryModule{} | |
argument := txt2vec.Arguments()["nearText"] | |
return argument.ExtractFunction(param) | |
} | |
func createArg(name string, value string) *ast.Argument { | |
n := ast.Name{ | |
Value: name, | |
} | |
val := ast.StringValue{ | |
Kind: "Kind", | |
Value: value, | |
} | |
arg := ast.Argument{ | |
Name: ast.NewName(&n), | |
Kind: "Kind", | |
Value: ast.NewStringValue(&val), | |
} | |
a := ast.NewArgument(&arg) | |
return a | |
} | |
func extractAdditionalParam(name string, args []*ast.Argument) interface{} { | |
txt2vec := &mockText2vecContextionaryModule{} | |
additionalProperties := txt2vec.AdditionalProperties() | |
switch name { | |
case "semanticPath", "featureProjection": | |
if ap, ok := additionalProperties[name]; ok { | |
return ap.GraphQLExtractFunction(args) | |
} | |
return nil | |
default: | |
return nil | |
} | |
} | |
func getFakeModulesProvider() *fakeModulesProvider { | |
return &fakeModulesProvider{} | |
} | |
func newMockResolver() *mockResolver { | |
logger, _ := test.NewNullLogger() | |
field, err := get.Build(&test_helper.SimpleSchema, logger, getFakeModulesProvider()) | |
if err != nil { | |
panic(fmt.Sprintf("could not build graphql test schema: %s", err)) | |
} | |
mocker := &mockResolver{} | |
mockLog := &mockRequestsLog{} | |
mocker.RootFieldName = "Get" | |
mocker.RootField = field | |
mocker.RootObject = map[string]interface{}{"Resolver": GetResolver(mocker), "RequestsLog": RequestsLog(mockLog)} | |
return mocker | |
} | |
func newExploreMockResolver() *mockResolver { | |
field := explore.Build(test_helper.SimpleSchema.Objects, getFakeModulesProvider()) | |
mocker := &mockResolver{} | |
mockLog := &mockRequestsLog{} | |
mocker.RootFieldName = "Explore" | |
mocker.RootField = field | |
mocker.RootObject = map[string]interface{}{ | |
"Resolver": ExploreResolver(mocker), | |
"RequestsLog": mockLog, | |
} | |
return mocker | |
} | |
func (m *mockResolver) GetClass(ctx context.Context, principal *models.Principal, | |
params dto.GetParams, | |
) ([]interface{}, error) { | |
args := m.Called(params) | |
return args.Get(0).([]interface{}), args.Error(1) | |
} | |
func (m *mockResolver) Explore(ctx context.Context, | |
principal *models.Principal, params traverser.ExploreParams, | |
) ([]search.Result, error) { | |
args := m.Called(params) | |
return args.Get(0).([]search.Result), args.Error(1) | |
} | |
// Resolver is a local abstraction of the required UC resolvers | |
type GetResolver interface { | |
GetClass(ctx context.Context, principal *models.Principal, info dto.GetParams) ([]interface{}, error) | |
} | |
type ExploreResolver interface { | |
Explore(ctx context.Context, principal *models.Principal, params traverser.ExploreParams) ([]search.Result, error) | |
} | |
// RequestsLog is a local abstraction on the RequestsLog that needs to be | |
// provided to the graphQL API in order to log Local.Get queries. | |
type RequestsLog interface { | |
Register(requestType string, identifier string) | |
} | |