KevinStephenson
Adding in weaviate code
b110593
raw
history blame
12.6 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package modcontextionary
import (
"context"
"fmt"
"net/http"
"github.com/sirupsen/logrus/hooks/test"
"github.com/tailor-inc/graphql"
"github.com/tailor-inc/graphql/language/ast"
"github.com/weaviate/weaviate/adapters/handlers/graphql/local/explore"
"github.com/weaviate/weaviate/adapters/handlers/graphql/local/get"
test_helper "github.com/weaviate/weaviate/adapters/handlers/graphql/test/helper"
"github.com/weaviate/weaviate/entities/dto"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/modulecapabilities"
"github.com/weaviate/weaviate/entities/moduletools"
"github.com/weaviate/weaviate/entities/search"
text2vecadditional "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional"
text2vecadditionalsempath "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional/sempath"
text2vecadditionalprojector "github.com/weaviate/weaviate/usecases/modulecomponents/additional/projector"
text2vecneartext "github.com/weaviate/weaviate/usecases/modulecomponents/nearText"
"github.com/weaviate/weaviate/usecases/traverser"
)
type mockRequestsLog struct{}
func (m *mockRequestsLog) Register(first string, second string) {
}
type mockResolver struct {
test_helper.MockResolver
}
type fakeInterpretation struct{}
func (f *fakeInterpretation) AdditionalPropertyFn(ctx context.Context,
in []search.Result, params interface{}, limit *int,
argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig,
) ([]search.Result, error) {
return in, nil
}
func (f *fakeInterpretation) ExtractAdditionalFn(param []*ast.Argument) interface{} {
return true
}
func (f *fakeInterpretation) AdditionalPropertyDefaultValue() interface{} {
return true
}
type fakeExtender struct {
returnArgs []search.Result
}
func (f *fakeExtender) AdditionalPropertyFn(ctx context.Context,
in []search.Result, params interface{}, limit *int,
argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig,
) ([]search.Result, error) {
return f.returnArgs, nil
}
func (f *fakeExtender) ExtractAdditionalFn(param []*ast.Argument) interface{} {
return true
}
func (f *fakeExtender) AdditionalPropertyDefaultValue() interface{} {
return true
}
type fakeProjector struct {
returnArgs []search.Result
}
func (f *fakeProjector) AdditionalPropertyFn(ctx context.Context,
in []search.Result, params interface{}, limit *int,
argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig,
) ([]search.Result, error) {
return f.returnArgs, nil
}
func (f *fakeProjector) ExtractAdditionalFn(param []*ast.Argument) interface{} {
if len(param) > 0 {
p := &text2vecadditionalprojector.Params{}
err := p.SetDefaultsAndValidate(100, 4)
if err != nil {
return nil
}
return p
}
return &text2vecadditionalprojector.Params{
Enabled: true,
}
}
func (f *fakeProjector) AdditionalPropertyDefaultValue() interface{} {
return &text2vecadditionalprojector.Params{}
}
type fakePathBuilder struct {
returnArgs []search.Result
}
func (f *fakePathBuilder) AdditionalPropertyFn(ctx context.Context,
in []search.Result, params interface{}, limit *int,
argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig,
) ([]search.Result, error) {
return f.returnArgs, nil
}
func (f *fakePathBuilder) ExtractAdditionalFn(param []*ast.Argument) interface{} {
return &text2vecadditionalsempath.Params{}
}
func (f *fakePathBuilder) AdditionalPropertyDefaultValue() interface{} {
return &text2vecadditionalsempath.Params{}
}
type mockText2vecContextionaryModule struct{}
func (m *mockText2vecContextionaryModule) Name() string {
return "text2vec-contextionary"
}
func (m *mockText2vecContextionaryModule) Init(params moduletools.ModuleInitParams) error {
return nil
}
func (m *mockText2vecContextionaryModule) RootHandler() http.Handler {
return nil
}
// graphql arguments
func (m *mockText2vecContextionaryModule) Arguments() map[string]modulecapabilities.GraphQLArgument {
return text2vecneartext.New(nil).Arguments()
}
// additional properties
func (m *mockText2vecContextionaryModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
return text2vecadditional.New(&fakeExtender{}, &fakeProjector{}, &fakePathBuilder{}, &fakeInterpretation{}).AdditionalProperties()
}
type fakeModulesProvider struct{}
func (fmp *fakeModulesProvider) GetAll() []modulecapabilities.Module {
panic("implement me")
}
func (fmp *fakeModulesProvider) VectorFromInput(ctx context.Context, className string, input string) ([]float32, error) {
panic("not implemented")
}
func (fmp *fakeModulesProvider) GetArguments(class *models.Class) map[string]*graphql.ArgumentConfig {
args := map[string]*graphql.ArgumentConfig{}
txt2vec := &mockText2vecContextionaryModule{}
if class.Vectorizer == txt2vec.Name() {
for name, argument := range txt2vec.Arguments() {
args[name] = argument.GetArgumentsFunction(class.Class)
}
}
return args
}
func (fmp *fakeModulesProvider) ExploreArguments(schema *models.Schema) map[string]*graphql.ArgumentConfig {
args := map[string]*graphql.ArgumentConfig{}
txt2vec := &mockText2vecContextionaryModule{}
for _, c := range schema.Classes {
if c.Vectorizer == txt2vec.Name() {
for name, argument := range txt2vec.Arguments() {
args[name] = argument.ExploreArgumentsFunction()
}
}
}
return args
}
func (fmp *fakeModulesProvider) CrossClassExtractSearchParams(arguments map[string]interface{}) map[string]interface{} {
return fmp.ExtractSearchParams(arguments, "")
}
func (fmp *fakeModulesProvider) ExtractSearchParams(arguments map[string]interface{}, className string) map[string]interface{} {
exractedParams := map[string]interface{}{}
if param, ok := arguments["nearText"]; ok {
exractedParams["nearText"] = extractNearTextParam(param.(map[string]interface{}))
}
return exractedParams
}
func (fmp *fakeModulesProvider) GetAdditionalFields(class *models.Class) map[string]*graphql.Field {
txt2vec := &mockText2vecContextionaryModule{}
additionalProperties := map[string]*graphql.Field{}
for name, additionalProperty := range txt2vec.AdditionalProperties() {
if additionalProperty.GraphQLFieldFunction != nil {
additionalProperties[name] = additionalProperty.GraphQLFieldFunction(class.Class)
}
}
return additionalProperties
}
func (fmp *fakeModulesProvider) ExtractAdditionalField(className, name string, params []*ast.Argument) interface{} {
txt2vec := &mockText2vecContextionaryModule{}
if additionalProperties := txt2vec.AdditionalProperties(); len(additionalProperties) > 0 {
if additionalProperty, ok := additionalProperties[name]; ok {
if additionalProperty.GraphQLExtractFunction != nil {
return additionalProperty.GraphQLExtractFunction(params)
}
}
}
return nil
}
func (fmp *fakeModulesProvider) GetExploreAdditionalExtend(ctx context.Context, in []search.Result,
moduleParams map[string]interface{}, searchVector []float32,
argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig,
) ([]search.Result, error) {
return fmp.additionalExtend(ctx, in, moduleParams, searchVector, "ExploreGet", argumentModuleParams, nil)
}
func (fmp *fakeModulesProvider) additionalExtend(ctx context.Context,
in search.Results, moduleParams map[string]interface{},
searchVector []float32, capability string, argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig,
) (search.Results, error) {
txt2vec := &mockText2vecContextionaryModule{}
additionalProperties := txt2vec.AdditionalProperties()
for name, value := range moduleParams {
additionalPropertyFn := fmp.getAdditionalPropertyFn(additionalProperties[name], capability)
if additionalPropertyFn != nil && value != nil {
searchValue := value
if searchVectorValue, ok := value.(modulecapabilities.AdditionalPropertyWithSearchVector); ok {
searchVectorValue.SetSearchVector(searchVector)
searchValue = searchVectorValue
}
resArray, err := additionalPropertyFn(ctx, in, searchValue, nil, nil, nil)
if err != nil {
return nil, err
}
in = resArray
}
}
return in, nil
}
func (fmp *fakeModulesProvider) getAdditionalPropertyFn(additionalProperty modulecapabilities.AdditionalProperty,
capability string,
) modulecapabilities.AdditionalPropertyFn {
switch capability {
case "ObjectGet":
return additionalProperty.SearchFunctions.ObjectGet
case "ObjectList":
return additionalProperty.SearchFunctions.ObjectList
case "ExploreGet":
return additionalProperty.SearchFunctions.ExploreGet
case "ExploreList":
return additionalProperty.SearchFunctions.ExploreList
default:
return nil
}
}
func (fmp *fakeModulesProvider) GraphQLAdditionalFieldNames() []string {
txt2vec := &mockText2vecContextionaryModule{}
additionalPropertiesNames := []string{}
for _, additionalProperty := range txt2vec.AdditionalProperties() {
if additionalProperty.GraphQLNames != nil {
additionalPropertiesNames = append(additionalPropertiesNames, additionalProperty.GraphQLNames...)
}
}
return additionalPropertiesNames
}
func extractNearTextParam(param map[string]interface{}) interface{} {
txt2vec := &mockText2vecContextionaryModule{}
argument := txt2vec.Arguments()["nearText"]
return argument.ExtractFunction(param)
}
func createArg(name string, value string) *ast.Argument {
n := ast.Name{
Value: name,
}
val := ast.StringValue{
Kind: "Kind",
Value: value,
}
arg := ast.Argument{
Name: ast.NewName(&n),
Kind: "Kind",
Value: ast.NewStringValue(&val),
}
a := ast.NewArgument(&arg)
return a
}
func extractAdditionalParam(name string, args []*ast.Argument) interface{} {
txt2vec := &mockText2vecContextionaryModule{}
additionalProperties := txt2vec.AdditionalProperties()
switch name {
case "semanticPath", "featureProjection":
if ap, ok := additionalProperties[name]; ok {
return ap.GraphQLExtractFunction(args)
}
return nil
default:
return nil
}
}
func getFakeModulesProvider() *fakeModulesProvider {
return &fakeModulesProvider{}
}
func newMockResolver() *mockResolver {
logger, _ := test.NewNullLogger()
field, err := get.Build(&test_helper.SimpleSchema, logger, getFakeModulesProvider())
if err != nil {
panic(fmt.Sprintf("could not build graphql test schema: %s", err))
}
mocker := &mockResolver{}
mockLog := &mockRequestsLog{}
mocker.RootFieldName = "Get"
mocker.RootField = field
mocker.RootObject = map[string]interface{}{"Resolver": GetResolver(mocker), "RequestsLog": RequestsLog(mockLog)}
return mocker
}
func newExploreMockResolver() *mockResolver {
field := explore.Build(test_helper.SimpleSchema.Objects, getFakeModulesProvider())
mocker := &mockResolver{}
mockLog := &mockRequestsLog{}
mocker.RootFieldName = "Explore"
mocker.RootField = field
mocker.RootObject = map[string]interface{}{
"Resolver": ExploreResolver(mocker),
"RequestsLog": mockLog,
}
return mocker
}
func (m *mockResolver) GetClass(ctx context.Context, principal *models.Principal,
params dto.GetParams,
) ([]interface{}, error) {
args := m.Called(params)
return args.Get(0).([]interface{}), args.Error(1)
}
func (m *mockResolver) Explore(ctx context.Context,
principal *models.Principal, params traverser.ExploreParams,
) ([]search.Result, error) {
args := m.Called(params)
return args.Get(0).([]search.Result), args.Error(1)
}
// Resolver is a local abstraction of the required UC resolvers
type GetResolver interface {
GetClass(ctx context.Context, principal *models.Principal, info dto.GetParams) ([]interface{}, error)
}
type ExploreResolver interface {
Explore(ctx context.Context, principal *models.Principal, params traverser.ExploreParams) ([]search.Result, error)
}
// RequestsLog is a local abstraction on the RequestsLog that needs to be
// provided to the graphQL API in order to log Local.Get queries.
type RequestsLog interface {
Register(requestType string, identifier string)
}