File size: 4,709 Bytes
b110593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
//                           _       _
// __      _____  __ ___   ___  __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
//  CONTACT: [email protected]
//

package modules

import (
	"context"
	"testing"

	"github.com/go-openapi/strfmt"
	"github.com/sirupsen/logrus/hooks/test"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"github.com/weaviate/weaviate/entities/models"
	"github.com/weaviate/weaviate/entities/modulecapabilities"
	"github.com/weaviate/weaviate/entities/moduletools"
	"github.com/weaviate/weaviate/entities/schema"
)

func TestModulesWithSearchers(t *testing.T) {
	sch := schema.Schema{
		Objects: &models.Schema{
			Classes: []*models.Class{
				{
					Class:      "MyClass",
					Vectorizer: "mod",
					ModuleConfig: map[string]interface{}{
						"mod": map[string]interface{}{
							"some-config": "some-config-value",
						},
					},
				},
			},
		},
	}
	logger, _ := test.NewNullLogger()

	t.Run("get a vector for a class", func(t *testing.T) {
		p := NewProvider()
		p.SetSchemaGetter(&fakeSchemaGetter{
			schema: sch,
		})
		p.Register(newSearcherModule("mod").
			withArg("nearGrape").
			withSearcher("nearGrape", func(ctx context.Context, params interface{},

				className string,

				findVectorFn modulecapabilities.FindVectorFn,

				cfg moduletools.ClassConfig,

			) ([]float32, error) {
				// verify that the config tool is set, as this is a per-class search,
				// so it must be set
				assert.NotNil(t, cfg)

				// take the findVectorFn and append one dimension. This doesn't make too
				// much sense, but helps verify that the modules method was used in the
				// decisions
				initial, _ := findVectorFn(ctx, "class", "123", "")
				return append(initial, 4), nil
			}),
		)
		p.Init(context.Background(), nil, logger)

		res, err := p.VectorFromSearchParam(context.Background(), "MyClass",
			"nearGrape", nil, fakeFindVector, "")

		require.Nil(t, err)
		assert.Equal(t, []float32{1, 2, 3, 4}, res)
	})

	t.Run("get a vector across classes", func(t *testing.T) {
		p := NewProvider()
		p.SetSchemaGetter(&fakeSchemaGetter{
			schema: sch,
		})
		p.Register(newSearcherModule("mod").
			withArg("nearGrape").
			withSearcher("nearGrape", func(ctx context.Context, params interface{},

				className string,

				findVectorFn modulecapabilities.FindVectorFn,

				cfg moduletools.ClassConfig,

			) ([]float32, error) {
				// this is a cross-class search, such as is used for Explore{}, in this
				// case we do not have class-based config, but we need at least pass
				// a tenant information, that's why we pass an empty config with empty tenant
				// so that it would be possible to perform cross class searches, without
				// tenant context. Modules must be able to deal with this situation!
				assert.NotNil(t, cfg)
				assert.Equal(t, "", cfg.Tenant())

				// take the findVectorFn and append one dimension. This doesn't make too
				// much sense, but helps verify that the modules method was used in the
				// decisions
				initial, _ := findVectorFn(ctx, "class", "123", "")
				return append(initial, 4), nil
			}),
		)
		p.Init(context.Background(), nil, logger)

		res, err := p.CrossClassVectorFromSearchParam(context.Background(),
			"nearGrape", nil, fakeFindVector)

		require.Nil(t, err)
		assert.Equal(t, []float32{1, 2, 3, 4}, res)
	})
}

func fakeFindVector(ctx context.Context, className string, id strfmt.UUID, tenant string) ([]float32, error) {
	return []float32{1, 2, 3}, nil
}

func newSearcherModule(name string) *dummySearcherModule {
	return &dummySearcherModule{
		dummyGraphQLModule: newGraphQLModule(name),
		searchers:          map[string]modulecapabilities.VectorForParams{},
	}
}

type dummySearcherModule struct {
	*dummyGraphQLModule
	searchers map[string]modulecapabilities.VectorForParams
}

func (m *dummySearcherModule) withArg(arg string) *dummySearcherModule {
	// call the super's withArg
	m.dummyGraphQLModule.withArg(arg)

	// but don't return their return type but ours :)
	return m
}

// a helper for our test
func (m *dummySearcherModule) withSearcher(arg string,
	impl modulecapabilities.VectorForParams,
) *dummySearcherModule {
	m.searchers[arg] = impl
	return m
}

// public method to implement the modulecapabilities.Searcher interface
func (m *dummySearcherModule) VectorSearches() map[string]modulecapabilities.VectorForParams {
	return m.searchers
}