MVPilgrim commited on
Commit
6b502ec
·
1 Parent(s): 228dfde
modules/text2vec-transformers/clients/meta.go ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package clients
13
+
14
+ import (
15
+ "context"
16
+ "encoding/json"
17
+ "fmt"
18
+ "io"
19
+ "net/http"
20
+ "strings"
21
+ "sync"
22
+
23
+ "github.com/pkg/errors"
24
+ )
25
+
26
+ func (v *vectorizer) MetaInfo() (map[string]interface{}, error) {
27
+ type nameMetaErr struct {
28
+ name string
29
+ meta map[string]interface{}
30
+ err error
31
+ }
32
+
33
+ endpoints := map[string]string{}
34
+ if v.originPassage != v.originQuery {
35
+ endpoints["passage"] = v.urlPassage("/meta")
36
+ endpoints["query"] = v.urlQuery("/meta")
37
+ } else {
38
+ endpoints[""] = v.urlPassage("/meta")
39
+ }
40
+
41
+ var wg sync.WaitGroup
42
+ ch := make(chan nameMetaErr, len(endpoints))
43
+ for serviceName, endpoint := range endpoints {
44
+ wg.Add(1)
45
+ go func(serviceName string, endpoint string) {
46
+ defer wg.Done()
47
+ meta, err := v.metaInfo(endpoint)
48
+ ch <- nameMetaErr{serviceName, meta, err}
49
+ }(serviceName, endpoint)
50
+ }
51
+ wg.Wait()
52
+ close(ch)
53
+
54
+ metas := map[string]interface{}{}
55
+ var errs []string
56
+ for nme := range ch {
57
+ if nme.err != nil {
58
+ prefix := ""
59
+ if nme.name != "" {
60
+ prefix = "[" + nme.name + "] "
61
+ }
62
+ errs = append(errs, fmt.Sprintf("%s%v", prefix, nme.err.Error()))
63
+ }
64
+ if nme.meta != nil {
65
+ metas[nme.name] = nme.meta
66
+ }
67
+ }
68
+
69
+ if len(errs) > 0 {
70
+ return nil, errors.Errorf(strings.Join(errs, ", "))
71
+ }
72
+ if len(metas) == 1 {
73
+ for _, meta := range metas {
74
+ return meta.(map[string]interface{}), nil
75
+ }
76
+ }
77
+ return metas, nil
78
+ }
79
+
80
+ func (v *vectorizer) metaInfo(endpoint string) (map[string]interface{}, error) {
81
+ req, err := http.NewRequestWithContext(context.Background(), "GET", endpoint, nil)
82
+ if err != nil {
83
+ return nil, errors.Wrap(err, "create GET meta request")
84
+ }
85
+
86
+ res, err := v.httpClient.Do(req)
87
+ if err != nil {
88
+ return nil, errors.Wrap(err, "send GET meta request")
89
+ }
90
+ defer res.Body.Close()
91
+ if !(res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices) {
92
+ return nil, errors.Errorf("unexpected status code '%d' of meta request", res.StatusCode)
93
+ }
94
+
95
+ bodyBytes, err := io.ReadAll(res.Body)
96
+ if err != nil {
97
+ return nil, errors.Wrap(err, "read meta response body")
98
+ }
99
+
100
+ var resBody map[string]interface{}
101
+ if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
102
+ return nil, errors.Wrap(err, "unmarshal meta response body")
103
+ }
104
+ return resBody, nil
105
+ }
modules/text2vec-transformers/clients/meta_test.go ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package clients
13
+
14
+ import (
15
+ "net/http"
16
+ "net/http/httptest"
17
+ "testing"
18
+ "time"
19
+
20
+ "github.com/stretchr/testify/assert"
21
+ )
22
+
23
+ func TestGetMeta(t *testing.T) {
24
+ t.Run("when common server is providing meta", func(t *testing.T) {
25
+ server := httptest.NewServer(&testMetaHandler{t: t})
26
+ defer server.Close()
27
+ v := New(server.URL, server.URL, 0, nullLogger())
28
+ meta, err := v.MetaInfo()
29
+
30
+ assert.Nil(t, err)
31
+ assert.NotNil(t, meta)
32
+
33
+ model := extractChildMap(t, meta, "model")
34
+ assert.NotNil(t, model["_name_or_path"])
35
+ assert.NotNil(t, model["architectures"])
36
+ assert.Contains(t, model["architectures"], "DistilBertModel")
37
+ ID2Label := extractChildMap(t, model, "id2label")
38
+ assert.NotNil(t, ID2Label["0"])
39
+ assert.NotNil(t, ID2Label["1"])
40
+ })
41
+
42
+ t.Run("when passage and query servers are providing meta", func(t *testing.T) {
43
+ serverPassage := httptest.NewServer(&testMetaHandler{t: t, modelType: "passage"})
44
+ serverQuery := httptest.NewServer(&testMetaHandler{t: t, modelType: "query"})
45
+ defer serverPassage.Close()
46
+ defer serverQuery.Close()
47
+ v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger())
48
+ meta, err := v.MetaInfo()
49
+
50
+ assert.Nil(t, err)
51
+ assert.NotNil(t, meta)
52
+
53
+ passage := extractChildMap(t, meta, "passage")
54
+ passageModel := extractChildMap(t, passage, "model")
55
+ assert.NotNil(t, passageModel["_name_or_path"])
56
+ assert.NotNil(t, passageModel["architectures"])
57
+ assert.Contains(t, passageModel["architectures"], "DPRContextEncoder")
58
+ passageID2Label := extractChildMap(t, passageModel, "id2label")
59
+ assert.NotNil(t, passageID2Label["0"])
60
+ assert.NotNil(t, passageID2Label["1"])
61
+
62
+ query := extractChildMap(t, meta, "query")
63
+ queryModel := extractChildMap(t, query, "model")
64
+ assert.NotNil(t, queryModel["_name_or_path"])
65
+ assert.NotNil(t, queryModel["architectures"])
66
+ assert.Contains(t, queryModel["architectures"], "DPRQuestionEncoder")
67
+ queryID2Label := extractChildMap(t, queryModel, "id2label")
68
+ assert.NotNil(t, queryID2Label["0"])
69
+ assert.NotNil(t, queryID2Label["1"])
70
+ })
71
+
72
+ t.Run("when passage and query servers are unavailable", func(t *testing.T) {
73
+ rt := time.Now().Add(time.Hour)
74
+ serverPassage := httptest.NewServer(&testMetaHandler{t: t, modelType: "passage", readyTime: rt})
75
+ serverQuery := httptest.NewServer(&testMetaHandler{t: t, modelType: "query", readyTime: rt})
76
+ defer serverPassage.Close()
77
+ defer serverQuery.Close()
78
+ v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger())
79
+ meta, err := v.MetaInfo()
80
+
81
+ assert.NotNil(t, err)
82
+ assert.Contains(t, err.Error(), "[passage] unexpected status code '503' of meta request")
83
+ assert.Contains(t, err.Error(), "[query] unexpected status code '503' of meta request")
84
+ assert.Nil(t, meta)
85
+ })
86
+ }
87
+
88
+ type testMetaHandler struct {
89
+ t *testing.T
90
+ // the test handler will report as not ready before the time has passed
91
+ readyTime time.Time
92
+ modelType string
93
+ }
94
+
95
+ func (h *testMetaHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
96
+ assert.Equal(h.t, "/meta", r.URL.String())
97
+ assert.Equal(h.t, http.MethodGet, r.Method)
98
+
99
+ if time.Since(h.readyTime) < 0 {
100
+ w.WriteHeader(http.StatusServiceUnavailable)
101
+ return
102
+ }
103
+
104
+ w.Write([]byte(h.metaInfo()))
105
+ }
106
+
107
+ func (h *testMetaHandler) metaInfo() string {
108
+ switch h.modelType {
109
+ case "passage":
110
+ return `{
111
+ "model": {
112
+ "return_dict": true,
113
+ "output_hidden_states": false,
114
+ "output_attentions": false,
115
+ "torchscript": false,
116
+ "torch_dtype": "float32",
117
+ "use_bfloat16": false,
118
+ "pruned_heads": {},
119
+ "tie_word_embeddings": true,
120
+ "is_encoder_decoder": false,
121
+ "is_decoder": false,
122
+ "cross_attention_hidden_size": null,
123
+ "add_cross_attention": false,
124
+ "tie_encoder_decoder": false,
125
+ "max_length": 20,
126
+ "min_length": 0,
127
+ "do_sample": false,
128
+ "early_stopping": false,
129
+ "num_beams": 1,
130
+ "num_beam_groups": 1,
131
+ "diversity_penalty": 0,
132
+ "temperature": 1,
133
+ "top_k": 50,
134
+ "top_p": 1,
135
+ "repetition_penalty": 1,
136
+ "length_penalty": 1,
137
+ "no_repeat_ngram_size": 0,
138
+ "encoder_no_repeat_ngram_size": 0,
139
+ "bad_words_ids": null,
140
+ "num_return_sequences": 1,
141
+ "chunk_size_feed_forward": 0,
142
+ "output_scores": false,
143
+ "return_dict_in_generate": false,
144
+ "forced_bos_token_id": null,
145
+ "forced_eos_token_id": null,
146
+ "remove_invalid_values": false,
147
+ "architectures": [
148
+ "DPRContextEncoder"
149
+ ],
150
+ "finetuning_task": null,
151
+ "id2label": {
152
+ "0": "LABEL_0",
153
+ "1": "LABEL_1"
154
+ },
155
+ "label2id": {
156
+ "LABEL_0": 0,
157
+ "LABEL_1": 1
158
+ },
159
+ "tokenizer_class": null,
160
+ "prefix": null,
161
+ "bos_token_id": null,
162
+ "pad_token_id": 0,
163
+ "eos_token_id": null,
164
+ "sep_token_id": null,
165
+ "decoder_start_token_id": null,
166
+ "task_specific_params": null,
167
+ "problem_type": null,
168
+ "_name_or_path": "./models/model",
169
+ "transformers_version": "4.16.2",
170
+ "gradient_checkpointing": false,
171
+ "model_type": "dpr",
172
+ "vocab_size": 30522,
173
+ "hidden_size": 768,
174
+ "num_hidden_layers": 12,
175
+ "num_attention_heads": 12,
176
+ "hidden_act": "gelu",
177
+ "intermediate_size": 3072,
178
+ "hidden_dropout_prob": 0.1,
179
+ "attention_probs_dropout_prob": 0.1,
180
+ "max_position_embeddings": 512,
181
+ "type_vocab_size": 2,
182
+ "initializer_range": 0.02,
183
+ "layer_norm_eps": 1e-12,
184
+ "projection_dim": 0,
185
+ "position_embedding_type": "absolute"
186
+ }
187
+ }`
188
+ case "query":
189
+ return `{
190
+ "model": {
191
+ "return_dict": true,
192
+ "output_hidden_states": false,
193
+ "output_attentions": false,
194
+ "torchscript": false,
195
+ "torch_dtype": "float32",
196
+ "use_bfloat16": false,
197
+ "pruned_heads": {},
198
+ "tie_word_embeddings": true,
199
+ "is_encoder_decoder": false,
200
+ "is_decoder": false,
201
+ "cross_attention_hidden_size": null,
202
+ "add_cross_attention": false,
203
+ "tie_encoder_decoder": false,
204
+ "max_length": 20,
205
+ "min_length": 0,
206
+ "do_sample": false,
207
+ "early_stopping": false,
208
+ "num_beams": 1,
209
+ "num_beam_groups": 1,
210
+ "diversity_penalty": 0,
211
+ "temperature": 1,
212
+ "top_k": 50,
213
+ "top_p": 1,
214
+ "repetition_penalty": 1,
215
+ "length_penalty": 1,
216
+ "no_repeat_ngram_size": 0,
217
+ "encoder_no_repeat_ngram_size": 0,
218
+ "bad_words_ids": null,
219
+ "num_return_sequences": 1,
220
+ "chunk_size_feed_forward": 0,
221
+ "output_scores": false,
222
+ "return_dict_in_generate": false,
223
+ "forced_bos_token_id": null,
224
+ "forced_eos_token_id": null,
225
+ "remove_invalid_values": false,
226
+ "architectures": [
227
+ "DPRQuestionEncoder"
228
+ ],
229
+ "finetuning_task": null,
230
+ "id2label": {
231
+ "0": "LABEL_0",
232
+ "1": "LABEL_1"
233
+ },
234
+ "label2id": {
235
+ "LABEL_0": 0,
236
+ "LABEL_1": 1
237
+ },
238
+ "tokenizer_class": null,
239
+ "prefix": null,
240
+ "bos_token_id": null,
241
+ "pad_token_id": 0,
242
+ "eos_token_id": null,
243
+ "sep_token_id": null,
244
+ "decoder_start_token_id": null,
245
+ "task_specific_params": null,
246
+ "problem_type": null,
247
+ "_name_or_path": "./models/model",
248
+ "transformers_version": "4.16.2",
249
+ "gradient_checkpointing": false,
250
+ "model_type": "dpr",
251
+ "vocab_size": 30522,
252
+ "hidden_size": 768,
253
+ "num_hidden_layers": 12,
254
+ "num_attention_heads": 12,
255
+ "hidden_act": "gelu",
256
+ "intermediate_size": 3072,
257
+ "hidden_dropout_prob": 0.1,
258
+ "attention_probs_dropout_prob": 0.1,
259
+ "max_position_embeddings": 512,
260
+ "type_vocab_size": 2,
261
+ "initializer_range": 0.02,
262
+ "layer_norm_eps": 1e-12,
263
+ "projection_dim": 0,
264
+ "position_embedding_type": "absolute"
265
+ }
266
+ }`
267
+ default:
268
+ return `{
269
+ "model": {
270
+ "_name_or_path": "distilbert-base-uncased",
271
+ "activation": "gelu",
272
+ "add_cross_attention": false,
273
+ "architectures": [
274
+ "DistilBertModel"
275
+ ],
276
+ "attention_dropout": 0.1,
277
+ "bad_words_ids": null,
278
+ "bos_token_id": null,
279
+ "chunk_size_feed_forward": 0,
280
+ "decoder_start_token_id": null,
281
+ "dim": 768,
282
+ "diversity_penalty": 0,
283
+ "do_sample": false,
284
+ "dropout": 0.1,
285
+ "early_stopping": false,
286
+ "encoder_no_repeat_ngram_size": 0,
287
+ "eos_token_id": null,
288
+ "finetuning_task": null,
289
+ "hidden_dim": 3072,
290
+ "id2label": {
291
+ "0": "LABEL_0",
292
+ "1": "LABEL_1"
293
+ },
294
+ "initializer_range": 0.02,
295
+ "is_decoder": false,
296
+ "is_encoder_decoder": false,
297
+ "label2id": {
298
+ "LABEL_0": 0,
299
+ "LABEL_1": 1
300
+ },
301
+ "length_penalty": 1,
302
+ "max_length": 20,
303
+ "max_position_embeddings": 512,
304
+ "min_length": 0,
305
+ "model_type": "distilbert",
306
+ "n_heads": 12,
307
+ "n_layers": 6,
308
+ "no_repeat_ngram_size": 0,
309
+ "num_beam_groups": 1,
310
+ "num_beams": 1,
311
+ "num_return_sequences": 1,
312
+ "output_attentions": false,
313
+ "output_hidden_states": false,
314
+ "output_scores": false,
315
+ "pad_token_id": 0,
316
+ "prefix": null,
317
+ "pruned_heads": {},
318
+ "qa_dropout": 0.1,
319
+ "repetition_penalty": 1,
320
+ "return_dict": true,
321
+ "return_dict_in_generate": false,
322
+ "sep_token_id": null,
323
+ "seq_classif_dropout": 0.2,
324
+ "sinusoidal_pos_embds": false,
325
+ "task_specific_params": null,
326
+ "temperature": 1,
327
+ "tie_encoder_decoder": false,
328
+ "tie_weights_": true,
329
+ "tie_word_embeddings": true,
330
+ "tokenizer_class": null,
331
+ "top_k": 50,
332
+ "top_p": 1,
333
+ "torchscript": false,
334
+ "transformers_version": "4.3.2",
335
+ "use_bfloat16": false,
336
+ "vocab_size": 30522,
337
+ "xla_device": null
338
+ }
339
+ }`
340
+ }
341
+ }
342
+
343
+ func extractChildMap(t *testing.T, parent map[string]interface{}, name string) map[string]interface{} {
344
+ assert.NotNil(t, parent[name])
345
+ child, ok := parent[name].(map[string]interface{})
346
+ assert.True(t, ok)
347
+ assert.NotNil(t, child)
348
+
349
+ return child
350
+ }
modules/text2vec-transformers/clients/startup.go ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package clients
13
+
14
+ import (
15
+ "context"
16
+ "net/http"
17
+ "strings"
18
+ "sync"
19
+ "time"
20
+
21
+ "github.com/pkg/errors"
22
+ )
23
+
24
+ func (v *vectorizer) WaitForStartup(initCtx context.Context,
25
+ interval time.Duration,
26
+ ) error {
27
+ endpoints := map[string]string{}
28
+ if v.originPassage != v.originQuery {
29
+ endpoints["passage"] = v.urlPassage("/.well-known/ready")
30
+ endpoints["query"] = v.urlQuery("/.well-known/ready")
31
+ } else {
32
+ endpoints[""] = v.urlPassage("/.well-known/ready")
33
+ }
34
+
35
+ ch := make(chan error, len(endpoints))
36
+ var wg sync.WaitGroup
37
+ for serviceName, endpoint := range endpoints {
38
+ wg.Add(1)
39
+ go func(serviceName string, endpoint string) {
40
+ defer wg.Done()
41
+ if err := v.waitFor(initCtx, interval, endpoint, serviceName); err != nil {
42
+ ch <- err
43
+ }
44
+ }(serviceName, endpoint)
45
+ }
46
+ wg.Wait()
47
+ close(ch)
48
+
49
+ if len(ch) > 0 {
50
+ var errs []string
51
+ for err := range ch {
52
+ errs = append(errs, err.Error())
53
+ }
54
+ return errors.Errorf(strings.Join(errs, ", "))
55
+ }
56
+ return nil
57
+ }
58
+
59
+ func (v *vectorizer) waitFor(initCtx context.Context, interval time.Duration, endpoint string, serviceName string) error {
60
+ ticker := time.NewTicker(interval)
61
+ defer ticker.Stop()
62
+ expired := initCtx.Done()
63
+ var lastErr error
64
+ prefix := ""
65
+ if serviceName != "" {
66
+ prefix = "[" + serviceName + "] "
67
+ }
68
+
69
+ for {
70
+ select {
71
+ case <-ticker.C:
72
+ lastErr = v.checkReady(initCtx, endpoint, serviceName)
73
+ if lastErr == nil {
74
+ return nil
75
+ }
76
+ v.logger.
77
+ WithField("action", "transformer_remote_wait_for_startup").
78
+ WithError(lastErr).Warnf("%stransformer remote inference service not ready", prefix)
79
+ case <-expired:
80
+ return errors.Wrapf(lastErr, "%sinit context expired before remote was ready", prefix)
81
+ }
82
+ }
83
+ }
84
+
85
+ func (v *vectorizer) checkReady(initCtx context.Context, endpoint string, serviceName string) error {
86
+ // spawn a new context (derived on the overall context) which is used to
87
+ // consider an individual request timed out
88
+ // due to parent timeout being superior over request's one, request can be cancelled by parent timeout
89
+ // resulting in "send check ready request" even if service is responding with non 2xx http code
90
+ requestCtx, cancel := context.WithTimeout(initCtx, 500*time.Millisecond)
91
+ defer cancel()
92
+
93
+ req, err := http.NewRequestWithContext(requestCtx, http.MethodGet, endpoint, nil)
94
+ if err != nil {
95
+ return errors.Wrap(err, "create check ready request")
96
+ }
97
+
98
+ res, err := v.httpClient.Do(req)
99
+ if err != nil {
100
+ return errors.Wrap(err, "send check ready request")
101
+ }
102
+
103
+ defer res.Body.Close()
104
+ if res.StatusCode > 299 {
105
+ return errors.Errorf("not ready: status %d", res.StatusCode)
106
+ }
107
+
108
+ return nil
109
+ }
modules/text2vec-transformers/clients/startup_test.go ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package clients
13
+
14
+ import (
15
+ "context"
16
+ "net/http"
17
+ "net/http/httptest"
18
+ "regexp"
19
+ "strings"
20
+ "testing"
21
+ "time"
22
+
23
+ "github.com/sirupsen/logrus"
24
+ "github.com/sirupsen/logrus/hooks/test"
25
+ "github.com/stretchr/testify/assert"
26
+ "github.com/stretchr/testify/require"
27
+ )
28
+
29
+ func TestWaitForStartup(t *testing.T) {
30
+ t.Run("when common server is immediately ready", func(t *testing.T) {
31
+ server := httptest.NewServer(&testReadyHandler{t: t})
32
+ defer server.Close()
33
+ v := New(server.URL, server.URL, 0, nullLogger())
34
+ err := v.WaitForStartup(context.Background(), 150*time.Millisecond)
35
+
36
+ assert.Nil(t, err)
37
+ })
38
+
39
+ t.Run("when passage and query servers are immediately ready", func(t *testing.T) {
40
+ serverPassage := httptest.NewServer(&testReadyHandler{t: t})
41
+ serverQuery := httptest.NewServer(&testReadyHandler{t: t})
42
+ defer serverPassage.Close()
43
+ defer serverQuery.Close()
44
+ v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger())
45
+ err := v.WaitForStartup(context.Background(), 150*time.Millisecond)
46
+
47
+ assert.Nil(t, err)
48
+ })
49
+
50
+ t.Run("when common server is down", func(t *testing.T) {
51
+ url := "http://nothing-running-at-this-url"
52
+ v := New(url, url, 0, nullLogger())
53
+ ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
54
+ defer cancel()
55
+ err := v.WaitForStartup(ctx, 50*time.Millisecond)
56
+
57
+ require.NotNil(t, err, nullLogger())
58
+ assert.Contains(t, err.Error(), "init context expired before remote was ready: send check ready request")
59
+ assertContainsEither(t, err.Error(), "dial tcp", "context deadline exceeded")
60
+ assert.NotContains(t, err.Error(), "[passage]")
61
+ assert.NotContains(t, err.Error(), "[query]")
62
+ })
63
+
64
+ t.Run("when passage and query servers are down", func(t *testing.T) {
65
+ urlPassage := "http://nothing-running-at-this-url"
66
+ urlQuery := "http://nothing-running-at-this-url-either"
67
+ v := New(urlPassage, urlQuery, 0, nullLogger())
68
+ ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
69
+ defer cancel()
70
+ err := v.WaitForStartup(ctx, 50*time.Millisecond)
71
+
72
+ require.NotNil(t, err, nullLogger())
73
+ assert.Contains(t, err.Error(), "[passage] init context expired before remote was ready: send check ready request")
74
+ assert.Contains(t, err.Error(), "[query] init context expired before remote was ready: send check ready request")
75
+ assertContainsEither(t, err.Error(), "dial tcp", "context deadline exceeded")
76
+ })
77
+
78
+ t.Run("when common server is alive, but not ready", func(t *testing.T) {
79
+ server := httptest.NewServer(&testReadyHandler{
80
+ t: t,
81
+ readyTime: time.Now().Add(time.Hour),
82
+ })
83
+ defer server.Close()
84
+ v := New(server.URL, server.URL, 0, nullLogger())
85
+ ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
86
+ defer cancel()
87
+ err := v.WaitForStartup(ctx, 50*time.Millisecond)
88
+
89
+ require.NotNil(t, err)
90
+ assert.Contains(t, err.Error(), "init context expired before remote was ready")
91
+ assertContainsEither(t, err.Error(), "not ready: status 503", "context deadline exceeded")
92
+ assert.NotContains(t, err.Error(), "[passage]")
93
+ assert.NotContains(t, err.Error(), "[query]")
94
+ })
95
+
96
+ t.Run("when passage and query servers are alive, but not ready", func(t *testing.T) {
97
+ rt := time.Now().Add(time.Hour)
98
+ serverPassage := httptest.NewServer(&testReadyHandler{
99
+ t: t,
100
+ readyTime: rt,
101
+ })
102
+ serverQuery := httptest.NewServer(&testReadyHandler{
103
+ t: t,
104
+ readyTime: rt,
105
+ })
106
+ defer serverPassage.Close()
107
+ defer serverQuery.Close()
108
+ v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger())
109
+ ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
110
+ defer cancel()
111
+ err := v.WaitForStartup(ctx, 50*time.Millisecond)
112
+
113
+ require.NotNil(t, err)
114
+ assert.Contains(t, err.Error(), "[passage] init context expired before remote was ready")
115
+ assert.Contains(t, err.Error(), "[query] init context expired before remote was ready")
116
+ assertContainsEither(t, err.Error(), "not ready: status 503", "context deadline exceeded")
117
+ })
118
+
119
+ t.Run("when passage and query servers are alive, but query one is not ready", func(t *testing.T) {
120
+ serverPassage := httptest.NewServer(&testReadyHandler{t: t})
121
+ serverQuery := httptest.NewServer(&testReadyHandler{
122
+ t: t,
123
+ readyTime: time.Now().Add(1 * time.Minute),
124
+ })
125
+ defer serverPassage.Close()
126
+ defer serverQuery.Close()
127
+ v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger())
128
+ ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
129
+ defer cancel()
130
+ err := v.WaitForStartup(ctx, 50*time.Millisecond)
131
+
132
+ require.NotNil(t, err)
133
+ assert.Contains(t, err.Error(), "[query] init context expired before remote was ready")
134
+ assertContainsEither(t, err.Error(), "not ready: status 503", "context deadline exceeded")
135
+ assert.NotContains(t, err.Error(), "[passage]")
136
+ })
137
+
138
+ t.Run("when common server is initially not ready, but then becomes ready", func(t *testing.T) {
139
+ server := httptest.NewServer(&testReadyHandler{
140
+ t: t,
141
+ readyTime: time.Now().Add(100 * time.Millisecond),
142
+ })
143
+ v := New(server.URL, server.URL, 0, nullLogger())
144
+ defer server.Close()
145
+ ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
146
+ defer cancel()
147
+ err := v.WaitForStartup(ctx, 50*time.Millisecond)
148
+
149
+ require.Nil(t, err)
150
+ })
151
+
152
+ t.Run("when passage and query servers are initially not ready, but then become ready", func(t *testing.T) {
153
+ serverPassage := httptest.NewServer(&testReadyHandler{
154
+ t: t,
155
+ readyTime: time.Now().Add(100 * time.Millisecond),
156
+ })
157
+ serverQuery := httptest.NewServer(&testReadyHandler{
158
+ t: t,
159
+ readyTime: time.Now().Add(150 * time.Millisecond),
160
+ })
161
+ defer serverPassage.Close()
162
+ defer serverQuery.Close()
163
+ v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger())
164
+ ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
165
+ defer cancel()
166
+ err := v.WaitForStartup(ctx, 50*time.Millisecond)
167
+
168
+ require.Nil(t, err)
169
+ })
170
+ }
171
+
172
+ type testReadyHandler struct {
173
+ t *testing.T
174
+ // the test handler will report as not ready before the time has passed
175
+ readyTime time.Time
176
+ }
177
+
178
+ func (f *testReadyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
179
+ assert.Equal(f.t, "/.well-known/ready", r.URL.String())
180
+ assert.Equal(f.t, http.MethodGet, r.Method)
181
+
182
+ if time.Since(f.readyTime) < 0 {
183
+ w.WriteHeader(http.StatusServiceUnavailable)
184
+ } else {
185
+ w.WriteHeader(http.StatusNoContent)
186
+ }
187
+ }
188
+
189
+ func nullLogger() logrus.FieldLogger {
190
+ l, _ := test.NewNullLogger()
191
+ return l
192
+ }
193
+
194
+ func assertContainsEither(t *testing.T, str string, contains ...string) {
195
+ reg := regexp.MustCompile(strings.Join(contains, "|"))
196
+ assert.Regexp(t, reg, str)
197
+ }
modules/text2vec-transformers/clients/transformers.go ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package clients
13
+
14
+ import (
15
+ "bytes"
16
+ "context"
17
+ "encoding/json"
18
+ "fmt"
19
+ "io"
20
+ "net/http"
21
+ "time"
22
+
23
+ "github.com/pkg/errors"
24
+ "github.com/sirupsen/logrus"
25
+ "github.com/weaviate/weaviate/modules/text2vec-transformers/ent"
26
+ )
27
+
28
+ type vectorizer struct {
29
+ originPassage string
30
+ originQuery string
31
+ httpClient *http.Client
32
+ logger logrus.FieldLogger
33
+ }
34
+
35
+ func New(originPassage, originQuery string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer {
36
+ return &vectorizer{
37
+ originPassage: originPassage,
38
+ originQuery: originQuery,
39
+ httpClient: &http.Client{
40
+ Timeout: timeout,
41
+ },
42
+ logger: logger,
43
+ }
44
+ }
45
+
46
+ func (v *vectorizer) VectorizeObject(ctx context.Context, input string,
47
+ config ent.VectorizationConfig,
48
+ ) (*ent.VectorizationResult, error) {
49
+ return v.vectorize(ctx, input, config, v.urlPassage)
50
+ }
51
+
52
+ func (v *vectorizer) VectorizeQuery(ctx context.Context, input string,
53
+ config ent.VectorizationConfig,
54
+ ) (*ent.VectorizationResult, error) {
55
+ return v.vectorize(ctx, input, config, v.urlQuery)
56
+ }
57
+
58
+ func (v *vectorizer) vectorize(ctx context.Context, input string,
59
+ config ent.VectorizationConfig, url func(string) string,
60
+ ) (*ent.VectorizationResult, error) {
61
+ body, err := json.Marshal(vecRequest{
62
+ Text: input,
63
+ Config: vecRequestConfig{
64
+ PoolingStrategy: config.PoolingStrategy,
65
+ },
66
+ })
67
+ if err != nil {
68
+ return nil, errors.Wrapf(err, "marshal body")
69
+ }
70
+
71
+ req, err := http.NewRequestWithContext(ctx, "POST", url("/vectors"),
72
+ bytes.NewReader(body))
73
+ if err != nil {
74
+ return nil, errors.Wrap(err, "create POST request")
75
+ }
76
+
77
+ res, err := v.httpClient.Do(req)
78
+ if err != nil {
79
+ return nil, errors.Wrap(err, "send POST request")
80
+ }
81
+ defer res.Body.Close()
82
+
83
+ bodyBytes, err := io.ReadAll(res.Body)
84
+ if err != nil {
85
+ return nil, errors.Wrap(err, "read response body")
86
+ }
87
+
88
+ var resBody vecRequest
89
+ if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
90
+ return nil, errors.Wrap(err, "unmarshal response body")
91
+ }
92
+
93
+ if res.StatusCode > 399 {
94
+ return nil, errors.Errorf("fail with status %d: %s", res.StatusCode,
95
+ resBody.Error)
96
+ }
97
+
98
+ return &ent.VectorizationResult{
99
+ Text: resBody.Text,
100
+ Dimensions: resBody.Dims,
101
+ Vector: resBody.Vector,
102
+ }, nil
103
+ }
104
+
105
+ func (v *vectorizer) urlPassage(path string) string {
106
+ return fmt.Sprintf("%s%s", v.originPassage, path)
107
+ }
108
+
109
+ func (v *vectorizer) urlQuery(path string) string {
110
+ return fmt.Sprintf("%s%s", v.originQuery, path)
111
+ }
112
+
113
+ type vecRequest struct {
114
+ Text string `json:"text"`
115
+ Dims int `json:"dims"`
116
+ Vector []float32 `json:"vector"`
117
+ Error string `json:"error"`
118
+ Config vecRequestConfig `json:"config"`
119
+ }
120
+
121
+ type vecRequestConfig struct {
122
+ PoolingStrategy string `json:"pooling_strategy"`
123
+ }
modules/text2vec-transformers/clients/transformers_test.go ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package clients
13
+
14
+ import (
15
+ "context"
16
+ "encoding/json"
17
+ "fmt"
18
+ "io"
19
+ "net/http"
20
+ "net/http/httptest"
21
+ "testing"
22
+ "time"
23
+
24
+ "github.com/pkg/errors"
25
+ "github.com/stretchr/testify/assert"
26
+ "github.com/stretchr/testify/require"
27
+ "github.com/weaviate/weaviate/modules/text2vec-transformers/ent"
28
+ )
29
+
30
+ func TestClient(t *testing.T) {
31
+ t.Run("when all is fine", func(t *testing.T) {
32
+ server := httptest.NewServer(&fakeHandler{t: t})
33
+ defer server.Close()
34
+ c := New(server.URL, server.URL, 0, nullLogger())
35
+ expected := &ent.VectorizationResult{
36
+ Text: "This is my text",
37
+ Vector: []float32{0.1, 0.2, 0.3},
38
+ Dimensions: 3,
39
+ }
40
+ res, err := c.VectorizeObject(context.Background(), "This is my text",
41
+ ent.VectorizationConfig{
42
+ PoolingStrategy: "masked_mean",
43
+ })
44
+
45
+ assert.Nil(t, err)
46
+ assert.Equal(t, expected, res)
47
+ })
48
+
49
+ t.Run("when the context is expired", func(t *testing.T) {
50
+ server := httptest.NewServer(&fakeHandler{t: t})
51
+ defer server.Close()
52
+ c := New(server.URL, server.URL, 0, nullLogger())
53
+ ctx, cancel := context.WithDeadline(context.Background(), time.Now())
54
+ defer cancel()
55
+
56
+ _, err := c.VectorizeObject(ctx, "This is my text", ent.VectorizationConfig{})
57
+
58
+ require.NotNil(t, err)
59
+ assert.Contains(t, err.Error(), "context deadline exceeded")
60
+ })
61
+
62
+ t.Run("when the server returns an error", func(t *testing.T) {
63
+ server := httptest.NewServer(&fakeHandler{
64
+ t: t,
65
+ serverError: errors.Errorf("nope, not gonna happen"),
66
+ })
67
+ defer server.Close()
68
+ c := New(server.URL, server.URL, 0, nullLogger())
69
+ _, err := c.VectorizeObject(context.Background(), "This is my text",
70
+ ent.VectorizationConfig{})
71
+
72
+ require.NotNil(t, err)
73
+ assert.Contains(t, err.Error(), "nope, not gonna happen")
74
+ })
75
+ }
76
+
77
+ type fakeHandler struct {
78
+ t *testing.T
79
+ serverError error
80
+ }
81
+
82
+ func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
83
+ assert.Equal(f.t, "/vectors", r.URL.String())
84
+ assert.Equal(f.t, http.MethodPost, r.Method)
85
+
86
+ if f.serverError != nil {
87
+ w.WriteHeader(http.StatusInternalServerError)
88
+ w.Write([]byte(fmt.Sprintf(`{"error":"%s"}`, f.serverError.Error())))
89
+ return
90
+ }
91
+
92
+ bodyBytes, err := io.ReadAll(r.Body)
93
+ require.Nil(f.t, err)
94
+ defer r.Body.Close()
95
+
96
+ var b map[string]interface{}
97
+ require.Nil(f.t, json.Unmarshal(bodyBytes, &b))
98
+
99
+ textInput := b["text"].(string)
100
+ assert.Greater(f.t, len(textInput), 0)
101
+
102
+ pooling := b["config"].(map[string]interface{})["pooling_strategy"].(string)
103
+ assert.Equal(f.t, "masked_mean", pooling)
104
+
105
+ out := map[string]interface{}{
106
+ "text": textInput,
107
+ "dims": 3,
108
+ "vector": []float32{0.1, 0.2, 0.3},
109
+ }
110
+ outBytes, err := json.Marshal(out)
111
+ require.Nil(f.t, err)
112
+
113
+ w.Write(outBytes)
114
+ }
modules/text2vec-transformers/config.go ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package modtransformers
13
+
14
+ import (
15
+ "context"
16
+ "fmt"
17
+
18
+ "github.com/pkg/errors"
19
+ "github.com/sirupsen/logrus"
20
+ "github.com/weaviate/weaviate/entities/models"
21
+ "github.com/weaviate/weaviate/entities/modulecapabilities"
22
+ "github.com/weaviate/weaviate/entities/moduletools"
23
+ "github.com/weaviate/weaviate/entities/schema"
24
+ "github.com/weaviate/weaviate/modules/text2vec-transformers/vectorizer"
25
+ )
26
+
27
+ func (m *TransformersModule) ClassConfigDefaults() map[string]interface{} {
28
+ return map[string]interface{}{
29
+ "vectorizeClassName": vectorizer.DefaultVectorizeClassName,
30
+ "poolingStrategy": vectorizer.DefaultPoolingStrategy,
31
+ }
32
+ }
33
+
34
+ func (m *TransformersModule) PropertyConfigDefaults(
35
+ dt *schema.DataType,
36
+ ) map[string]interface{} {
37
+ return map[string]interface{}{
38
+ "skip": !vectorizer.DefaultPropertyIndexed,
39
+ "vectorizePropertyName": vectorizer.DefaultVectorizePropertyName,
40
+ }
41
+ }
42
+
43
+ func (m *TransformersModule) ValidateClass(ctx context.Context,
44
+ class *models.Class, cfg moduletools.ClassConfig,
45
+ ) error {
46
+ settings := vectorizer.NewClassSettings(cfg)
47
+ return NewConfigValidator(m.logger).Do(ctx, class, cfg, settings)
48
+ }
49
+
50
+ var _ = modulecapabilities.ClassConfigurator(New())
51
+
52
+ type ConfigValidator struct {
53
+ logger logrus.FieldLogger
54
+ }
55
+
56
+ type ClassSettings interface {
57
+ VectorizeClassName() bool
58
+ VectorizePropertyName(propName string) bool
59
+ PropertyIndexed(propName string) bool
60
+ }
61
+
62
+ func NewConfigValidator(logger logrus.FieldLogger) *ConfigValidator {
63
+ return &ConfigValidator{logger: logger}
64
+ }
65
+
66
+ func (cv *ConfigValidator) Do(ctx context.Context, class *models.Class,
67
+ cfg moduletools.ClassConfig, settings ClassSettings,
68
+ ) error {
69
+ // In text2vec-transformers (as opposed to e.g. text2vec-contextionary) the
70
+ // assumption is that the models will be able to deal with any words, even
71
+ // previously unseen ones. Therefore we do not need to validate individual
72
+ // properties, but only the overall "index state"
73
+
74
+ if err := cv.validateIndexState(ctx, class, settings); err != nil {
75
+ return errors.Errorf("invalid combination of properties")
76
+ }
77
+
78
+ cv.checkForPossibilityOfDuplicateVectors(ctx, class, settings)
79
+
80
+ return nil
81
+ }
82
+
83
+ func (cv *ConfigValidator) validateIndexState(ctx context.Context,
84
+ class *models.Class, settings ClassSettings,
85
+ ) error {
86
+ if settings.VectorizeClassName() {
87
+ // if the user chooses to vectorize the classname, vector-building will
88
+ // always be possible, no need to investigate further
89
+
90
+ return nil
91
+ }
92
+
93
+ // search if there is at least one indexed, string/text prop. If found pass
94
+ // validation
95
+ for _, prop := range class.Properties {
96
+ if len(prop.DataType) < 1 {
97
+ return errors.Errorf("property %s must have at least one datatype: "+
98
+ "got %v", prop.Name, prop.DataType)
99
+ }
100
+
101
+ if prop.DataType[0] != string(schema.DataTypeText) {
102
+ // we can only vectorize text-like props
103
+ continue
104
+ }
105
+
106
+ if settings.PropertyIndexed(prop.Name) {
107
+ // found at least one, this is a valid schema
108
+ return nil
109
+ }
110
+ }
111
+
112
+ return fmt.Errorf("invalid properties: didn't find a single property which is " +
113
+ "of type string or text and is not excluded from indexing. In addition the " +
114
+ "class name is excluded from vectorization as well, meaning that it cannot be " +
115
+ "used to determine the vector position. To fix this, set 'vectorizeClassName' " +
116
+ "to true if the class name is contextionary-valid. Alternatively add at least " +
117
+ "contextionary-valid text/string property which is not excluded from " +
118
+ "indexing.")
119
+ }
120
+
121
+ func (cv *ConfigValidator) checkForPossibilityOfDuplicateVectors(
122
+ ctx context.Context, class *models.Class, settings ClassSettings,
123
+ ) {
124
+ if !settings.VectorizeClassName() {
125
+ // if the user choses not to vectorize the class name, this means they must
126
+ // have chosen something else to vectorize, otherwise the validation would
127
+ // have error'd before we ever got here. We can skip further checking.
128
+
129
+ return
130
+ }
131
+
132
+ // search if there is at least one indexed, string/text prop. If found exit
133
+ for _, prop := range class.Properties {
134
+ // length check skipped, because validation has already passed
135
+ if prop.DataType[0] != string(schema.DataTypeText) {
136
+ // we can only vectorize text-like props
137
+ continue
138
+ }
139
+
140
+ if settings.PropertyIndexed(prop.Name) {
141
+ // found at least one
142
+ return
143
+ }
144
+ }
145
+
146
+ cv.logger.WithField("module", "text2vec-transformers").
147
+ WithField("class", class.Class).
148
+ Warnf("text2vec-contextionary: Class %q does not have any properties "+
149
+ "indexed (or only non text-properties indexed) and the vector position is "+
150
+ "only determined by the class name. Each object will end up with the same "+
151
+ "vector which leads to a severe performance penalty on imports. Consider "+
152
+ "setting vectorIndexConfig.skip=true for this property", class.Class)
153
+ }
modules/text2vec-transformers/config_test.go ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package modtransformers
13
+
14
+ import (
15
+ "context"
16
+ "testing"
17
+
18
+ "github.com/sirupsen/logrus"
19
+ ltest "github.com/sirupsen/logrus/hooks/test"
20
+ "github.com/stretchr/testify/assert"
21
+ "github.com/stretchr/testify/require"
22
+ "github.com/weaviate/weaviate/entities/models"
23
+ "github.com/weaviate/weaviate/entities/schema"
24
+ )
25
+
26
+ func TestConfigDefaults(t *testing.T) {
27
+ t.Run("for properties", func(t *testing.T) {
28
+ def := New().ClassConfigDefaults()
29
+
30
+ assert.Equal(t, true, def["vectorizeClassName"])
31
+ assert.Equal(t, "masked_mean", def["poolingStrategy"])
32
+ })
33
+
34
+ t.Run("for the class", func(t *testing.T) {
35
+ dt := schema.DataTypeText
36
+ def := New().PropertyConfigDefaults(&dt)
37
+ assert.Equal(t, false, def["vectorizePropertyName"])
38
+ assert.Equal(t, false, def["skip"])
39
+ })
40
+ }
41
+
42
+ func TestConfigValidator(t *testing.T) {
43
+ t.Run("all usable props no-indexed", func(t *testing.T) {
44
+ t.Run("all schema vectorization turned off", func(t *testing.T) {
45
+ class := &models.Class{
46
+ Vectorizer: "text2vec-contextionary",
47
+ Class: "ValidName",
48
+ Properties: []*models.Property{
49
+ {
50
+ DataType: []string{"text"},
51
+ Name: "description",
52
+ },
53
+ {
54
+ DataType: schema.DataTypeText.PropString(),
55
+ Tokenization: models.PropertyTokenizationWhitespace,
56
+ Name: "name",
57
+ },
58
+ {
59
+ DataType: []string{"int"},
60
+ Name: "amount",
61
+ },
62
+ },
63
+ }
64
+
65
+ logger, _ := ltest.NewNullLogger()
66
+ v := NewConfigValidator(logger)
67
+ err := v.Do(context.Background(), class, nil, &fakeIndexChecker{
68
+ vectorizePropertyName: false,
69
+ vectorizeClassName: false,
70
+ propertyIndexed: false,
71
+ })
72
+ assert.NotNil(t, err)
73
+ })
74
+ })
75
+ }
76
+
77
+ func TestConfigValidator_RiskOfDuplicateVectors(t *testing.T) {
78
+ type test struct {
79
+ name string
80
+ in *models.Class
81
+ expectWarning bool
82
+ indexChecker *fakeIndexChecker
83
+ }
84
+
85
+ tests := []test{
86
+ {
87
+ name: "usable properties",
88
+ in: &models.Class{
89
+ Class: "ValidName",
90
+ Properties: []*models.Property{
91
+ {
92
+ DataType: []string{string(schema.DataTypeText)},
93
+ Name: "textProp",
94
+ },
95
+ },
96
+ },
97
+ expectWarning: false,
98
+ indexChecker: &fakeIndexChecker{
99
+ vectorizePropertyName: false,
100
+ vectorizeClassName: true,
101
+ propertyIndexed: true,
102
+ },
103
+ },
104
+ {
105
+ name: "no properties",
106
+ in: &models.Class{
107
+ Class: "ValidName",
108
+ },
109
+ expectWarning: true,
110
+ indexChecker: &fakeIndexChecker{
111
+ vectorizePropertyName: false,
112
+ vectorizeClassName: true,
113
+ propertyIndexed: false,
114
+ },
115
+ },
116
+ {
117
+ name: "usable properties, but they are no-indexed",
118
+ in: &models.Class{
119
+ Class: "ValidName",
120
+ Properties: []*models.Property{
121
+ {
122
+ DataType: []string{string(schema.DataTypeText)},
123
+ Name: "textProp",
124
+ },
125
+ },
126
+ },
127
+ expectWarning: true,
128
+ indexChecker: &fakeIndexChecker{
129
+ vectorizePropertyName: false,
130
+ vectorizeClassName: true,
131
+ propertyIndexed: false,
132
+ },
133
+ },
134
+ {
135
+ name: "only unusable properties",
136
+ in: &models.Class{
137
+ Class: "ValidName",
138
+ Properties: []*models.Property{
139
+ {
140
+ DataType: []string{string(schema.DataTypeInt)},
141
+ Name: "intProp",
142
+ },
143
+ },
144
+ },
145
+ expectWarning: true,
146
+ indexChecker: &fakeIndexChecker{
147
+ vectorizePropertyName: false,
148
+ vectorizeClassName: true,
149
+ propertyIndexed: false,
150
+ },
151
+ },
152
+ }
153
+
154
+ for _, test := range tests {
155
+ t.Run(test.name, func(t *testing.T) {
156
+ logger, hook := ltest.NewNullLogger()
157
+ v := NewConfigValidator(logger)
158
+ err := v.Do(context.Background(), test.in, nil, test.indexChecker)
159
+ require.Nil(t, err)
160
+
161
+ entry := hook.LastEntry()
162
+ if test.expectWarning {
163
+ require.NotNil(t, entry)
164
+ assert.Equal(t, logrus.WarnLevel, entry.Level)
165
+ } else {
166
+ assert.Nil(t, entry)
167
+ }
168
+ })
169
+ }
170
+ }
171
+
172
+ type fakeIndexChecker struct {
173
+ vectorizeClassName bool
174
+ vectorizePropertyName bool
175
+ propertyIndexed bool
176
+ }
177
+
178
+ func (f *fakeIndexChecker) VectorizeClassName() bool {
179
+ return f.vectorizeClassName
180
+ }
181
+
182
+ func (f *fakeIndexChecker) VectorizePropertyName(propName string) bool {
183
+ return f.vectorizePropertyName
184
+ }
185
+
186
+ func (f *fakeIndexChecker) PropertyIndexed(propName string) bool {
187
+ return f.propertyIndexed
188
+ }
modules/text2vec-transformers/ent/vectorization_config.go ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package ent
13
+
14
+ type VectorizationConfig struct {
15
+ PoolingStrategy string
16
+ }
modules/text2vec-transformers/ent/vectorization_result.go ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package ent
13
+
14
+ type VectorizationResult struct {
15
+ Text string
16
+ Dimensions int
17
+ Vector []float32
18
+ }
modules/text2vec-transformers/module.go ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package modtransformers
13
+
14
+ import (
15
+ "context"
16
+ "net/http"
17
+ "os"
18
+ "time"
19
+
20
+ "github.com/pkg/errors"
21
+ "github.com/sirupsen/logrus"
22
+ "github.com/weaviate/weaviate/entities/models"
23
+ "github.com/weaviate/weaviate/entities/modulecapabilities"
24
+ "github.com/weaviate/weaviate/entities/moduletools"
25
+ "github.com/weaviate/weaviate/modules/text2vec-transformers/clients"
26
+ "github.com/weaviate/weaviate/modules/text2vec-transformers/vectorizer"
27
+ "github.com/weaviate/weaviate/usecases/modulecomponents/additional"
28
+ )
29
+
30
+ func New() *TransformersModule {
31
+ return &TransformersModule{}
32
+ }
33
+
34
+ type TransformersModule struct {
35
+ vectorizer textVectorizer
36
+ metaProvider metaProvider
37
+ graphqlProvider modulecapabilities.GraphQLArguments
38
+ searcher modulecapabilities.Searcher
39
+ nearTextTransformer modulecapabilities.TextTransform
40
+ logger logrus.FieldLogger
41
+ additionalPropertiesProvider modulecapabilities.AdditionalProperties
42
+ }
43
+
44
+ type textVectorizer interface {
45
+ Object(ctx context.Context, obj *models.Object, objDiff *moduletools.ObjectDiff,
46
+ cfg moduletools.ClassConfig) error
47
+ Texts(ctx context.Context, input []string,
48
+ cfg moduletools.ClassConfig) ([]float32, error)
49
+ }
50
+
51
+ type metaProvider interface {
52
+ MetaInfo() (map[string]interface{}, error)
53
+ }
54
+
55
+ func (m *TransformersModule) Name() string {
56
+ return "text2vec-transformers"
57
+ }
58
+
59
+ func (m *TransformersModule) Type() modulecapabilities.ModuleType {
60
+ return modulecapabilities.Text2Vec
61
+ }
62
+
63
+ func (m *TransformersModule) Init(ctx context.Context,
64
+ params moduletools.ModuleInitParams,
65
+ ) error {
66
+ m.logger = params.GetLogger()
67
+
68
+ if err := m.initVectorizer(ctx, params.GetConfig().ModuleHttpClientTimeout, m.logger); err != nil {
69
+ return errors.Wrap(err, "init vectorizer")
70
+ }
71
+
72
+ if err := m.initAdditionalPropertiesProvider(); err != nil {
73
+ return errors.Wrap(err, "init additional properties provider")
74
+ }
75
+
76
+ return nil
77
+ }
78
+
79
+ func (m *TransformersModule) InitExtension(modules []modulecapabilities.Module) error {
80
+ for _, module := range modules {
81
+ if module.Name() == m.Name() {
82
+ continue
83
+ }
84
+ if arg, ok := module.(modulecapabilities.TextTransformers); ok {
85
+ if arg != nil && arg.TextTransformers() != nil {
86
+ m.nearTextTransformer = arg.TextTransformers()["nearText"]
87
+ }
88
+ }
89
+ }
90
+
91
+ if err := m.initNearText(); err != nil {
92
+ return errors.Wrap(err, "init graphql provider")
93
+ }
94
+ return nil
95
+ }
96
+
97
+ func (m *TransformersModule) initVectorizer(ctx context.Context, timeout time.Duration,
98
+ logger logrus.FieldLogger,
99
+ ) error {
100
+ // TODO: gh-1486 proper config management
101
+ uriPassage := os.Getenv("TRANSFORMERS_PASSAGE_INFERENCE_API")
102
+ uriQuery := os.Getenv("TRANSFORMERS_QUERY_INFERENCE_API")
103
+ uriCommon := os.Getenv("TRANSFORMERS_INFERENCE_API")
104
+
105
+ if uriCommon == "" {
106
+ if uriPassage == "" && uriQuery == "" {
107
+ return errors.Errorf("required variable TRANSFORMERS_INFERENCE_API or both variables TRANSFORMERS_PASSAGE_INFERENCE_API and TRANSFORMERS_QUERY_INFERENCE_API are not set")
108
+ }
109
+ if uriPassage != "" && uriQuery == "" {
110
+ return errors.Errorf("required variable TRANSFORMERS_QUERY_INFERENCE_API is not set")
111
+ }
112
+ if uriPassage == "" && uriQuery != "" {
113
+ return errors.Errorf("required variable TRANSFORMERS_PASSAGE_INFERENCE_API is not set")
114
+ }
115
+ } else {
116
+ if uriPassage != "" || uriQuery != "" {
117
+ return errors.Errorf("either variable TRANSFORMERS_INFERENCE_API or both variables TRANSFORMERS_PASSAGE_INFERENCE_API and TRANSFORMERS_QUERY_INFERENCE_API should be set")
118
+ }
119
+ uriPassage = uriCommon
120
+ uriQuery = uriCommon
121
+ }
122
+
123
+ client := clients.New(uriPassage, uriQuery, timeout, logger)
124
+ if err := client.WaitForStartup(ctx, 1*time.Second); err != nil {
125
+ return errors.Wrap(err, "init remote vectorizer")
126
+ }
127
+
128
+ m.vectorizer = vectorizer.New(client)
129
+ m.metaProvider = client
130
+
131
+ return nil
132
+ }
133
+
134
+ func (m *TransformersModule) initAdditionalPropertiesProvider() error {
135
+ m.additionalPropertiesProvider = additional.NewText2VecProvider()
136
+ return nil
137
+ }
138
+
139
+ func (m *TransformersModule) RootHandler() http.Handler {
140
+ // TODO: remove once this is a capability interface
141
+ return nil
142
+ }
143
+
144
+ func (m *TransformersModule) VectorizeObject(ctx context.Context,
145
+ obj *models.Object, objDiff *moduletools.ObjectDiff, cfg moduletools.ClassConfig,
146
+ ) error {
147
+ return m.vectorizer.Object(ctx, obj, objDiff, cfg)
148
+ }
149
+
150
+ func (m *TransformersModule) MetaInfo() (map[string]interface{}, error) {
151
+ return m.metaProvider.MetaInfo()
152
+ }
153
+
154
+ func (m *TransformersModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
155
+ return m.additionalPropertiesProvider.AdditionalProperties()
156
+ }
157
+
158
+ func (m *TransformersModule) VectorizeInput(ctx context.Context,
159
+ input string, cfg moduletools.ClassConfig,
160
+ ) ([]float32, error) {
161
+ return m.vectorizer.Texts(ctx, []string{input}, cfg)
162
+ }
163
+
164
+ // verify we implement the modules.Module interface
165
+ var (
166
+ _ = modulecapabilities.Module(New())
167
+ _ = modulecapabilities.Vectorizer(New())
168
+ _ = modulecapabilities.MetaProvider(New())
169
+ )
modules/text2vec-transformers/nearText.go ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package modtransformers
13
+
14
+ import (
15
+ "github.com/weaviate/weaviate/entities/modulecapabilities"
16
+ "github.com/weaviate/weaviate/usecases/modulecomponents/nearText"
17
+ )
18
+
19
+ func (m *TransformersModule) initNearText() error {
20
+ m.searcher = nearText.NewSearcher(m.vectorizer)
21
+ m.graphqlProvider = nearText.New(m.nearTextTransformer)
22
+ return nil
23
+ }
24
+
25
+ func (m *TransformersModule) Arguments() map[string]modulecapabilities.GraphQLArgument {
26
+ return m.graphqlProvider.Arguments()
27
+ }
28
+
29
+ func (m *TransformersModule) VectorSearches() map[string]modulecapabilities.VectorForParams {
30
+ return m.searcher.VectorSearches()
31
+ }
32
+
33
+ var (
34
+ _ = modulecapabilities.GraphQLArguments(New())
35
+ _ = modulecapabilities.Searcher(New())
36
+ )
modules/text2vec-transformers/vectorizer/class_settings.go ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package vectorizer
13
+
14
+ import (
15
+ "github.com/weaviate/weaviate/entities/moduletools"
16
+ )
17
+
18
+ const (
19
+ DefaultPropertyIndexed = true
20
+ DefaultVectorizeClassName = true
21
+ DefaultVectorizePropertyName = false
22
+ DefaultPoolingStrategy = "masked_mean"
23
+ )
24
+
25
+ type classSettings struct {
26
+ cfg moduletools.ClassConfig
27
+ }
28
+
29
+ func NewClassSettings(cfg moduletools.ClassConfig) *classSettings {
30
+ return &classSettings{cfg: cfg}
31
+ }
32
+
33
+ func (ic *classSettings) PropertyIndexed(propName string) bool {
34
+ if ic.cfg == nil {
35
+ // we would receive a nil-config on cross-class requests, such as Explore{}
36
+ return DefaultPropertyIndexed
37
+ }
38
+
39
+ vcn, ok := ic.cfg.Property(propName)["skip"]
40
+ if !ok {
41
+ return DefaultPropertyIndexed
42
+ }
43
+
44
+ asBool, ok := vcn.(bool)
45
+ if !ok {
46
+ return DefaultPropertyIndexed
47
+ }
48
+
49
+ return !asBool
50
+ }
51
+
52
+ func (ic *classSettings) VectorizePropertyName(propName string) bool {
53
+ if ic.cfg == nil {
54
+ // we would receive a nil-config on cross-class requests, such as Explore{}
55
+ return DefaultVectorizePropertyName
56
+ }
57
+ vcn, ok := ic.cfg.Property(propName)["vectorizePropertyName"]
58
+ if !ok {
59
+ return DefaultVectorizePropertyName
60
+ }
61
+
62
+ asBool, ok := vcn.(bool)
63
+ if !ok {
64
+ return DefaultVectorizePropertyName
65
+ }
66
+
67
+ return asBool
68
+ }
69
+
70
+ func (ic *classSettings) VectorizeClassName() bool {
71
+ if ic.cfg == nil {
72
+ // we would receive a nil-config on cross-class requests, such as Explore{}
73
+ return DefaultVectorizeClassName
74
+ }
75
+
76
+ vcn, ok := ic.cfg.Class()["vectorizeClassName"]
77
+ if !ok {
78
+ return DefaultVectorizeClassName
79
+ }
80
+
81
+ asBool, ok := vcn.(bool)
82
+ if !ok {
83
+ return DefaultVectorizeClassName
84
+ }
85
+
86
+ return asBool
87
+ }
88
+
89
+ func (ic *classSettings) PoolingStrategy() string {
90
+ if ic.cfg == nil {
91
+ // we would receive a nil-config on cross-class requests, such as Explore{}
92
+ return DefaultPoolingStrategy
93
+ }
94
+
95
+ vcn, ok := ic.cfg.Class()["poolingStrategy"]
96
+ if !ok {
97
+ return DefaultPoolingStrategy
98
+ }
99
+
100
+ asString, ok := vcn.(string)
101
+ if !ok {
102
+ return DefaultPoolingStrategy
103
+ }
104
+
105
+ return asString
106
+ }
modules/text2vec-transformers/vectorizer/class_settings_test.go ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package vectorizer
13
+
14
+ import (
15
+ "testing"
16
+
17
+ "github.com/stretchr/testify/assert"
18
+ "github.com/weaviate/weaviate/entities/models"
19
+ "github.com/weaviate/weaviate/usecases/modules"
20
+ )
21
+
22
+ func TestClassSettings(t *testing.T) {
23
+ t.Run("with all defaults", func(t *testing.T) {
24
+ class := &models.Class{
25
+ Class: "MyClass",
26
+ Properties: []*models.Property{{
27
+ Name: "someProp",
28
+ }},
29
+ }
30
+
31
+ cfg := modules.NewClassBasedModuleConfig(class, "my-module", "tenant")
32
+ ic := NewClassSettings(cfg)
33
+
34
+ assert.True(t, ic.PropertyIndexed("someProp"))
35
+ assert.False(t, ic.VectorizePropertyName("someProp"))
36
+ assert.True(t, ic.VectorizeClassName())
37
+ assert.Equal(t, ic.PoolingStrategy(), "masked_mean")
38
+ })
39
+
40
+ t.Run("with a nil config", func(t *testing.T) {
41
+ // this is the case if we were running in a situation such as a
42
+ // cross-class vectorization of search time, as is the case with Explore
43
+ // {}, we then expect all default values
44
+
45
+ ic := NewClassSettings(nil)
46
+
47
+ assert.True(t, ic.PropertyIndexed("someProp"))
48
+ assert.False(t, ic.VectorizePropertyName("someProp"))
49
+ assert.True(t, ic.VectorizeClassName())
50
+ assert.Equal(t, ic.PoolingStrategy(), "masked_mean")
51
+ })
52
+
53
+ t.Run("with all explicit config matching the defaults", func(t *testing.T) {
54
+ class := &models.Class{
55
+ Class: "MyClass",
56
+ ModuleConfig: map[string]interface{}{
57
+ "my-module": map[string]interface{}{
58
+ "vectorizeClassName": true,
59
+ "poolingStrategy": "masked_mean",
60
+ },
61
+ },
62
+ Properties: []*models.Property{{
63
+ Name: "someProp",
64
+ ModuleConfig: map[string]interface{}{
65
+ "my-module": map[string]interface{}{
66
+ "skip": false,
67
+ "vectorizePropertyName": false,
68
+ },
69
+ },
70
+ }},
71
+ }
72
+
73
+ cfg := modules.NewClassBasedModuleConfig(class, "my-module", "tenant")
74
+ ic := NewClassSettings(cfg)
75
+
76
+ assert.True(t, ic.PropertyIndexed("someProp"))
77
+ assert.False(t, ic.VectorizePropertyName("someProp"))
78
+ assert.True(t, ic.VectorizeClassName())
79
+ assert.Equal(t, ic.PoolingStrategy(), "masked_mean")
80
+ })
81
+
82
+ t.Run("with all explicit config using non-default values", func(t *testing.T) {
83
+ class := &models.Class{
84
+ Class: "MyClass",
85
+ ModuleConfig: map[string]interface{}{
86
+ "my-module": map[string]interface{}{
87
+ "vectorizeClassName": false,
88
+ "poolingStrategy": "cls",
89
+ },
90
+ },
91
+ Properties: []*models.Property{{
92
+ Name: "someProp",
93
+ ModuleConfig: map[string]interface{}{
94
+ "my-module": map[string]interface{}{
95
+ "skip": true,
96
+ "vectorizePropertyName": true,
97
+ },
98
+ },
99
+ }},
100
+ }
101
+
102
+ cfg := modules.NewClassBasedModuleConfig(class, "my-module", "tenant")
103
+ ic := NewClassSettings(cfg)
104
+
105
+ assert.False(t, ic.PropertyIndexed("someProp"))
106
+ assert.True(t, ic.VectorizePropertyName("someProp"))
107
+ assert.False(t, ic.VectorizeClassName())
108
+ assert.Equal(t, ic.PoolingStrategy(), "cls")
109
+ })
110
+ }
modules/text2vec-transformers/vectorizer/fakes_for_test.go ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package vectorizer
13
+
14
+ import (
15
+ "context"
16
+
17
+ "github.com/weaviate/weaviate/modules/text2vec-transformers/ent"
18
+ )
19
+
20
+ type fakeClient struct {
21
+ lastInput string
22
+ lastConfig ent.VectorizationConfig
23
+ }
24
+
25
+ func (c *fakeClient) VectorizeObject(ctx context.Context,
26
+ text string, cfg ent.VectorizationConfig,
27
+ ) (*ent.VectorizationResult, error) {
28
+ c.lastInput = text
29
+ c.lastConfig = cfg
30
+ return &ent.VectorizationResult{
31
+ Vector: []float32{0, 1, 2, 3},
32
+ Dimensions: 4,
33
+ Text: text,
34
+ }, nil
35
+ }
36
+
37
+ func (c *fakeClient) VectorizeQuery(ctx context.Context,
38
+ text string, cfg ent.VectorizationConfig,
39
+ ) (*ent.VectorizationResult, error) {
40
+ return c.VectorizeObject(ctx, text, cfg)
41
+ }
42
+
43
+ type fakeClassConfig struct {
44
+ classConfig map[string]interface{}
45
+ vectorizeClassName bool
46
+ vectorizePropertyName bool
47
+ skippedProperty string
48
+ excludedProperty string
49
+ poolingStrategy string
50
+ }
51
+
52
+ func (f fakeClassConfig) Class() map[string]interface{} {
53
+ classSettings := map[string]interface{}{
54
+ "vectorizeClassName": f.vectorizeClassName,
55
+ "poolingStrategy": f.poolingStrategy,
56
+ }
57
+ return classSettings
58
+ }
59
+
60
+ func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} {
61
+ return f.classConfig
62
+ }
63
+
64
+ func (f fakeClassConfig) Property(propName string) map[string]interface{} {
65
+ if propName == f.skippedProperty {
66
+ return map[string]interface{}{
67
+ "skip": true,
68
+ }
69
+ }
70
+ if propName == f.excludedProperty {
71
+ return map[string]interface{}{
72
+ "vectorizePropertyName": false,
73
+ }
74
+ }
75
+ if f.vectorizePropertyName {
76
+ return map[string]interface{}{
77
+ "vectorizePropertyName": true,
78
+ }
79
+ }
80
+ return nil
81
+ }
82
+
83
+ func (f fakeClassConfig) Tenant() string {
84
+ return ""
85
+ }
modules/text2vec-transformers/vectorizer/objects.go ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package vectorizer
13
+
14
+ import (
15
+ "context"
16
+
17
+ "github.com/weaviate/weaviate/entities/models"
18
+ "github.com/weaviate/weaviate/entities/moduletools"
19
+ "github.com/weaviate/weaviate/modules/text2vec-transformers/ent"
20
+ objectsvectorizer "github.com/weaviate/weaviate/usecases/modulecomponents/vectorizer"
21
+ )
22
+
23
+ type Vectorizer struct {
24
+ client Client
25
+ objectVectorizer *objectsvectorizer.ObjectVectorizer
26
+ }
27
+
28
+ func New(client Client) *Vectorizer {
29
+ return &Vectorizer{
30
+ client: client,
31
+ objectVectorizer: objectsvectorizer.New(),
32
+ }
33
+ }
34
+
35
+ type Client interface {
36
+ VectorizeObject(ctx context.Context, input string,
37
+ cfg ent.VectorizationConfig) (*ent.VectorizationResult, error)
38
+ VectorizeQuery(ctx context.Context, input string,
39
+ cfg ent.VectorizationConfig) (*ent.VectorizationResult, error)
40
+ }
41
+
42
+ // IndexCheck returns whether a property of a class should be indexed
43
+ type ClassSettings interface {
44
+ PropertyIndexed(property string) bool
45
+ VectorizeClassName() bool
46
+ VectorizePropertyName(propertyName string) bool
47
+ PoolingStrategy() string
48
+ }
49
+
50
+ func (v *Vectorizer) Object(ctx context.Context, object *models.Object,
51
+ objDiff *moduletools.ObjectDiff, cfg moduletools.ClassConfig,
52
+ ) error {
53
+ vec, err := v.object(ctx, object.Class, object.Properties, objDiff, cfg)
54
+ if err != nil {
55
+ return err
56
+ }
57
+
58
+ object.Vector = vec
59
+ return nil
60
+ }
61
+
62
+ func (v *Vectorizer) object(ctx context.Context, className string,
63
+ schema interface{}, objDiff *moduletools.ObjectDiff, cfg moduletools.ClassConfig,
64
+ ) ([]float32, error) {
65
+ text, vector, err := v.objectVectorizer.TextsOrVector(ctx, className, schema, objDiff, NewClassSettings(cfg))
66
+ if err != nil {
67
+ return nil, err
68
+ }
69
+ if vector != nil {
70
+ // dont' re-vectorize
71
+ return vector, nil
72
+ }
73
+ // vectorize text
74
+ res, err := v.client.VectorizeObject(ctx, text, ent.VectorizationConfig{
75
+ PoolingStrategy: NewClassSettings(cfg).PoolingStrategy(),
76
+ })
77
+ if err != nil {
78
+ return nil, err
79
+ }
80
+
81
+ return res.Vector, nil
82
+ }
modules/text2vec-transformers/vectorizer/objects_test.go ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package vectorizer
13
+
14
+ import (
15
+ "context"
16
+ "strings"
17
+ "testing"
18
+
19
+ "github.com/stretchr/testify/assert"
20
+ "github.com/stretchr/testify/require"
21
+ "github.com/weaviate/weaviate/entities/models"
22
+ "github.com/weaviate/weaviate/entities/moduletools"
23
+ )
24
+
25
+ // These are mostly copy/pasted (with minimal additions) from the
26
+ // text2vec-contextionary module
27
+ func TestVectorizingObjects(t *testing.T) {
28
+ type testCase struct {
29
+ name string
30
+ input *models.Object
31
+ expectedClientCall string
32
+ expectedPoolingStrategy string
33
+ noindex string
34
+ excludedProperty string // to simulate a schema where property names aren't vectorized
35
+ excludedClass string // to simulate a schema where class names aren't vectorized
36
+ poolingStrategy string
37
+ }
38
+
39
+ tests := []testCase{
40
+ {
41
+ name: "empty object",
42
+ input: &models.Object{
43
+ Class: "Car",
44
+ },
45
+ poolingStrategy: "cls",
46
+ expectedPoolingStrategy: "cls",
47
+ expectedClientCall: "car",
48
+ },
49
+ {
50
+ name: "object with one string prop",
51
+ input: &models.Object{
52
+ Class: "Car",
53
+ Properties: map[string]interface{}{
54
+ "brand": "Mercedes",
55
+ },
56
+ },
57
+ expectedClientCall: "car brand mercedes",
58
+ },
59
+
60
+ {
61
+ name: "object with one non-string prop",
62
+ input: &models.Object{
63
+ Class: "Car",
64
+ Properties: map[string]interface{}{
65
+ "power": 300,
66
+ },
67
+ },
68
+ expectedClientCall: "car",
69
+ },
70
+
71
+ {
72
+ name: "object with a mix of props",
73
+ input: &models.Object{
74
+ Class: "Car",
75
+ Properties: map[string]interface{}{
76
+ "brand": "best brand",
77
+ "power": 300,
78
+ "review": "a very great car",
79
+ },
80
+ },
81
+ expectedClientCall: "car brand best brand review a very great car",
82
+ },
83
+ {
84
+ name: "with a noindexed property",
85
+ noindex: "review",
86
+ input: &models.Object{
87
+ Class: "Car",
88
+ Properties: map[string]interface{}{
89
+ "brand": "best brand",
90
+ "power": 300,
91
+ "review": "a very great car",
92
+ },
93
+ },
94
+ expectedClientCall: "car brand best brand",
95
+ },
96
+
97
+ {
98
+ name: "with the class name not vectorized",
99
+ excludedClass: "Car",
100
+ input: &models.Object{
101
+ Class: "Car",
102
+ Properties: map[string]interface{}{
103
+ "brand": "best brand",
104
+ "power": 300,
105
+ "review": "a very great car",
106
+ },
107
+ },
108
+ expectedClientCall: "brand best brand review a very great car",
109
+ },
110
+
111
+ {
112
+ name: "with a property name not vectorized",
113
+ excludedProperty: "review",
114
+ input: &models.Object{
115
+ Class: "Car",
116
+ Properties: map[string]interface{}{
117
+ "brand": "best brand",
118
+ "power": 300,
119
+ "review": "a very great car",
120
+ },
121
+ },
122
+ expectedClientCall: "car brand best brand a very great car",
123
+ },
124
+
125
+ {
126
+ name: "with no schema labels vectorized",
127
+ excludedProperty: "review",
128
+ excludedClass: "Car",
129
+ input: &models.Object{
130
+ Class: "Car",
131
+ Properties: map[string]interface{}{
132
+ "review": "a very great car",
133
+ },
134
+ },
135
+ expectedClientCall: "a very great car",
136
+ },
137
+
138
+ {
139
+ name: "with string/text arrays without propname or classname",
140
+ excludedProperty: "reviews",
141
+ excludedClass: "Car",
142
+ input: &models.Object{
143
+ Class: "Car",
144
+ Properties: map[string]interface{}{
145
+ "reviews": []interface{}{
146
+ "a very great car",
147
+ "you should consider buying one",
148
+ },
149
+ },
150
+ },
151
+ expectedClientCall: "a very great car you should consider buying one",
152
+ },
153
+
154
+ {
155
+ name: "with string/text arrays with propname and classname",
156
+ input: &models.Object{
157
+ Class: "Car",
158
+ Properties: map[string]interface{}{
159
+ "reviews": []interface{}{
160
+ "a very great car",
161
+ "you should consider buying one",
162
+ },
163
+ },
164
+ },
165
+ expectedClientCall: "car reviews a very great car reviews you should consider buying one",
166
+ },
167
+
168
+ {
169
+ name: "with compound class and prop names",
170
+ input: &models.Object{
171
+ Class: "SuperCar",
172
+ Properties: map[string]interface{}{
173
+ "brandOfTheCar": "best brand",
174
+ "power": 300,
175
+ "review": "a very great car",
176
+ },
177
+ },
178
+ expectedClientCall: "super car brand of the car best brand review a very great car",
179
+ },
180
+ }
181
+
182
+ for _, test := range tests {
183
+ t.Run(test.name, func(t *testing.T) {
184
+ client := &fakeClient{}
185
+
186
+ v := New(client)
187
+
188
+ ic := &fakeClassConfig{
189
+ excludedProperty: test.excludedProperty,
190
+ skippedProperty: test.noindex,
191
+ vectorizeClassName: test.excludedClass != "Car",
192
+ poolingStrategy: test.poolingStrategy,
193
+ vectorizePropertyName: true,
194
+ }
195
+ err := v.Object(context.Background(), test.input, nil, ic)
196
+
197
+ require.Nil(t, err)
198
+ assert.Equal(t, models.C11yVector{0, 1, 2, 3}, test.input.Vector)
199
+ expected := strings.Split(test.expectedClientCall, " ")
200
+ actual := strings.Split(client.lastInput, " ")
201
+ assert.Equal(t, expected, actual)
202
+ assert.Equal(t, client.lastConfig.PoolingStrategy, test.expectedPoolingStrategy)
203
+ })
204
+ }
205
+ }
206
+
207
+ func TestVectorizingObjectsWithDiff(t *testing.T) {
208
+ type testCase struct {
209
+ name string
210
+ input *models.Object
211
+ skipped string
212
+ diff *moduletools.ObjectDiff
213
+ expectedVectorize bool
214
+ }
215
+
216
+ tests := []testCase{
217
+ {
218
+ name: "no diff",
219
+ input: &models.Object{
220
+ Class: "Car",
221
+ Properties: map[string]interface{}{
222
+ "brand": "best brand",
223
+ "power": 300,
224
+ "description": "a very great car",
225
+ "reviews": []interface{}{
226
+ "a very great car",
227
+ "you should consider buying one",
228
+ },
229
+ },
230
+ },
231
+ diff: nil,
232
+ expectedVectorize: true,
233
+ },
234
+ {
235
+ name: "diff all props unchanged",
236
+ input: &models.Object{
237
+ Class: "Car",
238
+ Properties: map[string]interface{}{
239
+ "brand": "best brand",
240
+ "power": 300,
241
+ "description": "a very great car",
242
+ "reviews": []interface{}{
243
+ "a very great car",
244
+ "you should consider buying one",
245
+ },
246
+ },
247
+ },
248
+ diff: newObjectDiffWithVector().
249
+ WithProp("brand", "best brand", "best brand").
250
+ WithProp("power", 300, 300).
251
+ WithProp("description", "a very great car", "a very great car").
252
+ WithProp("reviews", []interface{}{
253
+ "a very great car",
254
+ "you should consider buying one",
255
+ }, []interface{}{
256
+ "a very great car",
257
+ "you should consider buying one",
258
+ }),
259
+ expectedVectorize: false,
260
+ },
261
+ {
262
+ name: "diff one vectorizable prop changed (1)",
263
+ input: &models.Object{
264
+ Class: "Car",
265
+ Properties: map[string]interface{}{
266
+ "brand": "best brand",
267
+ "power": 300,
268
+ "description": "a very great car",
269
+ "reviews": []interface{}{
270
+ "a very great car",
271
+ "you should consider buying one",
272
+ },
273
+ },
274
+ },
275
+ diff: newObjectDiffWithVector().
276
+ WithProp("brand", "old best brand", "best brand"),
277
+ expectedVectorize: true,
278
+ },
279
+ {
280
+ name: "diff one vectorizable prop changed (2)",
281
+ input: &models.Object{
282
+ Class: "Car",
283
+ Properties: map[string]interface{}{
284
+ "brand": "best brand",
285
+ "power": 300,
286
+ "description": "a very great car",
287
+ "reviews": []interface{}{
288
+ "a very great car",
289
+ "you should consider buying one",
290
+ },
291
+ },
292
+ },
293
+ diff: newObjectDiffWithVector().
294
+ WithProp("description", "old a very great car", "a very great car"),
295
+ expectedVectorize: true,
296
+ },
297
+ {
298
+ name: "diff one vectorizable prop changed (3)",
299
+ input: &models.Object{
300
+ Class: "Car",
301
+ Properties: map[string]interface{}{
302
+ "brand": "best brand",
303
+ "power": 300,
304
+ "description": "a very great car",
305
+ "reviews": []interface{}{
306
+ "a very great car",
307
+ "you should consider buying one",
308
+ },
309
+ },
310
+ },
311
+ diff: newObjectDiffWithVector().
312
+ WithProp("reviews", []interface{}{
313
+ "old a very great car",
314
+ "you should consider buying one",
315
+ }, []interface{}{
316
+ "a very great car",
317
+ "you should consider buying one",
318
+ }),
319
+ expectedVectorize: true,
320
+ },
321
+ {
322
+ name: "all non-vectorizable props changed",
323
+ skipped: "description",
324
+ input: &models.Object{
325
+ Class: "Car",
326
+ Properties: map[string]interface{}{
327
+ "brand": "best brand",
328
+ "power": 300,
329
+ "description": "a very great car",
330
+ "reviews": []interface{}{
331
+ "a very great car",
332
+ "you should consider buying one",
333
+ },
334
+ },
335
+ },
336
+ diff: newObjectDiffWithVector().
337
+ WithProp("power", 123, 300).
338
+ WithProp("description", "old a very great car", "a very great car"),
339
+ expectedVectorize: false,
340
+ },
341
+ }
342
+
343
+ for _, test := range tests {
344
+ t.Run(test.name, func(t *testing.T) {
345
+ ic := &fakeClassConfig{
346
+ skippedProperty: test.skipped,
347
+ }
348
+
349
+ client := &fakeClient{}
350
+ v := New(client)
351
+
352
+ err := v.Object(context.Background(), test.input, test.diff, ic)
353
+
354
+ require.Nil(t, err)
355
+ if test.expectedVectorize {
356
+ assert.Equal(t, models.C11yVector{0, 1, 2, 3}, test.input.Vector)
357
+ assert.NotEmpty(t, client.lastInput)
358
+ } else {
359
+ assert.Equal(t, models.C11yVector{0, 0, 0, 0}, test.input.Vector)
360
+ assert.Empty(t, client.lastInput)
361
+ }
362
+ })
363
+ }
364
+ }
365
+
366
+ func newObjectDiffWithVector() *moduletools.ObjectDiff {
367
+ return moduletools.NewObjectDiff([]float32{0, 0, 0, 0})
368
+ }
modules/text2vec-transformers/vectorizer/texts.go ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package vectorizer
13
+
14
+ import (
15
+ "context"
16
+
17
+ "github.com/pkg/errors"
18
+ "github.com/weaviate/weaviate/entities/moduletools"
19
+ "github.com/weaviate/weaviate/modules/text2vec-transformers/ent"
20
+ libvectorizer "github.com/weaviate/weaviate/usecases/vectorizer"
21
+ )
22
+
23
+ func (v *Vectorizer) Texts(ctx context.Context, inputs []string,
24
+ cfg moduletools.ClassConfig,
25
+ ) ([]float32, error) {
26
+ vectors := make([][]float32, len(inputs))
27
+ for i := range inputs {
28
+ res, err := v.client.VectorizeQuery(ctx, inputs[i], ent.VectorizationConfig{
29
+ PoolingStrategy: NewClassSettings(cfg).PoolingStrategy(),
30
+ })
31
+ if err != nil {
32
+ return nil, errors.Wrap(err, "remote client vectorize")
33
+ }
34
+ vectors[i] = res.Vector
35
+ }
36
+
37
+ return libvectorizer.CombineVectors(vectors), nil
38
+ }
modules/text2vec-transformers/vectorizer/texts_test.go ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // _ _
2
+ // __ _____ __ ___ ___ __ _| |_ ___
3
+ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
+ // \ V V / __/ (_| |\ V /| | (_| | || __/
5
+ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
+ //
7
+ // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
+ //
9
+ // CONTACT: [email protected]
10
+ //
11
+
12
+ package vectorizer
13
+
14
+ import (
15
+ "context"
16
+ "testing"
17
+
18
+ "github.com/stretchr/testify/assert"
19
+ "github.com/stretchr/testify/require"
20
+ )
21
+
22
+ // as used in the nearText searcher
23
+ func TestVectorizingTexts(t *testing.T) {
24
+ type testCase struct {
25
+ name string
26
+ input []string
27
+ expectedPoolingStrategy string
28
+ poolingStrategy string
29
+ }
30
+
31
+ tests := []testCase{
32
+ {
33
+ name: "single word",
34
+ input: []string{"hello"},
35
+ poolingStrategy: "cls",
36
+ expectedPoolingStrategy: "cls",
37
+ },
38
+ {
39
+ name: "multiple words",
40
+ input: []string{"hello world, this is me!"},
41
+ poolingStrategy: "cls",
42
+ expectedPoolingStrategy: "cls",
43
+ },
44
+
45
+ {
46
+ name: "multiple sentences (joined with a dot)",
47
+ input: []string{"this is sentence 1", "and here's number 2"},
48
+ poolingStrategy: "cls",
49
+ expectedPoolingStrategy: "cls",
50
+ },
51
+
52
+ {
53
+ name: "multiple sentences already containing a dot",
54
+ input: []string{"this is sentence 1.", "and here's number 2"},
55
+ poolingStrategy: "cls",
56
+ expectedPoolingStrategy: "cls",
57
+ },
58
+ {
59
+ name: "multiple sentences already containing a question mark",
60
+ input: []string{"this is sentence 1?", "and here's number 2"},
61
+ poolingStrategy: "cls",
62
+ expectedPoolingStrategy: "cls",
63
+ },
64
+ {
65
+ name: "multiple sentences already containing an exclamation mark",
66
+ input: []string{"this is sentence 1!", "and here's number 2"},
67
+ poolingStrategy: "cls",
68
+ expectedPoolingStrategy: "cls",
69
+ },
70
+ {
71
+ name: "multiple sentences already containing comma",
72
+ input: []string{"this is sentence 1,", "and here's number 2"},
73
+ poolingStrategy: "cls",
74
+ expectedPoolingStrategy: "cls",
75
+ },
76
+ }
77
+
78
+ for _, test := range tests {
79
+ t.Run(test.name, func(t *testing.T) {
80
+ client := &fakeClient{}
81
+
82
+ v := New(client)
83
+
84
+ settings := &fakeClassConfig{
85
+ poolingStrategy: test.poolingStrategy,
86
+ }
87
+ vec, err := v.Texts(context.Background(), test.input, settings)
88
+
89
+ require.Nil(t, err)
90
+ assert.Equal(t, []float32{0, 1, 2, 3}, vec)
91
+ assert.Equal(t, client.lastConfig.PoolingStrategy, test.expectedPoolingStrategy)
92
+ })
93
+ }
94
+ }