Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package modules | |
import ( | |
"context" | |
"testing" | |
"github.com/go-openapi/strfmt" | |
"github.com/sirupsen/logrus/hooks/test" | |
"github.com/stretchr/testify/assert" | |
"github.com/stretchr/testify/require" | |
"github.com/weaviate/weaviate/entities/models" | |
"github.com/weaviate/weaviate/entities/modulecapabilities" | |
"github.com/weaviate/weaviate/entities/moduletools" | |
"github.com/weaviate/weaviate/entities/schema" | |
) | |
func TestModulesWithSearchers(t *testing.T) { | |
sch := schema.Schema{ | |
Objects: &models.Schema{ | |
Classes: []*models.Class{ | |
{ | |
Class: "MyClass", | |
Vectorizer: "mod", | |
ModuleConfig: map[string]interface{}{ | |
"mod": map[string]interface{}{ | |
"some-config": "some-config-value", | |
}, | |
}, | |
}, | |
}, | |
}, | |
} | |
logger, _ := test.NewNullLogger() | |
t.Run("get a vector for a class", func(t *testing.T) { | |
p := NewProvider() | |
p.SetSchemaGetter(&fakeSchemaGetter{ | |
schema: sch, | |
}) | |
p.Register(newSearcherModule("mod"). | |
withArg("nearGrape"). | |
withSearcher("nearGrape", func(ctx context.Context, params interface{}, | |
className string, | |
findVectorFn modulecapabilities.FindVectorFn, | |
cfg moduletools.ClassConfig, | |
) ([]float32, error) { | |
// verify that the config tool is set, as this is a per-class search, | |
// so it must be set | |
assert.NotNil(t, cfg) | |
// take the findVectorFn and append one dimension. This doesn't make too | |
// much sense, but helps verify that the modules method was used in the | |
// decisions | |
initial, _ := findVectorFn(ctx, "class", "123", "") | |
return append(initial, 4), nil | |
}), | |
) | |
p.Init(context.Background(), nil, logger) | |
res, err := p.VectorFromSearchParam(context.Background(), "MyClass", | |
"nearGrape", nil, fakeFindVector, "") | |
require.Nil(t, err) | |
assert.Equal(t, []float32{1, 2, 3, 4}, res) | |
}) | |
t.Run("get a vector across classes", func(t *testing.T) { | |
p := NewProvider() | |
p.SetSchemaGetter(&fakeSchemaGetter{ | |
schema: sch, | |
}) | |
p.Register(newSearcherModule("mod"). | |
withArg("nearGrape"). | |
withSearcher("nearGrape", func(ctx context.Context, params interface{}, | |
className string, | |
findVectorFn modulecapabilities.FindVectorFn, | |
cfg moduletools.ClassConfig, | |
) ([]float32, error) { | |
// this is a cross-class search, such as is used for Explore{}, in this | |
// case we do not have class-based config, but we need at least pass | |
// a tenant information, that's why we pass an empty config with empty tenant | |
// so that it would be possible to perform cross class searches, without | |
// tenant context. Modules must be able to deal with this situation! | |
assert.NotNil(t, cfg) | |
assert.Equal(t, "", cfg.Tenant()) | |
// take the findVectorFn and append one dimension. This doesn't make too | |
// much sense, but helps verify that the modules method was used in the | |
// decisions | |
initial, _ := findVectorFn(ctx, "class", "123", "") | |
return append(initial, 4), nil | |
}), | |
) | |
p.Init(context.Background(), nil, logger) | |
res, err := p.CrossClassVectorFromSearchParam(context.Background(), | |
"nearGrape", nil, fakeFindVector) | |
require.Nil(t, err) | |
assert.Equal(t, []float32{1, 2, 3, 4}, res) | |
}) | |
} | |
func fakeFindVector(ctx context.Context, className string, id strfmt.UUID, tenant string) ([]float32, error) { | |
return []float32{1, 2, 3}, nil | |
} | |
func newSearcherModule(name string) *dummySearcherModule { | |
return &dummySearcherModule{ | |
dummyGraphQLModule: newGraphQLModule(name), | |
searchers: map[string]modulecapabilities.VectorForParams{}, | |
} | |
} | |
type dummySearcherModule struct { | |
*dummyGraphQLModule | |
searchers map[string]modulecapabilities.VectorForParams | |
} | |
func (m *dummySearcherModule) withArg(arg string) *dummySearcherModule { | |
// call the super's withArg | |
m.dummyGraphQLModule.withArg(arg) | |
// but don't return their return type but ours :) | |
return m | |
} | |
// a helper for our test | |
func (m *dummySearcherModule) withSearcher(arg string, | |
impl modulecapabilities.VectorForParams, | |
) *dummySearcherModule { | |
m.searchers[arg] = impl | |
return m | |
} | |
// public method to implement the modulecapabilities.Searcher interface | |
func (m *dummySearcherModule) VectorSearches() map[string]modulecapabilities.VectorForParams { | |
return m.searchers | |
} | |