KevinStephenson
Adding in weaviate code
b110593
raw
history blame
20.8 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package get
import (
"context"
"fmt"
"net/http"
"github.com/sirupsen/logrus/hooks/test"
"github.com/tailor-inc/graphql"
"github.com/tailor-inc/graphql/language/ast"
"github.com/weaviate/weaviate/adapters/handlers/graphql/descriptions"
test_helper "github.com/weaviate/weaviate/adapters/handlers/graphql/test/helper"
"github.com/weaviate/weaviate/entities/dto"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/modulecapabilities"
"github.com/weaviate/weaviate/entities/moduletools"
"github.com/weaviate/weaviate/entities/search"
"github.com/weaviate/weaviate/usecases/config"
)
type mockRequestsLog struct{}
func (m *mockRequestsLog) Register(first string, second string) {
}
type mockResolver struct {
test_helper.MockResolver
}
type fakeInterpretation struct {
returnArgs []search.Result
}
func (f *fakeInterpretation) AdditionalPropertyFn(ctx context.Context,
in []search.Result, params interface{}, limit *int,
argumentModuleParams map[string]interface{},
) ([]search.Result, error) {
return f.returnArgs, nil
}
func (f *fakeInterpretation) ExtractAdditionalFn(param []*ast.Argument) interface{} {
return true
}
func (f *fakeInterpretation) AdditonalPropertyDefaultValue() interface{} {
return true
}
type fakeExtender struct {
returnArgs []search.Result
}
func (f *fakeExtender) AdditionalPropertyFn(ctx context.Context,
in []search.Result, params interface{}, limit *int,
argumentModuleParams map[string]interface{},
) ([]search.Result, error) {
return f.returnArgs, nil
}
func (f *fakeExtender) ExtractAdditionalFn(param []*ast.Argument) interface{} {
return true
}
func (f *fakeExtender) AdditonalPropertyDefaultValue() interface{} {
return true
}
type fakeProjectorParams struct {
Enabled bool
Algorithm string
Dimensions int
Perplexity int
Iterations int
LearningRate int
IncludeNeighbors bool
}
type fakeProjector struct {
returnArgs []search.Result
}
func (f *fakeProjector) AdditionalPropertyFn(ctx context.Context,
in []search.Result, params interface{}, limit *int,
argumentModuleParams map[string]interface{},
) ([]search.Result, error) {
return f.returnArgs, nil
}
func (f *fakeProjector) ExtractAdditionalFn(param []*ast.Argument) interface{} {
if len(param) > 0 {
return &fakeProjectorParams{
Enabled: true,
Algorithm: "tsne",
Dimensions: 3,
Iterations: 100,
LearningRate: 15,
Perplexity: 10,
}
}
return &fakeProjectorParams{
Enabled: true,
}
}
func (f *fakeProjector) AdditonalPropertyDefaultValue() interface{} {
return &fakeProjectorParams{}
}
type pathBuilderParams struct{}
type fakePathBuilder struct {
returnArgs []search.Result
}
func (f *fakePathBuilder) AdditionalPropertyFn(ctx context.Context,
in []search.Result, params interface{}, limit *int,
) ([]search.Result, error) {
return f.returnArgs, nil
}
func (f *fakePathBuilder) ExtractAdditionalFn(param []*ast.Argument) interface{} {
return &pathBuilderParams{}
}
func (f *fakePathBuilder) AdditonalPropertyDefaultValue() interface{} {
return &pathBuilderParams{}
}
type nearCustomTextParams struct {
Values []string
MoveTo nearExploreMove
MoveAwayFrom nearExploreMove
Certainty float64
Distance float64
WithDistance bool
}
// implements the modulecapabilities.NearParam interface
func (n *nearCustomTextParams) GetCertainty() float64 {
return n.Certainty
}
func (n nearCustomTextParams) GetDistance() float64 {
return n.Distance
}
func (n nearCustomTextParams) SimilarityMetricProvided() bool {
return n.Certainty != 0 || n.WithDistance
}
type nearExploreMove struct {
Values []string
Force float32
Objects []nearObjectMove
}
type nearObjectMove struct {
ID string
Beacon string
}
type nearCustomTextModule struct {
fakePathBuilder *fakePathBuilder
fakeProjector *fakeProjector
fakeExtender *fakeExtender
fakeInterpretation *fakeInterpretation
}
func newNearCustomTextModule() *nearCustomTextModule {
return &nearCustomTextModule{
fakePathBuilder: &fakePathBuilder{},
fakeProjector: &fakeProjector{},
fakeExtender: &fakeExtender{},
fakeInterpretation: &fakeInterpretation{},
}
}
func (m *nearCustomTextModule) Name() string {
return "mock-custom-near-text-module"
}
func (m *nearCustomTextModule) Init(params moduletools.ModuleInitParams) error {
return nil
}
func (m *nearCustomTextModule) RootHandler() http.Handler {
return nil
}
func (m *nearCustomTextModule) getNearCustomTextArgument(classname string) *graphql.ArgumentConfig {
prefix := classname
return &graphql.ArgumentConfig{
Type: graphql.NewInputObject(
graphql.InputObjectConfig{
Name: fmt.Sprintf("%sNearCustomTextInpObj", prefix),
Fields: graphql.InputObjectConfigFieldMap{
"concepts": &graphql.InputObjectFieldConfig{
Type: graphql.NewNonNull(graphql.NewList(graphql.String)),
},
"moveTo": &graphql.InputObjectFieldConfig{
Description: descriptions.VectorMovement,
Type: graphql.NewInputObject(
graphql.InputObjectConfig{
Name: fmt.Sprintf("%sMoveTo", prefix),
Fields: graphql.InputObjectConfigFieldMap{
"concepts": &graphql.InputObjectFieldConfig{
Description: descriptions.Keywords,
Type: graphql.NewList(graphql.String),
},
"objects": &graphql.InputObjectFieldConfig{
Description: "objects",
Type: graphql.NewList(graphql.NewInputObject(
graphql.InputObjectConfig{
Name: fmt.Sprintf("%sMovementObjectsToInpObj", prefix),
Fields: graphql.InputObjectConfigFieldMap{
"id": &graphql.InputObjectFieldConfig{
Type: graphql.String,
Description: "id of an object",
},
"beacon": &graphql.InputObjectFieldConfig{
Type: graphql.String,
Description: descriptions.Beacon,
},
},
Description: "Movement Object",
},
)),
},
"force": &graphql.InputObjectFieldConfig{
Description: descriptions.Force,
Type: graphql.NewNonNull(graphql.Float),
},
},
}),
},
"moveAwayFrom": &graphql.InputObjectFieldConfig{
Description: descriptions.VectorMovement,
Type: graphql.NewInputObject(
graphql.InputObjectConfig{
Name: fmt.Sprintf("%sMoveAway", prefix),
Fields: graphql.InputObjectConfigFieldMap{
"concepts": &graphql.InputObjectFieldConfig{
Description: descriptions.Keywords,
Type: graphql.NewList(graphql.String),
},
"objects": &graphql.InputObjectFieldConfig{
Description: "objects",
Type: graphql.NewList(graphql.NewInputObject(
graphql.InputObjectConfig{
Name: fmt.Sprintf("%sMovementObjectsAwayInpObj", prefix),
Fields: graphql.InputObjectConfigFieldMap{
"id": &graphql.InputObjectFieldConfig{
Type: graphql.String,
Description: "id of an object",
},
"beacon": &graphql.InputObjectFieldConfig{
Type: graphql.String,
Description: descriptions.Beacon,
},
},
Description: "Movement Object",
},
)),
},
"force": &graphql.InputObjectFieldConfig{
Description: descriptions.Force,
Type: graphql.NewNonNull(graphql.Float),
},
},
}),
},
"certainty": &graphql.InputObjectFieldConfig{
Description: descriptions.Certainty,
Type: graphql.Float,
},
"distance": &graphql.InputObjectFieldConfig{
Description: descriptions.Distance,
Type: graphql.Float,
},
},
Description: descriptions.GetWhereInpObj,
},
),
}
}
func (m *nearCustomTextModule) extractNearCustomTextArgument(source map[string]interface{}) *nearCustomTextParams {
var args nearCustomTextParams
concepts := source["concepts"].([]interface{})
args.Values = make([]string, len(concepts))
for i, value := range concepts {
args.Values[i] = value.(string)
}
certainty, ok := source["certainty"]
if ok {
args.Certainty = certainty.(float64)
}
distance, ok := source["distance"]
if ok {
args.Distance = distance.(float64)
args.WithDistance = true
}
// moveTo is an optional arg, so it could be nil
moveTo, ok := source["moveTo"]
if ok {
moveToMap := moveTo.(map[string]interface{})
args.MoveTo = m.parseMoveParam(moveToMap)
}
moveAwayFrom, ok := source["moveAwayFrom"]
if ok {
moveAwayFromMap := moveAwayFrom.(map[string]interface{})
args.MoveAwayFrom = m.parseMoveParam(moveAwayFromMap)
}
return &args
}
func (m *nearCustomTextModule) parseMoveParam(source map[string]interface{}) nearExploreMove {
res := nearExploreMove{}
res.Force = float32(source["force"].(float64))
concepts, ok := source["concepts"].([]interface{})
if ok {
res.Values = make([]string, len(concepts))
for i, value := range concepts {
res.Values[i] = value.(string)
}
}
objects, ok := source["objects"].([]interface{})
if ok {
res.Objects = make([]nearObjectMove, len(objects))
for i, value := range objects {
v, ok := value.(map[string]interface{})
if ok {
if v["id"] != nil {
res.Objects[i].ID = v["id"].(string)
}
if v["beacon"] != nil {
res.Objects[i].Beacon = v["beacon"].(string)
}
}
}
}
return res
}
func (m *nearCustomTextModule) Arguments() map[string]modulecapabilities.GraphQLArgument {
arguments := map[string]modulecapabilities.GraphQLArgument{}
// define nearCustomText argument
arguments["nearCustomText"] = modulecapabilities.GraphQLArgument{
GetArgumentsFunction: func(classname string) *graphql.ArgumentConfig {
return m.getNearCustomTextArgument(classname)
},
ExtractFunction: func(source map[string]interface{}) interface{} {
return m.extractNearCustomTextArgument(source)
},
ValidateFunction: func(param interface{}) error {
// all is valid
return nil
},
}
return arguments
}
// additional properties
func (m *nearCustomTextModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
additionalProperties := map[string]modulecapabilities.AdditionalProperty{}
additionalProperties["featureProjection"] = m.getFeatureProjection()
additionalProperties["nearestNeighbors"] = m.getNearestNeighbors()
additionalProperties["semanticPath"] = m.getSemanticPath()
additionalProperties["interpretation"] = m.getInterpretation()
return additionalProperties
}
func (m *nearCustomTextModule) getFeatureProjection() modulecapabilities.AdditionalProperty {
return modulecapabilities.AdditionalProperty{
DefaultValue: m.fakeProjector.AdditonalPropertyDefaultValue(),
GraphQLNames: []string{"featureProjection"},
GraphQLFieldFunction: func(classname string) *graphql.Field {
return &graphql.Field{
Args: graphql.FieldConfigArgument{
"algorithm": &graphql.ArgumentConfig{
Type: graphql.String,
DefaultValue: nil,
},
"dimensions": &graphql.ArgumentConfig{
Type: graphql.Int,
DefaultValue: nil,
},
"learningRate": &graphql.ArgumentConfig{
Type: graphql.Int,
DefaultValue: nil,
},
"iterations": &graphql.ArgumentConfig{
Type: graphql.Int,
DefaultValue: nil,
},
"perplexity": &graphql.ArgumentConfig{
Type: graphql.Int,
DefaultValue: nil,
},
},
Type: graphql.NewObject(graphql.ObjectConfig{
Name: fmt.Sprintf("%sAdditionalFeatureProjection", classname),
Fields: graphql.Fields{
"vector": &graphql.Field{Type: graphql.NewList(graphql.Float)},
},
}),
}
},
GraphQLExtractFunction: m.fakeProjector.ExtractAdditionalFn,
}
}
func (m *nearCustomTextModule) getNearestNeighbors() modulecapabilities.AdditionalProperty {
return modulecapabilities.AdditionalProperty{
DefaultValue: m.fakeExtender.AdditonalPropertyDefaultValue(),
GraphQLNames: []string{"nearestNeighbors"},
GraphQLFieldFunction: func(classname string) *graphql.Field {
return &graphql.Field{
Type: graphql.NewObject(graphql.ObjectConfig{
Name: fmt.Sprintf("%sAdditionalNearestNeighbors", classname),
Fields: graphql.Fields{
"neighbors": &graphql.Field{Type: graphql.NewList(graphql.NewObject(graphql.ObjectConfig{
Name: fmt.Sprintf("%sAdditionalNearestNeighborsNeighbors", classname),
Fields: graphql.Fields{
"concept": &graphql.Field{Type: graphql.String},
"distance": &graphql.Field{Type: graphql.Float},
},
}))},
},
}),
}
},
GraphQLExtractFunction: m.fakeExtender.ExtractAdditionalFn,
}
}
func (m *nearCustomTextModule) getSemanticPath() modulecapabilities.AdditionalProperty {
return modulecapabilities.AdditionalProperty{
DefaultValue: m.fakePathBuilder.AdditonalPropertyDefaultValue(),
GraphQLNames: []string{"semanticPath"},
GraphQLFieldFunction: func(classname string) *graphql.Field {
return &graphql.Field{
Type: graphql.NewObject(graphql.ObjectConfig{
Name: fmt.Sprintf("%sAdditionalSemanticPath", classname),
Fields: graphql.Fields{
"path": &graphql.Field{Type: graphql.NewList(graphql.NewObject(graphql.ObjectConfig{
Name: fmt.Sprintf("%sAdditionalSemanticPathElement", classname),
Fields: graphql.Fields{
"concept": &graphql.Field{Type: graphql.String},
"distanceToQuery": &graphql.Field{Type: graphql.Float},
"distanceToResult": &graphql.Field{Type: graphql.Float},
"distanceToNext": &graphql.Field{Type: graphql.Float},
"distanceToPrevious": &graphql.Field{Type: graphql.Float},
},
}))},
},
}),
}
},
GraphQLExtractFunction: m.fakePathBuilder.ExtractAdditionalFn,
}
}
func (m *nearCustomTextModule) getInterpretation() modulecapabilities.AdditionalProperty {
return modulecapabilities.AdditionalProperty{
DefaultValue: m.fakeInterpretation.AdditonalPropertyDefaultValue(),
GraphQLNames: []string{"interpretation"},
GraphQLFieldFunction: func(classname string) *graphql.Field {
return &graphql.Field{
Type: graphql.NewObject(graphql.ObjectConfig{
Name: fmt.Sprintf("%sAdditionalInterpretation", classname),
Fields: graphql.Fields{
"source": &graphql.Field{Type: graphql.NewList(graphql.NewObject(graphql.ObjectConfig{
Name: fmt.Sprintf("%sAdditionalInterpretationSource", classname),
Fields: graphql.Fields{
"concept": &graphql.Field{Type: graphql.String},
"weight": &graphql.Field{Type: graphql.Float},
"occurrence": &graphql.Field{Type: graphql.Int},
},
}))},
},
}),
}
},
GraphQLExtractFunction: m.fakeInterpretation.ExtractAdditionalFn,
}
}
type fakeModulesProvider struct {
nearCustomTextModule *nearCustomTextModule
}
func newFakeModulesProvider() *fakeModulesProvider {
return &fakeModulesProvider{newNearCustomTextModule()}
}
func (fmp *fakeModulesProvider) GetAll() []modulecapabilities.Module {
panic("implement me")
}
func (fmp *fakeModulesProvider) VectorFromInput(ctx context.Context, className string, input string) ([]float32, error) {
panic("not implemented")
}
func (fmp *fakeModulesProvider) GetArguments(class *models.Class) map[string]*graphql.ArgumentConfig {
args := map[string]*graphql.ArgumentConfig{}
if class.Vectorizer == fmp.nearCustomTextModule.Name() {
for name, argument := range fmp.nearCustomTextModule.Arguments() {
args[name] = argument.GetArgumentsFunction(class.Class)
}
}
return args
}
func (fmp *fakeModulesProvider) ExtractSearchParams(arguments map[string]interface{}, className string) map[string]interface{} {
exractedParams := map[string]interface{}{}
if param, ok := arguments["nearCustomText"]; ok {
exractedParams["nearCustomText"] = extractNearTextParam(param.(map[string]interface{}))
}
return exractedParams
}
func (fmp *fakeModulesProvider) GetAdditionalFields(class *models.Class) map[string]*graphql.Field {
additionalProperties := map[string]*graphql.Field{}
for name, additionalProperty := range fmp.nearCustomTextModule.AdditionalProperties() {
if additionalProperty.GraphQLFieldFunction != nil {
additionalProperties[name] = additionalProperty.GraphQLFieldFunction(class.Class)
}
}
return additionalProperties
}
func (fmp *fakeModulesProvider) ExtractAdditionalField(className, name string, params []*ast.Argument) interface{} {
if additionalProperties := fmp.nearCustomTextModule.AdditionalProperties(); len(additionalProperties) > 0 {
if additionalProperty, ok := additionalProperties[name]; ok {
if additionalProperty.GraphQLExtractFunction != nil {
return additionalProperty.GraphQLExtractFunction(params)
}
}
}
return nil
}
func (fmp *fakeModulesProvider) GraphQLAdditionalFieldNames() []string {
additionalPropertiesNames := []string{}
for _, additionalProperty := range fmp.nearCustomTextModule.AdditionalProperties() {
if additionalProperty.GraphQLNames != nil {
additionalPropertiesNames = append(additionalPropertiesNames, additionalProperty.GraphQLNames...)
}
}
return additionalPropertiesNames
}
func extractNearTextParam(param map[string]interface{}) interface{} {
nearCustomTextModule := newNearCustomTextModule()
argument := nearCustomTextModule.Arguments()["nearCustomText"]
return argument.ExtractFunction(param)
}
func createArg(name string, value string) *ast.Argument {
n := ast.Name{
Value: name,
}
val := ast.StringValue{
Kind: "Kind",
Value: value,
}
arg := ast.Argument{
Name: ast.NewName(&n),
Kind: "Kind",
Value: ast.NewStringValue(&val),
}
a := ast.NewArgument(&arg)
return a
}
func extractAdditionalParam(name string, args []*ast.Argument) interface{} {
nearCustomTextModule := newNearCustomTextModule()
additionalProperties := nearCustomTextModule.AdditionalProperties()
switch name {
case "semanticPath", "featureProjection":
if ap, ok := additionalProperties[name]; ok {
return ap.GraphQLExtractFunction(args)
}
return nil
default:
return nil
}
}
func getFakeModulesProvider() ModulesProvider {
return newFakeModulesProvider()
}
func newMockResolver() *mockResolver {
return newMockResolverWithVectorizer(config.VectorizerModuleText2VecContextionary)
}
func newMockResolverWithVectorizer(vectorizer string) *mockResolver {
logger, _ := test.NewNullLogger()
simpleSchema := test_helper.CreateSimpleSchema(vectorizer)
field, err := Build(&simpleSchema, logger, getFakeModulesProvider())
if err != nil {
panic(fmt.Sprintf("could not build graphql test schema: %s", err))
}
mocker := &mockResolver{}
mockLog := &mockRequestsLog{}
mocker.RootFieldName = "Get"
mocker.RootField = field
mocker.RootObject = map[string]interface{}{"Resolver": Resolver(mocker), "RequestsLog": RequestsLog(mockLog)}
return mocker
}
func newMockResolverWithNoModules() *mockResolver {
logger, _ := test.NewNullLogger()
field, err := Build(&test_helper.SimpleSchema, logger, nil)
if err != nil {
panic(fmt.Sprintf("could not build graphql test schema: %s", err))
}
mocker := &mockResolver{}
mockLog := &mockRequestsLog{}
mocker.RootFieldName = "Get"
mocker.RootField = field
mocker.RootObject = map[string]interface{}{"Resolver": Resolver(mocker), "RequestsLog": RequestsLog(mockLog)}
return mocker
}
func (m *mockResolver) GetClass(ctx context.Context, principal *models.Principal,
params dto.GetParams,
) ([]interface{}, error) {
args := m.Called(params)
return args.Get(0).([]interface{}), args.Error(1)
}