Spaces:
Running
Running
File size: 4,709 Bytes
b110593 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package modules
import (
"context"
"testing"
"github.com/go-openapi/strfmt"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/modulecapabilities"
"github.com/weaviate/weaviate/entities/moduletools"
"github.com/weaviate/weaviate/entities/schema"
)
func TestModulesWithSearchers(t *testing.T) {
sch := schema.Schema{
Objects: &models.Schema{
Classes: []*models.Class{
{
Class: "MyClass",
Vectorizer: "mod",
ModuleConfig: map[string]interface{}{
"mod": map[string]interface{}{
"some-config": "some-config-value",
},
},
},
},
},
}
logger, _ := test.NewNullLogger()
t.Run("get a vector for a class", func(t *testing.T) {
p := NewProvider()
p.SetSchemaGetter(&fakeSchemaGetter{
schema: sch,
})
p.Register(newSearcherModule("mod").
withArg("nearGrape").
withSearcher("nearGrape", func(ctx context.Context, params interface{},
className string,
findVectorFn modulecapabilities.FindVectorFn,
cfg moduletools.ClassConfig,
) ([]float32, error) {
// verify that the config tool is set, as this is a per-class search,
// so it must be set
assert.NotNil(t, cfg)
// take the findVectorFn and append one dimension. This doesn't make too
// much sense, but helps verify that the modules method was used in the
// decisions
initial, _ := findVectorFn(ctx, "class", "123", "")
return append(initial, 4), nil
}),
)
p.Init(context.Background(), nil, logger)
res, err := p.VectorFromSearchParam(context.Background(), "MyClass",
"nearGrape", nil, fakeFindVector, "")
require.Nil(t, err)
assert.Equal(t, []float32{1, 2, 3, 4}, res)
})
t.Run("get a vector across classes", func(t *testing.T) {
p := NewProvider()
p.SetSchemaGetter(&fakeSchemaGetter{
schema: sch,
})
p.Register(newSearcherModule("mod").
withArg("nearGrape").
withSearcher("nearGrape", func(ctx context.Context, params interface{},
className string,
findVectorFn modulecapabilities.FindVectorFn,
cfg moduletools.ClassConfig,
) ([]float32, error) {
// this is a cross-class search, such as is used for Explore{}, in this
// case we do not have class-based config, but we need at least pass
// a tenant information, that's why we pass an empty config with empty tenant
// so that it would be possible to perform cross class searches, without
// tenant context. Modules must be able to deal with this situation!
assert.NotNil(t, cfg)
assert.Equal(t, "", cfg.Tenant())
// take the findVectorFn and append one dimension. This doesn't make too
// much sense, but helps verify that the modules method was used in the
// decisions
initial, _ := findVectorFn(ctx, "class", "123", "")
return append(initial, 4), nil
}),
)
p.Init(context.Background(), nil, logger)
res, err := p.CrossClassVectorFromSearchParam(context.Background(),
"nearGrape", nil, fakeFindVector)
require.Nil(t, err)
assert.Equal(t, []float32{1, 2, 3, 4}, res)
})
}
func fakeFindVector(ctx context.Context, className string, id strfmt.UUID, tenant string) ([]float32, error) {
return []float32{1, 2, 3}, nil
}
func newSearcherModule(name string) *dummySearcherModule {
return &dummySearcherModule{
dummyGraphQLModule: newGraphQLModule(name),
searchers: map[string]modulecapabilities.VectorForParams{},
}
}
type dummySearcherModule struct {
*dummyGraphQLModule
searchers map[string]modulecapabilities.VectorForParams
}
func (m *dummySearcherModule) withArg(arg string) *dummySearcherModule {
// call the super's withArg
m.dummyGraphQLModule.withArg(arg)
// but don't return their return type but ours :)
return m
}
// a helper for our test
func (m *dummySearcherModule) withSearcher(arg string,
impl modulecapabilities.VectorForParams,
) *dummySearcherModule {
m.searchers[arg] = impl
return m
}
// public method to implement the modulecapabilities.Searcher interface
func (m *dummySearcherModule) VectorSearches() map[string]modulecapabilities.VectorForParams {
return m.searchers
}
|