Spaces:
Running
Running
MVPilgrim
commited on
Commit
·
6b502ec
1
Parent(s):
228dfde
restore
Browse files- modules/text2vec-transformers/clients/meta.go +105 -0
- modules/text2vec-transformers/clients/meta_test.go +350 -0
- modules/text2vec-transformers/clients/startup.go +109 -0
- modules/text2vec-transformers/clients/startup_test.go +197 -0
- modules/text2vec-transformers/clients/transformers.go +123 -0
- modules/text2vec-transformers/clients/transformers_test.go +114 -0
- modules/text2vec-transformers/config.go +153 -0
- modules/text2vec-transformers/config_test.go +188 -0
- modules/text2vec-transformers/ent/vectorization_config.go +16 -0
- modules/text2vec-transformers/ent/vectorization_result.go +18 -0
- modules/text2vec-transformers/module.go +169 -0
- modules/text2vec-transformers/nearText.go +36 -0
- modules/text2vec-transformers/vectorizer/class_settings.go +106 -0
- modules/text2vec-transformers/vectorizer/class_settings_test.go +110 -0
- modules/text2vec-transformers/vectorizer/fakes_for_test.go +85 -0
- modules/text2vec-transformers/vectorizer/objects.go +82 -0
- modules/text2vec-transformers/vectorizer/objects_test.go +368 -0
- modules/text2vec-transformers/vectorizer/texts.go +38 -0
- modules/text2vec-transformers/vectorizer/texts_test.go +94 -0
modules/text2vec-transformers/clients/meta.go
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package clients
|
13 |
+
|
14 |
+
import (
|
15 |
+
"context"
|
16 |
+
"encoding/json"
|
17 |
+
"fmt"
|
18 |
+
"io"
|
19 |
+
"net/http"
|
20 |
+
"strings"
|
21 |
+
"sync"
|
22 |
+
|
23 |
+
"github.com/pkg/errors"
|
24 |
+
)
|
25 |
+
|
26 |
+
func (v *vectorizer) MetaInfo() (map[string]interface{}, error) {
|
27 |
+
type nameMetaErr struct {
|
28 |
+
name string
|
29 |
+
meta map[string]interface{}
|
30 |
+
err error
|
31 |
+
}
|
32 |
+
|
33 |
+
endpoints := map[string]string{}
|
34 |
+
if v.originPassage != v.originQuery {
|
35 |
+
endpoints["passage"] = v.urlPassage("/meta")
|
36 |
+
endpoints["query"] = v.urlQuery("/meta")
|
37 |
+
} else {
|
38 |
+
endpoints[""] = v.urlPassage("/meta")
|
39 |
+
}
|
40 |
+
|
41 |
+
var wg sync.WaitGroup
|
42 |
+
ch := make(chan nameMetaErr, len(endpoints))
|
43 |
+
for serviceName, endpoint := range endpoints {
|
44 |
+
wg.Add(1)
|
45 |
+
go func(serviceName string, endpoint string) {
|
46 |
+
defer wg.Done()
|
47 |
+
meta, err := v.metaInfo(endpoint)
|
48 |
+
ch <- nameMetaErr{serviceName, meta, err}
|
49 |
+
}(serviceName, endpoint)
|
50 |
+
}
|
51 |
+
wg.Wait()
|
52 |
+
close(ch)
|
53 |
+
|
54 |
+
metas := map[string]interface{}{}
|
55 |
+
var errs []string
|
56 |
+
for nme := range ch {
|
57 |
+
if nme.err != nil {
|
58 |
+
prefix := ""
|
59 |
+
if nme.name != "" {
|
60 |
+
prefix = "[" + nme.name + "] "
|
61 |
+
}
|
62 |
+
errs = append(errs, fmt.Sprintf("%s%v", prefix, nme.err.Error()))
|
63 |
+
}
|
64 |
+
if nme.meta != nil {
|
65 |
+
metas[nme.name] = nme.meta
|
66 |
+
}
|
67 |
+
}
|
68 |
+
|
69 |
+
if len(errs) > 0 {
|
70 |
+
return nil, errors.Errorf(strings.Join(errs, ", "))
|
71 |
+
}
|
72 |
+
if len(metas) == 1 {
|
73 |
+
for _, meta := range metas {
|
74 |
+
return meta.(map[string]interface{}), nil
|
75 |
+
}
|
76 |
+
}
|
77 |
+
return metas, nil
|
78 |
+
}
|
79 |
+
|
80 |
+
func (v *vectorizer) metaInfo(endpoint string) (map[string]interface{}, error) {
|
81 |
+
req, err := http.NewRequestWithContext(context.Background(), "GET", endpoint, nil)
|
82 |
+
if err != nil {
|
83 |
+
return nil, errors.Wrap(err, "create GET meta request")
|
84 |
+
}
|
85 |
+
|
86 |
+
res, err := v.httpClient.Do(req)
|
87 |
+
if err != nil {
|
88 |
+
return nil, errors.Wrap(err, "send GET meta request")
|
89 |
+
}
|
90 |
+
defer res.Body.Close()
|
91 |
+
if !(res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices) {
|
92 |
+
return nil, errors.Errorf("unexpected status code '%d' of meta request", res.StatusCode)
|
93 |
+
}
|
94 |
+
|
95 |
+
bodyBytes, err := io.ReadAll(res.Body)
|
96 |
+
if err != nil {
|
97 |
+
return nil, errors.Wrap(err, "read meta response body")
|
98 |
+
}
|
99 |
+
|
100 |
+
var resBody map[string]interface{}
|
101 |
+
if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
|
102 |
+
return nil, errors.Wrap(err, "unmarshal meta response body")
|
103 |
+
}
|
104 |
+
return resBody, nil
|
105 |
+
}
|
modules/text2vec-transformers/clients/meta_test.go
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package clients
|
13 |
+
|
14 |
+
import (
|
15 |
+
"net/http"
|
16 |
+
"net/http/httptest"
|
17 |
+
"testing"
|
18 |
+
"time"
|
19 |
+
|
20 |
+
"github.com/stretchr/testify/assert"
|
21 |
+
)
|
22 |
+
|
23 |
+
func TestGetMeta(t *testing.T) {
|
24 |
+
t.Run("when common server is providing meta", func(t *testing.T) {
|
25 |
+
server := httptest.NewServer(&testMetaHandler{t: t})
|
26 |
+
defer server.Close()
|
27 |
+
v := New(server.URL, server.URL, 0, nullLogger())
|
28 |
+
meta, err := v.MetaInfo()
|
29 |
+
|
30 |
+
assert.Nil(t, err)
|
31 |
+
assert.NotNil(t, meta)
|
32 |
+
|
33 |
+
model := extractChildMap(t, meta, "model")
|
34 |
+
assert.NotNil(t, model["_name_or_path"])
|
35 |
+
assert.NotNil(t, model["architectures"])
|
36 |
+
assert.Contains(t, model["architectures"], "DistilBertModel")
|
37 |
+
ID2Label := extractChildMap(t, model, "id2label")
|
38 |
+
assert.NotNil(t, ID2Label["0"])
|
39 |
+
assert.NotNil(t, ID2Label["1"])
|
40 |
+
})
|
41 |
+
|
42 |
+
t.Run("when passage and query servers are providing meta", func(t *testing.T) {
|
43 |
+
serverPassage := httptest.NewServer(&testMetaHandler{t: t, modelType: "passage"})
|
44 |
+
serverQuery := httptest.NewServer(&testMetaHandler{t: t, modelType: "query"})
|
45 |
+
defer serverPassage.Close()
|
46 |
+
defer serverQuery.Close()
|
47 |
+
v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger())
|
48 |
+
meta, err := v.MetaInfo()
|
49 |
+
|
50 |
+
assert.Nil(t, err)
|
51 |
+
assert.NotNil(t, meta)
|
52 |
+
|
53 |
+
passage := extractChildMap(t, meta, "passage")
|
54 |
+
passageModel := extractChildMap(t, passage, "model")
|
55 |
+
assert.NotNil(t, passageModel["_name_or_path"])
|
56 |
+
assert.NotNil(t, passageModel["architectures"])
|
57 |
+
assert.Contains(t, passageModel["architectures"], "DPRContextEncoder")
|
58 |
+
passageID2Label := extractChildMap(t, passageModel, "id2label")
|
59 |
+
assert.NotNil(t, passageID2Label["0"])
|
60 |
+
assert.NotNil(t, passageID2Label["1"])
|
61 |
+
|
62 |
+
query := extractChildMap(t, meta, "query")
|
63 |
+
queryModel := extractChildMap(t, query, "model")
|
64 |
+
assert.NotNil(t, queryModel["_name_or_path"])
|
65 |
+
assert.NotNil(t, queryModel["architectures"])
|
66 |
+
assert.Contains(t, queryModel["architectures"], "DPRQuestionEncoder")
|
67 |
+
queryID2Label := extractChildMap(t, queryModel, "id2label")
|
68 |
+
assert.NotNil(t, queryID2Label["0"])
|
69 |
+
assert.NotNil(t, queryID2Label["1"])
|
70 |
+
})
|
71 |
+
|
72 |
+
t.Run("when passage and query servers are unavailable", func(t *testing.T) {
|
73 |
+
rt := time.Now().Add(time.Hour)
|
74 |
+
serverPassage := httptest.NewServer(&testMetaHandler{t: t, modelType: "passage", readyTime: rt})
|
75 |
+
serverQuery := httptest.NewServer(&testMetaHandler{t: t, modelType: "query", readyTime: rt})
|
76 |
+
defer serverPassage.Close()
|
77 |
+
defer serverQuery.Close()
|
78 |
+
v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger())
|
79 |
+
meta, err := v.MetaInfo()
|
80 |
+
|
81 |
+
assert.NotNil(t, err)
|
82 |
+
assert.Contains(t, err.Error(), "[passage] unexpected status code '503' of meta request")
|
83 |
+
assert.Contains(t, err.Error(), "[query] unexpected status code '503' of meta request")
|
84 |
+
assert.Nil(t, meta)
|
85 |
+
})
|
86 |
+
}
|
87 |
+
|
88 |
+
type testMetaHandler struct {
|
89 |
+
t *testing.T
|
90 |
+
// the test handler will report as not ready before the time has passed
|
91 |
+
readyTime time.Time
|
92 |
+
modelType string
|
93 |
+
}
|
94 |
+
|
95 |
+
func (h *testMetaHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
96 |
+
assert.Equal(h.t, "/meta", r.URL.String())
|
97 |
+
assert.Equal(h.t, http.MethodGet, r.Method)
|
98 |
+
|
99 |
+
if time.Since(h.readyTime) < 0 {
|
100 |
+
w.WriteHeader(http.StatusServiceUnavailable)
|
101 |
+
return
|
102 |
+
}
|
103 |
+
|
104 |
+
w.Write([]byte(h.metaInfo()))
|
105 |
+
}
|
106 |
+
|
107 |
+
func (h *testMetaHandler) metaInfo() string {
|
108 |
+
switch h.modelType {
|
109 |
+
case "passage":
|
110 |
+
return `{
|
111 |
+
"model": {
|
112 |
+
"return_dict": true,
|
113 |
+
"output_hidden_states": false,
|
114 |
+
"output_attentions": false,
|
115 |
+
"torchscript": false,
|
116 |
+
"torch_dtype": "float32",
|
117 |
+
"use_bfloat16": false,
|
118 |
+
"pruned_heads": {},
|
119 |
+
"tie_word_embeddings": true,
|
120 |
+
"is_encoder_decoder": false,
|
121 |
+
"is_decoder": false,
|
122 |
+
"cross_attention_hidden_size": null,
|
123 |
+
"add_cross_attention": false,
|
124 |
+
"tie_encoder_decoder": false,
|
125 |
+
"max_length": 20,
|
126 |
+
"min_length": 0,
|
127 |
+
"do_sample": false,
|
128 |
+
"early_stopping": false,
|
129 |
+
"num_beams": 1,
|
130 |
+
"num_beam_groups": 1,
|
131 |
+
"diversity_penalty": 0,
|
132 |
+
"temperature": 1,
|
133 |
+
"top_k": 50,
|
134 |
+
"top_p": 1,
|
135 |
+
"repetition_penalty": 1,
|
136 |
+
"length_penalty": 1,
|
137 |
+
"no_repeat_ngram_size": 0,
|
138 |
+
"encoder_no_repeat_ngram_size": 0,
|
139 |
+
"bad_words_ids": null,
|
140 |
+
"num_return_sequences": 1,
|
141 |
+
"chunk_size_feed_forward": 0,
|
142 |
+
"output_scores": false,
|
143 |
+
"return_dict_in_generate": false,
|
144 |
+
"forced_bos_token_id": null,
|
145 |
+
"forced_eos_token_id": null,
|
146 |
+
"remove_invalid_values": false,
|
147 |
+
"architectures": [
|
148 |
+
"DPRContextEncoder"
|
149 |
+
],
|
150 |
+
"finetuning_task": null,
|
151 |
+
"id2label": {
|
152 |
+
"0": "LABEL_0",
|
153 |
+
"1": "LABEL_1"
|
154 |
+
},
|
155 |
+
"label2id": {
|
156 |
+
"LABEL_0": 0,
|
157 |
+
"LABEL_1": 1
|
158 |
+
},
|
159 |
+
"tokenizer_class": null,
|
160 |
+
"prefix": null,
|
161 |
+
"bos_token_id": null,
|
162 |
+
"pad_token_id": 0,
|
163 |
+
"eos_token_id": null,
|
164 |
+
"sep_token_id": null,
|
165 |
+
"decoder_start_token_id": null,
|
166 |
+
"task_specific_params": null,
|
167 |
+
"problem_type": null,
|
168 |
+
"_name_or_path": "./models/model",
|
169 |
+
"transformers_version": "4.16.2",
|
170 |
+
"gradient_checkpointing": false,
|
171 |
+
"model_type": "dpr",
|
172 |
+
"vocab_size": 30522,
|
173 |
+
"hidden_size": 768,
|
174 |
+
"num_hidden_layers": 12,
|
175 |
+
"num_attention_heads": 12,
|
176 |
+
"hidden_act": "gelu",
|
177 |
+
"intermediate_size": 3072,
|
178 |
+
"hidden_dropout_prob": 0.1,
|
179 |
+
"attention_probs_dropout_prob": 0.1,
|
180 |
+
"max_position_embeddings": 512,
|
181 |
+
"type_vocab_size": 2,
|
182 |
+
"initializer_range": 0.02,
|
183 |
+
"layer_norm_eps": 1e-12,
|
184 |
+
"projection_dim": 0,
|
185 |
+
"position_embedding_type": "absolute"
|
186 |
+
}
|
187 |
+
}`
|
188 |
+
case "query":
|
189 |
+
return `{
|
190 |
+
"model": {
|
191 |
+
"return_dict": true,
|
192 |
+
"output_hidden_states": false,
|
193 |
+
"output_attentions": false,
|
194 |
+
"torchscript": false,
|
195 |
+
"torch_dtype": "float32",
|
196 |
+
"use_bfloat16": false,
|
197 |
+
"pruned_heads": {},
|
198 |
+
"tie_word_embeddings": true,
|
199 |
+
"is_encoder_decoder": false,
|
200 |
+
"is_decoder": false,
|
201 |
+
"cross_attention_hidden_size": null,
|
202 |
+
"add_cross_attention": false,
|
203 |
+
"tie_encoder_decoder": false,
|
204 |
+
"max_length": 20,
|
205 |
+
"min_length": 0,
|
206 |
+
"do_sample": false,
|
207 |
+
"early_stopping": false,
|
208 |
+
"num_beams": 1,
|
209 |
+
"num_beam_groups": 1,
|
210 |
+
"diversity_penalty": 0,
|
211 |
+
"temperature": 1,
|
212 |
+
"top_k": 50,
|
213 |
+
"top_p": 1,
|
214 |
+
"repetition_penalty": 1,
|
215 |
+
"length_penalty": 1,
|
216 |
+
"no_repeat_ngram_size": 0,
|
217 |
+
"encoder_no_repeat_ngram_size": 0,
|
218 |
+
"bad_words_ids": null,
|
219 |
+
"num_return_sequences": 1,
|
220 |
+
"chunk_size_feed_forward": 0,
|
221 |
+
"output_scores": false,
|
222 |
+
"return_dict_in_generate": false,
|
223 |
+
"forced_bos_token_id": null,
|
224 |
+
"forced_eos_token_id": null,
|
225 |
+
"remove_invalid_values": false,
|
226 |
+
"architectures": [
|
227 |
+
"DPRQuestionEncoder"
|
228 |
+
],
|
229 |
+
"finetuning_task": null,
|
230 |
+
"id2label": {
|
231 |
+
"0": "LABEL_0",
|
232 |
+
"1": "LABEL_1"
|
233 |
+
},
|
234 |
+
"label2id": {
|
235 |
+
"LABEL_0": 0,
|
236 |
+
"LABEL_1": 1
|
237 |
+
},
|
238 |
+
"tokenizer_class": null,
|
239 |
+
"prefix": null,
|
240 |
+
"bos_token_id": null,
|
241 |
+
"pad_token_id": 0,
|
242 |
+
"eos_token_id": null,
|
243 |
+
"sep_token_id": null,
|
244 |
+
"decoder_start_token_id": null,
|
245 |
+
"task_specific_params": null,
|
246 |
+
"problem_type": null,
|
247 |
+
"_name_or_path": "./models/model",
|
248 |
+
"transformers_version": "4.16.2",
|
249 |
+
"gradient_checkpointing": false,
|
250 |
+
"model_type": "dpr",
|
251 |
+
"vocab_size": 30522,
|
252 |
+
"hidden_size": 768,
|
253 |
+
"num_hidden_layers": 12,
|
254 |
+
"num_attention_heads": 12,
|
255 |
+
"hidden_act": "gelu",
|
256 |
+
"intermediate_size": 3072,
|
257 |
+
"hidden_dropout_prob": 0.1,
|
258 |
+
"attention_probs_dropout_prob": 0.1,
|
259 |
+
"max_position_embeddings": 512,
|
260 |
+
"type_vocab_size": 2,
|
261 |
+
"initializer_range": 0.02,
|
262 |
+
"layer_norm_eps": 1e-12,
|
263 |
+
"projection_dim": 0,
|
264 |
+
"position_embedding_type": "absolute"
|
265 |
+
}
|
266 |
+
}`
|
267 |
+
default:
|
268 |
+
return `{
|
269 |
+
"model": {
|
270 |
+
"_name_or_path": "distilbert-base-uncased",
|
271 |
+
"activation": "gelu",
|
272 |
+
"add_cross_attention": false,
|
273 |
+
"architectures": [
|
274 |
+
"DistilBertModel"
|
275 |
+
],
|
276 |
+
"attention_dropout": 0.1,
|
277 |
+
"bad_words_ids": null,
|
278 |
+
"bos_token_id": null,
|
279 |
+
"chunk_size_feed_forward": 0,
|
280 |
+
"decoder_start_token_id": null,
|
281 |
+
"dim": 768,
|
282 |
+
"diversity_penalty": 0,
|
283 |
+
"do_sample": false,
|
284 |
+
"dropout": 0.1,
|
285 |
+
"early_stopping": false,
|
286 |
+
"encoder_no_repeat_ngram_size": 0,
|
287 |
+
"eos_token_id": null,
|
288 |
+
"finetuning_task": null,
|
289 |
+
"hidden_dim": 3072,
|
290 |
+
"id2label": {
|
291 |
+
"0": "LABEL_0",
|
292 |
+
"1": "LABEL_1"
|
293 |
+
},
|
294 |
+
"initializer_range": 0.02,
|
295 |
+
"is_decoder": false,
|
296 |
+
"is_encoder_decoder": false,
|
297 |
+
"label2id": {
|
298 |
+
"LABEL_0": 0,
|
299 |
+
"LABEL_1": 1
|
300 |
+
},
|
301 |
+
"length_penalty": 1,
|
302 |
+
"max_length": 20,
|
303 |
+
"max_position_embeddings": 512,
|
304 |
+
"min_length": 0,
|
305 |
+
"model_type": "distilbert",
|
306 |
+
"n_heads": 12,
|
307 |
+
"n_layers": 6,
|
308 |
+
"no_repeat_ngram_size": 0,
|
309 |
+
"num_beam_groups": 1,
|
310 |
+
"num_beams": 1,
|
311 |
+
"num_return_sequences": 1,
|
312 |
+
"output_attentions": false,
|
313 |
+
"output_hidden_states": false,
|
314 |
+
"output_scores": false,
|
315 |
+
"pad_token_id": 0,
|
316 |
+
"prefix": null,
|
317 |
+
"pruned_heads": {},
|
318 |
+
"qa_dropout": 0.1,
|
319 |
+
"repetition_penalty": 1,
|
320 |
+
"return_dict": true,
|
321 |
+
"return_dict_in_generate": false,
|
322 |
+
"sep_token_id": null,
|
323 |
+
"seq_classif_dropout": 0.2,
|
324 |
+
"sinusoidal_pos_embds": false,
|
325 |
+
"task_specific_params": null,
|
326 |
+
"temperature": 1,
|
327 |
+
"tie_encoder_decoder": false,
|
328 |
+
"tie_weights_": true,
|
329 |
+
"tie_word_embeddings": true,
|
330 |
+
"tokenizer_class": null,
|
331 |
+
"top_k": 50,
|
332 |
+
"top_p": 1,
|
333 |
+
"torchscript": false,
|
334 |
+
"transformers_version": "4.3.2",
|
335 |
+
"use_bfloat16": false,
|
336 |
+
"vocab_size": 30522,
|
337 |
+
"xla_device": null
|
338 |
+
}
|
339 |
+
}`
|
340 |
+
}
|
341 |
+
}
|
342 |
+
|
343 |
+
func extractChildMap(t *testing.T, parent map[string]interface{}, name string) map[string]interface{} {
|
344 |
+
assert.NotNil(t, parent[name])
|
345 |
+
child, ok := parent[name].(map[string]interface{})
|
346 |
+
assert.True(t, ok)
|
347 |
+
assert.NotNil(t, child)
|
348 |
+
|
349 |
+
return child
|
350 |
+
}
|
modules/text2vec-transformers/clients/startup.go
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package clients
|
13 |
+
|
14 |
+
import (
|
15 |
+
"context"
|
16 |
+
"net/http"
|
17 |
+
"strings"
|
18 |
+
"sync"
|
19 |
+
"time"
|
20 |
+
|
21 |
+
"github.com/pkg/errors"
|
22 |
+
)
|
23 |
+
|
24 |
+
func (v *vectorizer) WaitForStartup(initCtx context.Context,
|
25 |
+
interval time.Duration,
|
26 |
+
) error {
|
27 |
+
endpoints := map[string]string{}
|
28 |
+
if v.originPassage != v.originQuery {
|
29 |
+
endpoints["passage"] = v.urlPassage("/.well-known/ready")
|
30 |
+
endpoints["query"] = v.urlQuery("/.well-known/ready")
|
31 |
+
} else {
|
32 |
+
endpoints[""] = v.urlPassage("/.well-known/ready")
|
33 |
+
}
|
34 |
+
|
35 |
+
ch := make(chan error, len(endpoints))
|
36 |
+
var wg sync.WaitGroup
|
37 |
+
for serviceName, endpoint := range endpoints {
|
38 |
+
wg.Add(1)
|
39 |
+
go func(serviceName string, endpoint string) {
|
40 |
+
defer wg.Done()
|
41 |
+
if err := v.waitFor(initCtx, interval, endpoint, serviceName); err != nil {
|
42 |
+
ch <- err
|
43 |
+
}
|
44 |
+
}(serviceName, endpoint)
|
45 |
+
}
|
46 |
+
wg.Wait()
|
47 |
+
close(ch)
|
48 |
+
|
49 |
+
if len(ch) > 0 {
|
50 |
+
var errs []string
|
51 |
+
for err := range ch {
|
52 |
+
errs = append(errs, err.Error())
|
53 |
+
}
|
54 |
+
return errors.Errorf(strings.Join(errs, ", "))
|
55 |
+
}
|
56 |
+
return nil
|
57 |
+
}
|
58 |
+
|
59 |
+
func (v *vectorizer) waitFor(initCtx context.Context, interval time.Duration, endpoint string, serviceName string) error {
|
60 |
+
ticker := time.NewTicker(interval)
|
61 |
+
defer ticker.Stop()
|
62 |
+
expired := initCtx.Done()
|
63 |
+
var lastErr error
|
64 |
+
prefix := ""
|
65 |
+
if serviceName != "" {
|
66 |
+
prefix = "[" + serviceName + "] "
|
67 |
+
}
|
68 |
+
|
69 |
+
for {
|
70 |
+
select {
|
71 |
+
case <-ticker.C:
|
72 |
+
lastErr = v.checkReady(initCtx, endpoint, serviceName)
|
73 |
+
if lastErr == nil {
|
74 |
+
return nil
|
75 |
+
}
|
76 |
+
v.logger.
|
77 |
+
WithField("action", "transformer_remote_wait_for_startup").
|
78 |
+
WithError(lastErr).Warnf("%stransformer remote inference service not ready", prefix)
|
79 |
+
case <-expired:
|
80 |
+
return errors.Wrapf(lastErr, "%sinit context expired before remote was ready", prefix)
|
81 |
+
}
|
82 |
+
}
|
83 |
+
}
|
84 |
+
|
85 |
+
func (v *vectorizer) checkReady(initCtx context.Context, endpoint string, serviceName string) error {
|
86 |
+
// spawn a new context (derived on the overall context) which is used to
|
87 |
+
// consider an individual request timed out
|
88 |
+
// due to parent timeout being superior over request's one, request can be cancelled by parent timeout
|
89 |
+
// resulting in "send check ready request" even if service is responding with non 2xx http code
|
90 |
+
requestCtx, cancel := context.WithTimeout(initCtx, 500*time.Millisecond)
|
91 |
+
defer cancel()
|
92 |
+
|
93 |
+
req, err := http.NewRequestWithContext(requestCtx, http.MethodGet, endpoint, nil)
|
94 |
+
if err != nil {
|
95 |
+
return errors.Wrap(err, "create check ready request")
|
96 |
+
}
|
97 |
+
|
98 |
+
res, err := v.httpClient.Do(req)
|
99 |
+
if err != nil {
|
100 |
+
return errors.Wrap(err, "send check ready request")
|
101 |
+
}
|
102 |
+
|
103 |
+
defer res.Body.Close()
|
104 |
+
if res.StatusCode > 299 {
|
105 |
+
return errors.Errorf("not ready: status %d", res.StatusCode)
|
106 |
+
}
|
107 |
+
|
108 |
+
return nil
|
109 |
+
}
|
modules/text2vec-transformers/clients/startup_test.go
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package clients
|
13 |
+
|
14 |
+
import (
|
15 |
+
"context"
|
16 |
+
"net/http"
|
17 |
+
"net/http/httptest"
|
18 |
+
"regexp"
|
19 |
+
"strings"
|
20 |
+
"testing"
|
21 |
+
"time"
|
22 |
+
|
23 |
+
"github.com/sirupsen/logrus"
|
24 |
+
"github.com/sirupsen/logrus/hooks/test"
|
25 |
+
"github.com/stretchr/testify/assert"
|
26 |
+
"github.com/stretchr/testify/require"
|
27 |
+
)
|
28 |
+
|
29 |
+
func TestWaitForStartup(t *testing.T) {
|
30 |
+
t.Run("when common server is immediately ready", func(t *testing.T) {
|
31 |
+
server := httptest.NewServer(&testReadyHandler{t: t})
|
32 |
+
defer server.Close()
|
33 |
+
v := New(server.URL, server.URL, 0, nullLogger())
|
34 |
+
err := v.WaitForStartup(context.Background(), 150*time.Millisecond)
|
35 |
+
|
36 |
+
assert.Nil(t, err)
|
37 |
+
})
|
38 |
+
|
39 |
+
t.Run("when passage and query servers are immediately ready", func(t *testing.T) {
|
40 |
+
serverPassage := httptest.NewServer(&testReadyHandler{t: t})
|
41 |
+
serverQuery := httptest.NewServer(&testReadyHandler{t: t})
|
42 |
+
defer serverPassage.Close()
|
43 |
+
defer serverQuery.Close()
|
44 |
+
v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger())
|
45 |
+
err := v.WaitForStartup(context.Background(), 150*time.Millisecond)
|
46 |
+
|
47 |
+
assert.Nil(t, err)
|
48 |
+
})
|
49 |
+
|
50 |
+
t.Run("when common server is down", func(t *testing.T) {
|
51 |
+
url := "http://nothing-running-at-this-url"
|
52 |
+
v := New(url, url, 0, nullLogger())
|
53 |
+
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
54 |
+
defer cancel()
|
55 |
+
err := v.WaitForStartup(ctx, 50*time.Millisecond)
|
56 |
+
|
57 |
+
require.NotNil(t, err, nullLogger())
|
58 |
+
assert.Contains(t, err.Error(), "init context expired before remote was ready: send check ready request")
|
59 |
+
assertContainsEither(t, err.Error(), "dial tcp", "context deadline exceeded")
|
60 |
+
assert.NotContains(t, err.Error(), "[passage]")
|
61 |
+
assert.NotContains(t, err.Error(), "[query]")
|
62 |
+
})
|
63 |
+
|
64 |
+
t.Run("when passage and query servers are down", func(t *testing.T) {
|
65 |
+
urlPassage := "http://nothing-running-at-this-url"
|
66 |
+
urlQuery := "http://nothing-running-at-this-url-either"
|
67 |
+
v := New(urlPassage, urlQuery, 0, nullLogger())
|
68 |
+
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
69 |
+
defer cancel()
|
70 |
+
err := v.WaitForStartup(ctx, 50*time.Millisecond)
|
71 |
+
|
72 |
+
require.NotNil(t, err, nullLogger())
|
73 |
+
assert.Contains(t, err.Error(), "[passage] init context expired before remote was ready: send check ready request")
|
74 |
+
assert.Contains(t, err.Error(), "[query] init context expired before remote was ready: send check ready request")
|
75 |
+
assertContainsEither(t, err.Error(), "dial tcp", "context deadline exceeded")
|
76 |
+
})
|
77 |
+
|
78 |
+
t.Run("when common server is alive, but not ready", func(t *testing.T) {
|
79 |
+
server := httptest.NewServer(&testReadyHandler{
|
80 |
+
t: t,
|
81 |
+
readyTime: time.Now().Add(time.Hour),
|
82 |
+
})
|
83 |
+
defer server.Close()
|
84 |
+
v := New(server.URL, server.URL, 0, nullLogger())
|
85 |
+
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
86 |
+
defer cancel()
|
87 |
+
err := v.WaitForStartup(ctx, 50*time.Millisecond)
|
88 |
+
|
89 |
+
require.NotNil(t, err)
|
90 |
+
assert.Contains(t, err.Error(), "init context expired before remote was ready")
|
91 |
+
assertContainsEither(t, err.Error(), "not ready: status 503", "context deadline exceeded")
|
92 |
+
assert.NotContains(t, err.Error(), "[passage]")
|
93 |
+
assert.NotContains(t, err.Error(), "[query]")
|
94 |
+
})
|
95 |
+
|
96 |
+
t.Run("when passage and query servers are alive, but not ready", func(t *testing.T) {
|
97 |
+
rt := time.Now().Add(time.Hour)
|
98 |
+
serverPassage := httptest.NewServer(&testReadyHandler{
|
99 |
+
t: t,
|
100 |
+
readyTime: rt,
|
101 |
+
})
|
102 |
+
serverQuery := httptest.NewServer(&testReadyHandler{
|
103 |
+
t: t,
|
104 |
+
readyTime: rt,
|
105 |
+
})
|
106 |
+
defer serverPassage.Close()
|
107 |
+
defer serverQuery.Close()
|
108 |
+
v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger())
|
109 |
+
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
110 |
+
defer cancel()
|
111 |
+
err := v.WaitForStartup(ctx, 50*time.Millisecond)
|
112 |
+
|
113 |
+
require.NotNil(t, err)
|
114 |
+
assert.Contains(t, err.Error(), "[passage] init context expired before remote was ready")
|
115 |
+
assert.Contains(t, err.Error(), "[query] init context expired before remote was ready")
|
116 |
+
assertContainsEither(t, err.Error(), "not ready: status 503", "context deadline exceeded")
|
117 |
+
})
|
118 |
+
|
119 |
+
t.Run("when passage and query servers are alive, but query one is not ready", func(t *testing.T) {
|
120 |
+
serverPassage := httptest.NewServer(&testReadyHandler{t: t})
|
121 |
+
serverQuery := httptest.NewServer(&testReadyHandler{
|
122 |
+
t: t,
|
123 |
+
readyTime: time.Now().Add(1 * time.Minute),
|
124 |
+
})
|
125 |
+
defer serverPassage.Close()
|
126 |
+
defer serverQuery.Close()
|
127 |
+
v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger())
|
128 |
+
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
129 |
+
defer cancel()
|
130 |
+
err := v.WaitForStartup(ctx, 50*time.Millisecond)
|
131 |
+
|
132 |
+
require.NotNil(t, err)
|
133 |
+
assert.Contains(t, err.Error(), "[query] init context expired before remote was ready")
|
134 |
+
assertContainsEither(t, err.Error(), "not ready: status 503", "context deadline exceeded")
|
135 |
+
assert.NotContains(t, err.Error(), "[passage]")
|
136 |
+
})
|
137 |
+
|
138 |
+
t.Run("when common server is initially not ready, but then becomes ready", func(t *testing.T) {
|
139 |
+
server := httptest.NewServer(&testReadyHandler{
|
140 |
+
t: t,
|
141 |
+
readyTime: time.Now().Add(100 * time.Millisecond),
|
142 |
+
})
|
143 |
+
v := New(server.URL, server.URL, 0, nullLogger())
|
144 |
+
defer server.Close()
|
145 |
+
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
146 |
+
defer cancel()
|
147 |
+
err := v.WaitForStartup(ctx, 50*time.Millisecond)
|
148 |
+
|
149 |
+
require.Nil(t, err)
|
150 |
+
})
|
151 |
+
|
152 |
+
t.Run("when passage and query servers are initially not ready, but then become ready", func(t *testing.T) {
|
153 |
+
serverPassage := httptest.NewServer(&testReadyHandler{
|
154 |
+
t: t,
|
155 |
+
readyTime: time.Now().Add(100 * time.Millisecond),
|
156 |
+
})
|
157 |
+
serverQuery := httptest.NewServer(&testReadyHandler{
|
158 |
+
t: t,
|
159 |
+
readyTime: time.Now().Add(150 * time.Millisecond),
|
160 |
+
})
|
161 |
+
defer serverPassage.Close()
|
162 |
+
defer serverQuery.Close()
|
163 |
+
v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger())
|
164 |
+
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
165 |
+
defer cancel()
|
166 |
+
err := v.WaitForStartup(ctx, 50*time.Millisecond)
|
167 |
+
|
168 |
+
require.Nil(t, err)
|
169 |
+
})
|
170 |
+
}
|
171 |
+
|
172 |
+
type testReadyHandler struct {
|
173 |
+
t *testing.T
|
174 |
+
// the test handler will report as not ready before the time has passed
|
175 |
+
readyTime time.Time
|
176 |
+
}
|
177 |
+
|
178 |
+
func (f *testReadyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
179 |
+
assert.Equal(f.t, "/.well-known/ready", r.URL.String())
|
180 |
+
assert.Equal(f.t, http.MethodGet, r.Method)
|
181 |
+
|
182 |
+
if time.Since(f.readyTime) < 0 {
|
183 |
+
w.WriteHeader(http.StatusServiceUnavailable)
|
184 |
+
} else {
|
185 |
+
w.WriteHeader(http.StatusNoContent)
|
186 |
+
}
|
187 |
+
}
|
188 |
+
|
189 |
+
func nullLogger() logrus.FieldLogger {
|
190 |
+
l, _ := test.NewNullLogger()
|
191 |
+
return l
|
192 |
+
}
|
193 |
+
|
194 |
+
func assertContainsEither(t *testing.T, str string, contains ...string) {
|
195 |
+
reg := regexp.MustCompile(strings.Join(contains, "|"))
|
196 |
+
assert.Regexp(t, reg, str)
|
197 |
+
}
|
modules/text2vec-transformers/clients/transformers.go
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package clients
|
13 |
+
|
14 |
+
import (
|
15 |
+
"bytes"
|
16 |
+
"context"
|
17 |
+
"encoding/json"
|
18 |
+
"fmt"
|
19 |
+
"io"
|
20 |
+
"net/http"
|
21 |
+
"time"
|
22 |
+
|
23 |
+
"github.com/pkg/errors"
|
24 |
+
"github.com/sirupsen/logrus"
|
25 |
+
"github.com/weaviate/weaviate/modules/text2vec-transformers/ent"
|
26 |
+
)
|
27 |
+
|
28 |
+
type vectorizer struct {
|
29 |
+
originPassage string
|
30 |
+
originQuery string
|
31 |
+
httpClient *http.Client
|
32 |
+
logger logrus.FieldLogger
|
33 |
+
}
|
34 |
+
|
35 |
+
func New(originPassage, originQuery string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer {
|
36 |
+
return &vectorizer{
|
37 |
+
originPassage: originPassage,
|
38 |
+
originQuery: originQuery,
|
39 |
+
httpClient: &http.Client{
|
40 |
+
Timeout: timeout,
|
41 |
+
},
|
42 |
+
logger: logger,
|
43 |
+
}
|
44 |
+
}
|
45 |
+
|
46 |
+
func (v *vectorizer) VectorizeObject(ctx context.Context, input string,
|
47 |
+
config ent.VectorizationConfig,
|
48 |
+
) (*ent.VectorizationResult, error) {
|
49 |
+
return v.vectorize(ctx, input, config, v.urlPassage)
|
50 |
+
}
|
51 |
+
|
52 |
+
func (v *vectorizer) VectorizeQuery(ctx context.Context, input string,
|
53 |
+
config ent.VectorizationConfig,
|
54 |
+
) (*ent.VectorizationResult, error) {
|
55 |
+
return v.vectorize(ctx, input, config, v.urlQuery)
|
56 |
+
}
|
57 |
+
|
58 |
+
func (v *vectorizer) vectorize(ctx context.Context, input string,
|
59 |
+
config ent.VectorizationConfig, url func(string) string,
|
60 |
+
) (*ent.VectorizationResult, error) {
|
61 |
+
body, err := json.Marshal(vecRequest{
|
62 |
+
Text: input,
|
63 |
+
Config: vecRequestConfig{
|
64 |
+
PoolingStrategy: config.PoolingStrategy,
|
65 |
+
},
|
66 |
+
})
|
67 |
+
if err != nil {
|
68 |
+
return nil, errors.Wrapf(err, "marshal body")
|
69 |
+
}
|
70 |
+
|
71 |
+
req, err := http.NewRequestWithContext(ctx, "POST", url("/vectors"),
|
72 |
+
bytes.NewReader(body))
|
73 |
+
if err != nil {
|
74 |
+
return nil, errors.Wrap(err, "create POST request")
|
75 |
+
}
|
76 |
+
|
77 |
+
res, err := v.httpClient.Do(req)
|
78 |
+
if err != nil {
|
79 |
+
return nil, errors.Wrap(err, "send POST request")
|
80 |
+
}
|
81 |
+
defer res.Body.Close()
|
82 |
+
|
83 |
+
bodyBytes, err := io.ReadAll(res.Body)
|
84 |
+
if err != nil {
|
85 |
+
return nil, errors.Wrap(err, "read response body")
|
86 |
+
}
|
87 |
+
|
88 |
+
var resBody vecRequest
|
89 |
+
if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
|
90 |
+
return nil, errors.Wrap(err, "unmarshal response body")
|
91 |
+
}
|
92 |
+
|
93 |
+
if res.StatusCode > 399 {
|
94 |
+
return nil, errors.Errorf("fail with status %d: %s", res.StatusCode,
|
95 |
+
resBody.Error)
|
96 |
+
}
|
97 |
+
|
98 |
+
return &ent.VectorizationResult{
|
99 |
+
Text: resBody.Text,
|
100 |
+
Dimensions: resBody.Dims,
|
101 |
+
Vector: resBody.Vector,
|
102 |
+
}, nil
|
103 |
+
}
|
104 |
+
|
105 |
+
func (v *vectorizer) urlPassage(path string) string {
|
106 |
+
return fmt.Sprintf("%s%s", v.originPassage, path)
|
107 |
+
}
|
108 |
+
|
109 |
+
func (v *vectorizer) urlQuery(path string) string {
|
110 |
+
return fmt.Sprintf("%s%s", v.originQuery, path)
|
111 |
+
}
|
112 |
+
|
113 |
+
type vecRequest struct {
|
114 |
+
Text string `json:"text"`
|
115 |
+
Dims int `json:"dims"`
|
116 |
+
Vector []float32 `json:"vector"`
|
117 |
+
Error string `json:"error"`
|
118 |
+
Config vecRequestConfig `json:"config"`
|
119 |
+
}
|
120 |
+
|
121 |
+
type vecRequestConfig struct {
|
122 |
+
PoolingStrategy string `json:"pooling_strategy"`
|
123 |
+
}
|
modules/text2vec-transformers/clients/transformers_test.go
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package clients
|
13 |
+
|
14 |
+
import (
|
15 |
+
"context"
|
16 |
+
"encoding/json"
|
17 |
+
"fmt"
|
18 |
+
"io"
|
19 |
+
"net/http"
|
20 |
+
"net/http/httptest"
|
21 |
+
"testing"
|
22 |
+
"time"
|
23 |
+
|
24 |
+
"github.com/pkg/errors"
|
25 |
+
"github.com/stretchr/testify/assert"
|
26 |
+
"github.com/stretchr/testify/require"
|
27 |
+
"github.com/weaviate/weaviate/modules/text2vec-transformers/ent"
|
28 |
+
)
|
29 |
+
|
30 |
+
func TestClient(t *testing.T) {
|
31 |
+
t.Run("when all is fine", func(t *testing.T) {
|
32 |
+
server := httptest.NewServer(&fakeHandler{t: t})
|
33 |
+
defer server.Close()
|
34 |
+
c := New(server.URL, server.URL, 0, nullLogger())
|
35 |
+
expected := &ent.VectorizationResult{
|
36 |
+
Text: "This is my text",
|
37 |
+
Vector: []float32{0.1, 0.2, 0.3},
|
38 |
+
Dimensions: 3,
|
39 |
+
}
|
40 |
+
res, err := c.VectorizeObject(context.Background(), "This is my text",
|
41 |
+
ent.VectorizationConfig{
|
42 |
+
PoolingStrategy: "masked_mean",
|
43 |
+
})
|
44 |
+
|
45 |
+
assert.Nil(t, err)
|
46 |
+
assert.Equal(t, expected, res)
|
47 |
+
})
|
48 |
+
|
49 |
+
t.Run("when the context is expired", func(t *testing.T) {
|
50 |
+
server := httptest.NewServer(&fakeHandler{t: t})
|
51 |
+
defer server.Close()
|
52 |
+
c := New(server.URL, server.URL, 0, nullLogger())
|
53 |
+
ctx, cancel := context.WithDeadline(context.Background(), time.Now())
|
54 |
+
defer cancel()
|
55 |
+
|
56 |
+
_, err := c.VectorizeObject(ctx, "This is my text", ent.VectorizationConfig{})
|
57 |
+
|
58 |
+
require.NotNil(t, err)
|
59 |
+
assert.Contains(t, err.Error(), "context deadline exceeded")
|
60 |
+
})
|
61 |
+
|
62 |
+
t.Run("when the server returns an error", func(t *testing.T) {
|
63 |
+
server := httptest.NewServer(&fakeHandler{
|
64 |
+
t: t,
|
65 |
+
serverError: errors.Errorf("nope, not gonna happen"),
|
66 |
+
})
|
67 |
+
defer server.Close()
|
68 |
+
c := New(server.URL, server.URL, 0, nullLogger())
|
69 |
+
_, err := c.VectorizeObject(context.Background(), "This is my text",
|
70 |
+
ent.VectorizationConfig{})
|
71 |
+
|
72 |
+
require.NotNil(t, err)
|
73 |
+
assert.Contains(t, err.Error(), "nope, not gonna happen")
|
74 |
+
})
|
75 |
+
}
|
76 |
+
|
77 |
+
type fakeHandler struct {
|
78 |
+
t *testing.T
|
79 |
+
serverError error
|
80 |
+
}
|
81 |
+
|
82 |
+
func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
83 |
+
assert.Equal(f.t, "/vectors", r.URL.String())
|
84 |
+
assert.Equal(f.t, http.MethodPost, r.Method)
|
85 |
+
|
86 |
+
if f.serverError != nil {
|
87 |
+
w.WriteHeader(http.StatusInternalServerError)
|
88 |
+
w.Write([]byte(fmt.Sprintf(`{"error":"%s"}`, f.serverError.Error())))
|
89 |
+
return
|
90 |
+
}
|
91 |
+
|
92 |
+
bodyBytes, err := io.ReadAll(r.Body)
|
93 |
+
require.Nil(f.t, err)
|
94 |
+
defer r.Body.Close()
|
95 |
+
|
96 |
+
var b map[string]interface{}
|
97 |
+
require.Nil(f.t, json.Unmarshal(bodyBytes, &b))
|
98 |
+
|
99 |
+
textInput := b["text"].(string)
|
100 |
+
assert.Greater(f.t, len(textInput), 0)
|
101 |
+
|
102 |
+
pooling := b["config"].(map[string]interface{})["pooling_strategy"].(string)
|
103 |
+
assert.Equal(f.t, "masked_mean", pooling)
|
104 |
+
|
105 |
+
out := map[string]interface{}{
|
106 |
+
"text": textInput,
|
107 |
+
"dims": 3,
|
108 |
+
"vector": []float32{0.1, 0.2, 0.3},
|
109 |
+
}
|
110 |
+
outBytes, err := json.Marshal(out)
|
111 |
+
require.Nil(f.t, err)
|
112 |
+
|
113 |
+
w.Write(outBytes)
|
114 |
+
}
|
modules/text2vec-transformers/config.go
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package modtransformers
|
13 |
+
|
14 |
+
import (
|
15 |
+
"context"
|
16 |
+
"fmt"
|
17 |
+
|
18 |
+
"github.com/pkg/errors"
|
19 |
+
"github.com/sirupsen/logrus"
|
20 |
+
"github.com/weaviate/weaviate/entities/models"
|
21 |
+
"github.com/weaviate/weaviate/entities/modulecapabilities"
|
22 |
+
"github.com/weaviate/weaviate/entities/moduletools"
|
23 |
+
"github.com/weaviate/weaviate/entities/schema"
|
24 |
+
"github.com/weaviate/weaviate/modules/text2vec-transformers/vectorizer"
|
25 |
+
)
|
26 |
+
|
27 |
+
func (m *TransformersModule) ClassConfigDefaults() map[string]interface{} {
|
28 |
+
return map[string]interface{}{
|
29 |
+
"vectorizeClassName": vectorizer.DefaultVectorizeClassName,
|
30 |
+
"poolingStrategy": vectorizer.DefaultPoolingStrategy,
|
31 |
+
}
|
32 |
+
}
|
33 |
+
|
34 |
+
func (m *TransformersModule) PropertyConfigDefaults(
|
35 |
+
dt *schema.DataType,
|
36 |
+
) map[string]interface{} {
|
37 |
+
return map[string]interface{}{
|
38 |
+
"skip": !vectorizer.DefaultPropertyIndexed,
|
39 |
+
"vectorizePropertyName": vectorizer.DefaultVectorizePropertyName,
|
40 |
+
}
|
41 |
+
}
|
42 |
+
|
43 |
+
func (m *TransformersModule) ValidateClass(ctx context.Context,
|
44 |
+
class *models.Class, cfg moduletools.ClassConfig,
|
45 |
+
) error {
|
46 |
+
settings := vectorizer.NewClassSettings(cfg)
|
47 |
+
return NewConfigValidator(m.logger).Do(ctx, class, cfg, settings)
|
48 |
+
}
|
49 |
+
|
50 |
+
var _ = modulecapabilities.ClassConfigurator(New())
|
51 |
+
|
52 |
+
type ConfigValidator struct {
|
53 |
+
logger logrus.FieldLogger
|
54 |
+
}
|
55 |
+
|
56 |
+
type ClassSettings interface {
|
57 |
+
VectorizeClassName() bool
|
58 |
+
VectorizePropertyName(propName string) bool
|
59 |
+
PropertyIndexed(propName string) bool
|
60 |
+
}
|
61 |
+
|
62 |
+
func NewConfigValidator(logger logrus.FieldLogger) *ConfigValidator {
|
63 |
+
return &ConfigValidator{logger: logger}
|
64 |
+
}
|
65 |
+
|
66 |
+
func (cv *ConfigValidator) Do(ctx context.Context, class *models.Class,
|
67 |
+
cfg moduletools.ClassConfig, settings ClassSettings,
|
68 |
+
) error {
|
69 |
+
// In text2vec-transformers (as opposed to e.g. text2vec-contextionary) the
|
70 |
+
// assumption is that the models will be able to deal with any words, even
|
71 |
+
// previously unseen ones. Therefore we do not need to validate individual
|
72 |
+
// properties, but only the overall "index state"
|
73 |
+
|
74 |
+
if err := cv.validateIndexState(ctx, class, settings); err != nil {
|
75 |
+
return errors.Errorf("invalid combination of properties")
|
76 |
+
}
|
77 |
+
|
78 |
+
cv.checkForPossibilityOfDuplicateVectors(ctx, class, settings)
|
79 |
+
|
80 |
+
return nil
|
81 |
+
}
|
82 |
+
|
83 |
+
func (cv *ConfigValidator) validateIndexState(ctx context.Context,
|
84 |
+
class *models.Class, settings ClassSettings,
|
85 |
+
) error {
|
86 |
+
if settings.VectorizeClassName() {
|
87 |
+
// if the user chooses to vectorize the classname, vector-building will
|
88 |
+
// always be possible, no need to investigate further
|
89 |
+
|
90 |
+
return nil
|
91 |
+
}
|
92 |
+
|
93 |
+
// search if there is at least one indexed, string/text prop. If found pass
|
94 |
+
// validation
|
95 |
+
for _, prop := range class.Properties {
|
96 |
+
if len(prop.DataType) < 1 {
|
97 |
+
return errors.Errorf("property %s must have at least one datatype: "+
|
98 |
+
"got %v", prop.Name, prop.DataType)
|
99 |
+
}
|
100 |
+
|
101 |
+
if prop.DataType[0] != string(schema.DataTypeText) {
|
102 |
+
// we can only vectorize text-like props
|
103 |
+
continue
|
104 |
+
}
|
105 |
+
|
106 |
+
if settings.PropertyIndexed(prop.Name) {
|
107 |
+
// found at least one, this is a valid schema
|
108 |
+
return nil
|
109 |
+
}
|
110 |
+
}
|
111 |
+
|
112 |
+
return fmt.Errorf("invalid properties: didn't find a single property which is " +
|
113 |
+
"of type string or text and is not excluded from indexing. In addition the " +
|
114 |
+
"class name is excluded from vectorization as well, meaning that it cannot be " +
|
115 |
+
"used to determine the vector position. To fix this, set 'vectorizeClassName' " +
|
116 |
+
"to true if the class name is contextionary-valid. Alternatively add at least " +
|
117 |
+
"contextionary-valid text/string property which is not excluded from " +
|
118 |
+
"indexing.")
|
119 |
+
}
|
120 |
+
|
121 |
+
func (cv *ConfigValidator) checkForPossibilityOfDuplicateVectors(
|
122 |
+
ctx context.Context, class *models.Class, settings ClassSettings,
|
123 |
+
) {
|
124 |
+
if !settings.VectorizeClassName() {
|
125 |
+
// if the user choses not to vectorize the class name, this means they must
|
126 |
+
// have chosen something else to vectorize, otherwise the validation would
|
127 |
+
// have error'd before we ever got here. We can skip further checking.
|
128 |
+
|
129 |
+
return
|
130 |
+
}
|
131 |
+
|
132 |
+
// search if there is at least one indexed, string/text prop. If found exit
|
133 |
+
for _, prop := range class.Properties {
|
134 |
+
// length check skipped, because validation has already passed
|
135 |
+
if prop.DataType[0] != string(schema.DataTypeText) {
|
136 |
+
// we can only vectorize text-like props
|
137 |
+
continue
|
138 |
+
}
|
139 |
+
|
140 |
+
if settings.PropertyIndexed(prop.Name) {
|
141 |
+
// found at least one
|
142 |
+
return
|
143 |
+
}
|
144 |
+
}
|
145 |
+
|
146 |
+
cv.logger.WithField("module", "text2vec-transformers").
|
147 |
+
WithField("class", class.Class).
|
148 |
+
Warnf("text2vec-contextionary: Class %q does not have any properties "+
|
149 |
+
"indexed (or only non text-properties indexed) and the vector position is "+
|
150 |
+
"only determined by the class name. Each object will end up with the same "+
|
151 |
+
"vector which leads to a severe performance penalty on imports. Consider "+
|
152 |
+
"setting vectorIndexConfig.skip=true for this property", class.Class)
|
153 |
+
}
|
modules/text2vec-transformers/config_test.go
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package modtransformers
|
13 |
+
|
14 |
+
import (
|
15 |
+
"context"
|
16 |
+
"testing"
|
17 |
+
|
18 |
+
"github.com/sirupsen/logrus"
|
19 |
+
ltest "github.com/sirupsen/logrus/hooks/test"
|
20 |
+
"github.com/stretchr/testify/assert"
|
21 |
+
"github.com/stretchr/testify/require"
|
22 |
+
"github.com/weaviate/weaviate/entities/models"
|
23 |
+
"github.com/weaviate/weaviate/entities/schema"
|
24 |
+
)
|
25 |
+
|
26 |
+
func TestConfigDefaults(t *testing.T) {
|
27 |
+
t.Run("for properties", func(t *testing.T) {
|
28 |
+
def := New().ClassConfigDefaults()
|
29 |
+
|
30 |
+
assert.Equal(t, true, def["vectorizeClassName"])
|
31 |
+
assert.Equal(t, "masked_mean", def["poolingStrategy"])
|
32 |
+
})
|
33 |
+
|
34 |
+
t.Run("for the class", func(t *testing.T) {
|
35 |
+
dt := schema.DataTypeText
|
36 |
+
def := New().PropertyConfigDefaults(&dt)
|
37 |
+
assert.Equal(t, false, def["vectorizePropertyName"])
|
38 |
+
assert.Equal(t, false, def["skip"])
|
39 |
+
})
|
40 |
+
}
|
41 |
+
|
42 |
+
func TestConfigValidator(t *testing.T) {
|
43 |
+
t.Run("all usable props no-indexed", func(t *testing.T) {
|
44 |
+
t.Run("all schema vectorization turned off", func(t *testing.T) {
|
45 |
+
class := &models.Class{
|
46 |
+
Vectorizer: "text2vec-contextionary",
|
47 |
+
Class: "ValidName",
|
48 |
+
Properties: []*models.Property{
|
49 |
+
{
|
50 |
+
DataType: []string{"text"},
|
51 |
+
Name: "description",
|
52 |
+
},
|
53 |
+
{
|
54 |
+
DataType: schema.DataTypeText.PropString(),
|
55 |
+
Tokenization: models.PropertyTokenizationWhitespace,
|
56 |
+
Name: "name",
|
57 |
+
},
|
58 |
+
{
|
59 |
+
DataType: []string{"int"},
|
60 |
+
Name: "amount",
|
61 |
+
},
|
62 |
+
},
|
63 |
+
}
|
64 |
+
|
65 |
+
logger, _ := ltest.NewNullLogger()
|
66 |
+
v := NewConfigValidator(logger)
|
67 |
+
err := v.Do(context.Background(), class, nil, &fakeIndexChecker{
|
68 |
+
vectorizePropertyName: false,
|
69 |
+
vectorizeClassName: false,
|
70 |
+
propertyIndexed: false,
|
71 |
+
})
|
72 |
+
assert.NotNil(t, err)
|
73 |
+
})
|
74 |
+
})
|
75 |
+
}
|
76 |
+
|
77 |
+
func TestConfigValidator_RiskOfDuplicateVectors(t *testing.T) {
|
78 |
+
type test struct {
|
79 |
+
name string
|
80 |
+
in *models.Class
|
81 |
+
expectWarning bool
|
82 |
+
indexChecker *fakeIndexChecker
|
83 |
+
}
|
84 |
+
|
85 |
+
tests := []test{
|
86 |
+
{
|
87 |
+
name: "usable properties",
|
88 |
+
in: &models.Class{
|
89 |
+
Class: "ValidName",
|
90 |
+
Properties: []*models.Property{
|
91 |
+
{
|
92 |
+
DataType: []string{string(schema.DataTypeText)},
|
93 |
+
Name: "textProp",
|
94 |
+
},
|
95 |
+
},
|
96 |
+
},
|
97 |
+
expectWarning: false,
|
98 |
+
indexChecker: &fakeIndexChecker{
|
99 |
+
vectorizePropertyName: false,
|
100 |
+
vectorizeClassName: true,
|
101 |
+
propertyIndexed: true,
|
102 |
+
},
|
103 |
+
},
|
104 |
+
{
|
105 |
+
name: "no properties",
|
106 |
+
in: &models.Class{
|
107 |
+
Class: "ValidName",
|
108 |
+
},
|
109 |
+
expectWarning: true,
|
110 |
+
indexChecker: &fakeIndexChecker{
|
111 |
+
vectorizePropertyName: false,
|
112 |
+
vectorizeClassName: true,
|
113 |
+
propertyIndexed: false,
|
114 |
+
},
|
115 |
+
},
|
116 |
+
{
|
117 |
+
name: "usable properties, but they are no-indexed",
|
118 |
+
in: &models.Class{
|
119 |
+
Class: "ValidName",
|
120 |
+
Properties: []*models.Property{
|
121 |
+
{
|
122 |
+
DataType: []string{string(schema.DataTypeText)},
|
123 |
+
Name: "textProp",
|
124 |
+
},
|
125 |
+
},
|
126 |
+
},
|
127 |
+
expectWarning: true,
|
128 |
+
indexChecker: &fakeIndexChecker{
|
129 |
+
vectorizePropertyName: false,
|
130 |
+
vectorizeClassName: true,
|
131 |
+
propertyIndexed: false,
|
132 |
+
},
|
133 |
+
},
|
134 |
+
{
|
135 |
+
name: "only unusable properties",
|
136 |
+
in: &models.Class{
|
137 |
+
Class: "ValidName",
|
138 |
+
Properties: []*models.Property{
|
139 |
+
{
|
140 |
+
DataType: []string{string(schema.DataTypeInt)},
|
141 |
+
Name: "intProp",
|
142 |
+
},
|
143 |
+
},
|
144 |
+
},
|
145 |
+
expectWarning: true,
|
146 |
+
indexChecker: &fakeIndexChecker{
|
147 |
+
vectorizePropertyName: false,
|
148 |
+
vectorizeClassName: true,
|
149 |
+
propertyIndexed: false,
|
150 |
+
},
|
151 |
+
},
|
152 |
+
}
|
153 |
+
|
154 |
+
for _, test := range tests {
|
155 |
+
t.Run(test.name, func(t *testing.T) {
|
156 |
+
logger, hook := ltest.NewNullLogger()
|
157 |
+
v := NewConfigValidator(logger)
|
158 |
+
err := v.Do(context.Background(), test.in, nil, test.indexChecker)
|
159 |
+
require.Nil(t, err)
|
160 |
+
|
161 |
+
entry := hook.LastEntry()
|
162 |
+
if test.expectWarning {
|
163 |
+
require.NotNil(t, entry)
|
164 |
+
assert.Equal(t, logrus.WarnLevel, entry.Level)
|
165 |
+
} else {
|
166 |
+
assert.Nil(t, entry)
|
167 |
+
}
|
168 |
+
})
|
169 |
+
}
|
170 |
+
}
|
171 |
+
|
172 |
+
type fakeIndexChecker struct {
|
173 |
+
vectorizeClassName bool
|
174 |
+
vectorizePropertyName bool
|
175 |
+
propertyIndexed bool
|
176 |
+
}
|
177 |
+
|
178 |
+
func (f *fakeIndexChecker) VectorizeClassName() bool {
|
179 |
+
return f.vectorizeClassName
|
180 |
+
}
|
181 |
+
|
182 |
+
func (f *fakeIndexChecker) VectorizePropertyName(propName string) bool {
|
183 |
+
return f.vectorizePropertyName
|
184 |
+
}
|
185 |
+
|
186 |
+
func (f *fakeIndexChecker) PropertyIndexed(propName string) bool {
|
187 |
+
return f.propertyIndexed
|
188 |
+
}
|
modules/text2vec-transformers/ent/vectorization_config.go
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package ent
|
13 |
+
|
14 |
+
type VectorizationConfig struct {
|
15 |
+
PoolingStrategy string
|
16 |
+
}
|
modules/text2vec-transformers/ent/vectorization_result.go
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package ent
|
13 |
+
|
14 |
+
type VectorizationResult struct {
|
15 |
+
Text string
|
16 |
+
Dimensions int
|
17 |
+
Vector []float32
|
18 |
+
}
|
modules/text2vec-transformers/module.go
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package modtransformers
|
13 |
+
|
14 |
+
import (
|
15 |
+
"context"
|
16 |
+
"net/http"
|
17 |
+
"os"
|
18 |
+
"time"
|
19 |
+
|
20 |
+
"github.com/pkg/errors"
|
21 |
+
"github.com/sirupsen/logrus"
|
22 |
+
"github.com/weaviate/weaviate/entities/models"
|
23 |
+
"github.com/weaviate/weaviate/entities/modulecapabilities"
|
24 |
+
"github.com/weaviate/weaviate/entities/moduletools"
|
25 |
+
"github.com/weaviate/weaviate/modules/text2vec-transformers/clients"
|
26 |
+
"github.com/weaviate/weaviate/modules/text2vec-transformers/vectorizer"
|
27 |
+
"github.com/weaviate/weaviate/usecases/modulecomponents/additional"
|
28 |
+
)
|
29 |
+
|
30 |
+
func New() *TransformersModule {
|
31 |
+
return &TransformersModule{}
|
32 |
+
}
|
33 |
+
|
34 |
+
type TransformersModule struct {
|
35 |
+
vectorizer textVectorizer
|
36 |
+
metaProvider metaProvider
|
37 |
+
graphqlProvider modulecapabilities.GraphQLArguments
|
38 |
+
searcher modulecapabilities.Searcher
|
39 |
+
nearTextTransformer modulecapabilities.TextTransform
|
40 |
+
logger logrus.FieldLogger
|
41 |
+
additionalPropertiesProvider modulecapabilities.AdditionalProperties
|
42 |
+
}
|
43 |
+
|
44 |
+
type textVectorizer interface {
|
45 |
+
Object(ctx context.Context, obj *models.Object, objDiff *moduletools.ObjectDiff,
|
46 |
+
cfg moduletools.ClassConfig) error
|
47 |
+
Texts(ctx context.Context, input []string,
|
48 |
+
cfg moduletools.ClassConfig) ([]float32, error)
|
49 |
+
}
|
50 |
+
|
51 |
+
type metaProvider interface {
|
52 |
+
MetaInfo() (map[string]interface{}, error)
|
53 |
+
}
|
54 |
+
|
55 |
+
func (m *TransformersModule) Name() string {
|
56 |
+
return "text2vec-transformers"
|
57 |
+
}
|
58 |
+
|
59 |
+
func (m *TransformersModule) Type() modulecapabilities.ModuleType {
|
60 |
+
return modulecapabilities.Text2Vec
|
61 |
+
}
|
62 |
+
|
63 |
+
func (m *TransformersModule) Init(ctx context.Context,
|
64 |
+
params moduletools.ModuleInitParams,
|
65 |
+
) error {
|
66 |
+
m.logger = params.GetLogger()
|
67 |
+
|
68 |
+
if err := m.initVectorizer(ctx, params.GetConfig().ModuleHttpClientTimeout, m.logger); err != nil {
|
69 |
+
return errors.Wrap(err, "init vectorizer")
|
70 |
+
}
|
71 |
+
|
72 |
+
if err := m.initAdditionalPropertiesProvider(); err != nil {
|
73 |
+
return errors.Wrap(err, "init additional properties provider")
|
74 |
+
}
|
75 |
+
|
76 |
+
return nil
|
77 |
+
}
|
78 |
+
|
79 |
+
func (m *TransformersModule) InitExtension(modules []modulecapabilities.Module) error {
|
80 |
+
for _, module := range modules {
|
81 |
+
if module.Name() == m.Name() {
|
82 |
+
continue
|
83 |
+
}
|
84 |
+
if arg, ok := module.(modulecapabilities.TextTransformers); ok {
|
85 |
+
if arg != nil && arg.TextTransformers() != nil {
|
86 |
+
m.nearTextTransformer = arg.TextTransformers()["nearText"]
|
87 |
+
}
|
88 |
+
}
|
89 |
+
}
|
90 |
+
|
91 |
+
if err := m.initNearText(); err != nil {
|
92 |
+
return errors.Wrap(err, "init graphql provider")
|
93 |
+
}
|
94 |
+
return nil
|
95 |
+
}
|
96 |
+
|
97 |
+
func (m *TransformersModule) initVectorizer(ctx context.Context, timeout time.Duration,
|
98 |
+
logger logrus.FieldLogger,
|
99 |
+
) error {
|
100 |
+
// TODO: gh-1486 proper config management
|
101 |
+
uriPassage := os.Getenv("TRANSFORMERS_PASSAGE_INFERENCE_API")
|
102 |
+
uriQuery := os.Getenv("TRANSFORMERS_QUERY_INFERENCE_API")
|
103 |
+
uriCommon := os.Getenv("TRANSFORMERS_INFERENCE_API")
|
104 |
+
|
105 |
+
if uriCommon == "" {
|
106 |
+
if uriPassage == "" && uriQuery == "" {
|
107 |
+
return errors.Errorf("required variable TRANSFORMERS_INFERENCE_API or both variables TRANSFORMERS_PASSAGE_INFERENCE_API and TRANSFORMERS_QUERY_INFERENCE_API are not set")
|
108 |
+
}
|
109 |
+
if uriPassage != "" && uriQuery == "" {
|
110 |
+
return errors.Errorf("required variable TRANSFORMERS_QUERY_INFERENCE_API is not set")
|
111 |
+
}
|
112 |
+
if uriPassage == "" && uriQuery != "" {
|
113 |
+
return errors.Errorf("required variable TRANSFORMERS_PASSAGE_INFERENCE_API is not set")
|
114 |
+
}
|
115 |
+
} else {
|
116 |
+
if uriPassage != "" || uriQuery != "" {
|
117 |
+
return errors.Errorf("either variable TRANSFORMERS_INFERENCE_API or both variables TRANSFORMERS_PASSAGE_INFERENCE_API and TRANSFORMERS_QUERY_INFERENCE_API should be set")
|
118 |
+
}
|
119 |
+
uriPassage = uriCommon
|
120 |
+
uriQuery = uriCommon
|
121 |
+
}
|
122 |
+
|
123 |
+
client := clients.New(uriPassage, uriQuery, timeout, logger)
|
124 |
+
if err := client.WaitForStartup(ctx, 1*time.Second); err != nil {
|
125 |
+
return errors.Wrap(err, "init remote vectorizer")
|
126 |
+
}
|
127 |
+
|
128 |
+
m.vectorizer = vectorizer.New(client)
|
129 |
+
m.metaProvider = client
|
130 |
+
|
131 |
+
return nil
|
132 |
+
}
|
133 |
+
|
134 |
+
func (m *TransformersModule) initAdditionalPropertiesProvider() error {
|
135 |
+
m.additionalPropertiesProvider = additional.NewText2VecProvider()
|
136 |
+
return nil
|
137 |
+
}
|
138 |
+
|
139 |
+
func (m *TransformersModule) RootHandler() http.Handler {
|
140 |
+
// TODO: remove once this is a capability interface
|
141 |
+
return nil
|
142 |
+
}
|
143 |
+
|
144 |
+
func (m *TransformersModule) VectorizeObject(ctx context.Context,
|
145 |
+
obj *models.Object, objDiff *moduletools.ObjectDiff, cfg moduletools.ClassConfig,
|
146 |
+
) error {
|
147 |
+
return m.vectorizer.Object(ctx, obj, objDiff, cfg)
|
148 |
+
}
|
149 |
+
|
150 |
+
func (m *TransformersModule) MetaInfo() (map[string]interface{}, error) {
|
151 |
+
return m.metaProvider.MetaInfo()
|
152 |
+
}
|
153 |
+
|
154 |
+
func (m *TransformersModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
|
155 |
+
return m.additionalPropertiesProvider.AdditionalProperties()
|
156 |
+
}
|
157 |
+
|
158 |
+
func (m *TransformersModule) VectorizeInput(ctx context.Context,
|
159 |
+
input string, cfg moduletools.ClassConfig,
|
160 |
+
) ([]float32, error) {
|
161 |
+
return m.vectorizer.Texts(ctx, []string{input}, cfg)
|
162 |
+
}
|
163 |
+
|
164 |
+
// verify we implement the modules.Module interface
|
165 |
+
var (
|
166 |
+
_ = modulecapabilities.Module(New())
|
167 |
+
_ = modulecapabilities.Vectorizer(New())
|
168 |
+
_ = modulecapabilities.MetaProvider(New())
|
169 |
+
)
|
modules/text2vec-transformers/nearText.go
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package modtransformers
|
13 |
+
|
14 |
+
import (
|
15 |
+
"github.com/weaviate/weaviate/entities/modulecapabilities"
|
16 |
+
"github.com/weaviate/weaviate/usecases/modulecomponents/nearText"
|
17 |
+
)
|
18 |
+
|
19 |
+
func (m *TransformersModule) initNearText() error {
|
20 |
+
m.searcher = nearText.NewSearcher(m.vectorizer)
|
21 |
+
m.graphqlProvider = nearText.New(m.nearTextTransformer)
|
22 |
+
return nil
|
23 |
+
}
|
24 |
+
|
25 |
+
func (m *TransformersModule) Arguments() map[string]modulecapabilities.GraphQLArgument {
|
26 |
+
return m.graphqlProvider.Arguments()
|
27 |
+
}
|
28 |
+
|
29 |
+
func (m *TransformersModule) VectorSearches() map[string]modulecapabilities.VectorForParams {
|
30 |
+
return m.searcher.VectorSearches()
|
31 |
+
}
|
32 |
+
|
33 |
+
var (
|
34 |
+
_ = modulecapabilities.GraphQLArguments(New())
|
35 |
+
_ = modulecapabilities.Searcher(New())
|
36 |
+
)
|
modules/text2vec-transformers/vectorizer/class_settings.go
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package vectorizer
|
13 |
+
|
14 |
+
import (
|
15 |
+
"github.com/weaviate/weaviate/entities/moduletools"
|
16 |
+
)
|
17 |
+
|
18 |
+
const (
|
19 |
+
DefaultPropertyIndexed = true
|
20 |
+
DefaultVectorizeClassName = true
|
21 |
+
DefaultVectorizePropertyName = false
|
22 |
+
DefaultPoolingStrategy = "masked_mean"
|
23 |
+
)
|
24 |
+
|
25 |
+
type classSettings struct {
|
26 |
+
cfg moduletools.ClassConfig
|
27 |
+
}
|
28 |
+
|
29 |
+
func NewClassSettings(cfg moduletools.ClassConfig) *classSettings {
|
30 |
+
return &classSettings{cfg: cfg}
|
31 |
+
}
|
32 |
+
|
33 |
+
func (ic *classSettings) PropertyIndexed(propName string) bool {
|
34 |
+
if ic.cfg == nil {
|
35 |
+
// we would receive a nil-config on cross-class requests, such as Explore{}
|
36 |
+
return DefaultPropertyIndexed
|
37 |
+
}
|
38 |
+
|
39 |
+
vcn, ok := ic.cfg.Property(propName)["skip"]
|
40 |
+
if !ok {
|
41 |
+
return DefaultPropertyIndexed
|
42 |
+
}
|
43 |
+
|
44 |
+
asBool, ok := vcn.(bool)
|
45 |
+
if !ok {
|
46 |
+
return DefaultPropertyIndexed
|
47 |
+
}
|
48 |
+
|
49 |
+
return !asBool
|
50 |
+
}
|
51 |
+
|
52 |
+
func (ic *classSettings) VectorizePropertyName(propName string) bool {
|
53 |
+
if ic.cfg == nil {
|
54 |
+
// we would receive a nil-config on cross-class requests, such as Explore{}
|
55 |
+
return DefaultVectorizePropertyName
|
56 |
+
}
|
57 |
+
vcn, ok := ic.cfg.Property(propName)["vectorizePropertyName"]
|
58 |
+
if !ok {
|
59 |
+
return DefaultVectorizePropertyName
|
60 |
+
}
|
61 |
+
|
62 |
+
asBool, ok := vcn.(bool)
|
63 |
+
if !ok {
|
64 |
+
return DefaultVectorizePropertyName
|
65 |
+
}
|
66 |
+
|
67 |
+
return asBool
|
68 |
+
}
|
69 |
+
|
70 |
+
func (ic *classSettings) VectorizeClassName() bool {
|
71 |
+
if ic.cfg == nil {
|
72 |
+
// we would receive a nil-config on cross-class requests, such as Explore{}
|
73 |
+
return DefaultVectorizeClassName
|
74 |
+
}
|
75 |
+
|
76 |
+
vcn, ok := ic.cfg.Class()["vectorizeClassName"]
|
77 |
+
if !ok {
|
78 |
+
return DefaultVectorizeClassName
|
79 |
+
}
|
80 |
+
|
81 |
+
asBool, ok := vcn.(bool)
|
82 |
+
if !ok {
|
83 |
+
return DefaultVectorizeClassName
|
84 |
+
}
|
85 |
+
|
86 |
+
return asBool
|
87 |
+
}
|
88 |
+
|
89 |
+
func (ic *classSettings) PoolingStrategy() string {
|
90 |
+
if ic.cfg == nil {
|
91 |
+
// we would receive a nil-config on cross-class requests, such as Explore{}
|
92 |
+
return DefaultPoolingStrategy
|
93 |
+
}
|
94 |
+
|
95 |
+
vcn, ok := ic.cfg.Class()["poolingStrategy"]
|
96 |
+
if !ok {
|
97 |
+
return DefaultPoolingStrategy
|
98 |
+
}
|
99 |
+
|
100 |
+
asString, ok := vcn.(string)
|
101 |
+
if !ok {
|
102 |
+
return DefaultPoolingStrategy
|
103 |
+
}
|
104 |
+
|
105 |
+
return asString
|
106 |
+
}
|
modules/text2vec-transformers/vectorizer/class_settings_test.go
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package vectorizer
|
13 |
+
|
14 |
+
import (
|
15 |
+
"testing"
|
16 |
+
|
17 |
+
"github.com/stretchr/testify/assert"
|
18 |
+
"github.com/weaviate/weaviate/entities/models"
|
19 |
+
"github.com/weaviate/weaviate/usecases/modules"
|
20 |
+
)
|
21 |
+
|
22 |
+
func TestClassSettings(t *testing.T) {
|
23 |
+
t.Run("with all defaults", func(t *testing.T) {
|
24 |
+
class := &models.Class{
|
25 |
+
Class: "MyClass",
|
26 |
+
Properties: []*models.Property{{
|
27 |
+
Name: "someProp",
|
28 |
+
}},
|
29 |
+
}
|
30 |
+
|
31 |
+
cfg := modules.NewClassBasedModuleConfig(class, "my-module", "tenant")
|
32 |
+
ic := NewClassSettings(cfg)
|
33 |
+
|
34 |
+
assert.True(t, ic.PropertyIndexed("someProp"))
|
35 |
+
assert.False(t, ic.VectorizePropertyName("someProp"))
|
36 |
+
assert.True(t, ic.VectorizeClassName())
|
37 |
+
assert.Equal(t, ic.PoolingStrategy(), "masked_mean")
|
38 |
+
})
|
39 |
+
|
40 |
+
t.Run("with a nil config", func(t *testing.T) {
|
41 |
+
// this is the case if we were running in a situation such as a
|
42 |
+
// cross-class vectorization of search time, as is the case with Explore
|
43 |
+
// {}, we then expect all default values
|
44 |
+
|
45 |
+
ic := NewClassSettings(nil)
|
46 |
+
|
47 |
+
assert.True(t, ic.PropertyIndexed("someProp"))
|
48 |
+
assert.False(t, ic.VectorizePropertyName("someProp"))
|
49 |
+
assert.True(t, ic.VectorizeClassName())
|
50 |
+
assert.Equal(t, ic.PoolingStrategy(), "masked_mean")
|
51 |
+
})
|
52 |
+
|
53 |
+
t.Run("with all explicit config matching the defaults", func(t *testing.T) {
|
54 |
+
class := &models.Class{
|
55 |
+
Class: "MyClass",
|
56 |
+
ModuleConfig: map[string]interface{}{
|
57 |
+
"my-module": map[string]interface{}{
|
58 |
+
"vectorizeClassName": true,
|
59 |
+
"poolingStrategy": "masked_mean",
|
60 |
+
},
|
61 |
+
},
|
62 |
+
Properties: []*models.Property{{
|
63 |
+
Name: "someProp",
|
64 |
+
ModuleConfig: map[string]interface{}{
|
65 |
+
"my-module": map[string]interface{}{
|
66 |
+
"skip": false,
|
67 |
+
"vectorizePropertyName": false,
|
68 |
+
},
|
69 |
+
},
|
70 |
+
}},
|
71 |
+
}
|
72 |
+
|
73 |
+
cfg := modules.NewClassBasedModuleConfig(class, "my-module", "tenant")
|
74 |
+
ic := NewClassSettings(cfg)
|
75 |
+
|
76 |
+
assert.True(t, ic.PropertyIndexed("someProp"))
|
77 |
+
assert.False(t, ic.VectorizePropertyName("someProp"))
|
78 |
+
assert.True(t, ic.VectorizeClassName())
|
79 |
+
assert.Equal(t, ic.PoolingStrategy(), "masked_mean")
|
80 |
+
})
|
81 |
+
|
82 |
+
t.Run("with all explicit config using non-default values", func(t *testing.T) {
|
83 |
+
class := &models.Class{
|
84 |
+
Class: "MyClass",
|
85 |
+
ModuleConfig: map[string]interface{}{
|
86 |
+
"my-module": map[string]interface{}{
|
87 |
+
"vectorizeClassName": false,
|
88 |
+
"poolingStrategy": "cls",
|
89 |
+
},
|
90 |
+
},
|
91 |
+
Properties: []*models.Property{{
|
92 |
+
Name: "someProp",
|
93 |
+
ModuleConfig: map[string]interface{}{
|
94 |
+
"my-module": map[string]interface{}{
|
95 |
+
"skip": true,
|
96 |
+
"vectorizePropertyName": true,
|
97 |
+
},
|
98 |
+
},
|
99 |
+
}},
|
100 |
+
}
|
101 |
+
|
102 |
+
cfg := modules.NewClassBasedModuleConfig(class, "my-module", "tenant")
|
103 |
+
ic := NewClassSettings(cfg)
|
104 |
+
|
105 |
+
assert.False(t, ic.PropertyIndexed("someProp"))
|
106 |
+
assert.True(t, ic.VectorizePropertyName("someProp"))
|
107 |
+
assert.False(t, ic.VectorizeClassName())
|
108 |
+
assert.Equal(t, ic.PoolingStrategy(), "cls")
|
109 |
+
})
|
110 |
+
}
|
modules/text2vec-transformers/vectorizer/fakes_for_test.go
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package vectorizer
|
13 |
+
|
14 |
+
import (
|
15 |
+
"context"
|
16 |
+
|
17 |
+
"github.com/weaviate/weaviate/modules/text2vec-transformers/ent"
|
18 |
+
)
|
19 |
+
|
20 |
+
type fakeClient struct {
|
21 |
+
lastInput string
|
22 |
+
lastConfig ent.VectorizationConfig
|
23 |
+
}
|
24 |
+
|
25 |
+
func (c *fakeClient) VectorizeObject(ctx context.Context,
|
26 |
+
text string, cfg ent.VectorizationConfig,
|
27 |
+
) (*ent.VectorizationResult, error) {
|
28 |
+
c.lastInput = text
|
29 |
+
c.lastConfig = cfg
|
30 |
+
return &ent.VectorizationResult{
|
31 |
+
Vector: []float32{0, 1, 2, 3},
|
32 |
+
Dimensions: 4,
|
33 |
+
Text: text,
|
34 |
+
}, nil
|
35 |
+
}
|
36 |
+
|
37 |
+
func (c *fakeClient) VectorizeQuery(ctx context.Context,
|
38 |
+
text string, cfg ent.VectorizationConfig,
|
39 |
+
) (*ent.VectorizationResult, error) {
|
40 |
+
return c.VectorizeObject(ctx, text, cfg)
|
41 |
+
}
|
42 |
+
|
43 |
+
type fakeClassConfig struct {
|
44 |
+
classConfig map[string]interface{}
|
45 |
+
vectorizeClassName bool
|
46 |
+
vectorizePropertyName bool
|
47 |
+
skippedProperty string
|
48 |
+
excludedProperty string
|
49 |
+
poolingStrategy string
|
50 |
+
}
|
51 |
+
|
52 |
+
func (f fakeClassConfig) Class() map[string]interface{} {
|
53 |
+
classSettings := map[string]interface{}{
|
54 |
+
"vectorizeClassName": f.vectorizeClassName,
|
55 |
+
"poolingStrategy": f.poolingStrategy,
|
56 |
+
}
|
57 |
+
return classSettings
|
58 |
+
}
|
59 |
+
|
60 |
+
func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} {
|
61 |
+
return f.classConfig
|
62 |
+
}
|
63 |
+
|
64 |
+
func (f fakeClassConfig) Property(propName string) map[string]interface{} {
|
65 |
+
if propName == f.skippedProperty {
|
66 |
+
return map[string]interface{}{
|
67 |
+
"skip": true,
|
68 |
+
}
|
69 |
+
}
|
70 |
+
if propName == f.excludedProperty {
|
71 |
+
return map[string]interface{}{
|
72 |
+
"vectorizePropertyName": false,
|
73 |
+
}
|
74 |
+
}
|
75 |
+
if f.vectorizePropertyName {
|
76 |
+
return map[string]interface{}{
|
77 |
+
"vectorizePropertyName": true,
|
78 |
+
}
|
79 |
+
}
|
80 |
+
return nil
|
81 |
+
}
|
82 |
+
|
83 |
+
func (f fakeClassConfig) Tenant() string {
|
84 |
+
return ""
|
85 |
+
}
|
modules/text2vec-transformers/vectorizer/objects.go
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package vectorizer
|
13 |
+
|
14 |
+
import (
|
15 |
+
"context"
|
16 |
+
|
17 |
+
"github.com/weaviate/weaviate/entities/models"
|
18 |
+
"github.com/weaviate/weaviate/entities/moduletools"
|
19 |
+
"github.com/weaviate/weaviate/modules/text2vec-transformers/ent"
|
20 |
+
objectsvectorizer "github.com/weaviate/weaviate/usecases/modulecomponents/vectorizer"
|
21 |
+
)
|
22 |
+
|
23 |
+
type Vectorizer struct {
|
24 |
+
client Client
|
25 |
+
objectVectorizer *objectsvectorizer.ObjectVectorizer
|
26 |
+
}
|
27 |
+
|
28 |
+
func New(client Client) *Vectorizer {
|
29 |
+
return &Vectorizer{
|
30 |
+
client: client,
|
31 |
+
objectVectorizer: objectsvectorizer.New(),
|
32 |
+
}
|
33 |
+
}
|
34 |
+
|
35 |
+
type Client interface {
|
36 |
+
VectorizeObject(ctx context.Context, input string,
|
37 |
+
cfg ent.VectorizationConfig) (*ent.VectorizationResult, error)
|
38 |
+
VectorizeQuery(ctx context.Context, input string,
|
39 |
+
cfg ent.VectorizationConfig) (*ent.VectorizationResult, error)
|
40 |
+
}
|
41 |
+
|
42 |
+
// IndexCheck returns whether a property of a class should be indexed
|
43 |
+
type ClassSettings interface {
|
44 |
+
PropertyIndexed(property string) bool
|
45 |
+
VectorizeClassName() bool
|
46 |
+
VectorizePropertyName(propertyName string) bool
|
47 |
+
PoolingStrategy() string
|
48 |
+
}
|
49 |
+
|
50 |
+
func (v *Vectorizer) Object(ctx context.Context, object *models.Object,
|
51 |
+
objDiff *moduletools.ObjectDiff, cfg moduletools.ClassConfig,
|
52 |
+
) error {
|
53 |
+
vec, err := v.object(ctx, object.Class, object.Properties, objDiff, cfg)
|
54 |
+
if err != nil {
|
55 |
+
return err
|
56 |
+
}
|
57 |
+
|
58 |
+
object.Vector = vec
|
59 |
+
return nil
|
60 |
+
}
|
61 |
+
|
62 |
+
func (v *Vectorizer) object(ctx context.Context, className string,
|
63 |
+
schema interface{}, objDiff *moduletools.ObjectDiff, cfg moduletools.ClassConfig,
|
64 |
+
) ([]float32, error) {
|
65 |
+
text, vector, err := v.objectVectorizer.TextsOrVector(ctx, className, schema, objDiff, NewClassSettings(cfg))
|
66 |
+
if err != nil {
|
67 |
+
return nil, err
|
68 |
+
}
|
69 |
+
if vector != nil {
|
70 |
+
// dont' re-vectorize
|
71 |
+
return vector, nil
|
72 |
+
}
|
73 |
+
// vectorize text
|
74 |
+
res, err := v.client.VectorizeObject(ctx, text, ent.VectorizationConfig{
|
75 |
+
PoolingStrategy: NewClassSettings(cfg).PoolingStrategy(),
|
76 |
+
})
|
77 |
+
if err != nil {
|
78 |
+
return nil, err
|
79 |
+
}
|
80 |
+
|
81 |
+
return res.Vector, nil
|
82 |
+
}
|
modules/text2vec-transformers/vectorizer/objects_test.go
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package vectorizer
|
13 |
+
|
14 |
+
import (
|
15 |
+
"context"
|
16 |
+
"strings"
|
17 |
+
"testing"
|
18 |
+
|
19 |
+
"github.com/stretchr/testify/assert"
|
20 |
+
"github.com/stretchr/testify/require"
|
21 |
+
"github.com/weaviate/weaviate/entities/models"
|
22 |
+
"github.com/weaviate/weaviate/entities/moduletools"
|
23 |
+
)
|
24 |
+
|
25 |
+
// These are mostly copy/pasted (with minimal additions) from the
|
26 |
+
// text2vec-contextionary module
|
27 |
+
func TestVectorizingObjects(t *testing.T) {
|
28 |
+
type testCase struct {
|
29 |
+
name string
|
30 |
+
input *models.Object
|
31 |
+
expectedClientCall string
|
32 |
+
expectedPoolingStrategy string
|
33 |
+
noindex string
|
34 |
+
excludedProperty string // to simulate a schema where property names aren't vectorized
|
35 |
+
excludedClass string // to simulate a schema where class names aren't vectorized
|
36 |
+
poolingStrategy string
|
37 |
+
}
|
38 |
+
|
39 |
+
tests := []testCase{
|
40 |
+
{
|
41 |
+
name: "empty object",
|
42 |
+
input: &models.Object{
|
43 |
+
Class: "Car",
|
44 |
+
},
|
45 |
+
poolingStrategy: "cls",
|
46 |
+
expectedPoolingStrategy: "cls",
|
47 |
+
expectedClientCall: "car",
|
48 |
+
},
|
49 |
+
{
|
50 |
+
name: "object with one string prop",
|
51 |
+
input: &models.Object{
|
52 |
+
Class: "Car",
|
53 |
+
Properties: map[string]interface{}{
|
54 |
+
"brand": "Mercedes",
|
55 |
+
},
|
56 |
+
},
|
57 |
+
expectedClientCall: "car brand mercedes",
|
58 |
+
},
|
59 |
+
|
60 |
+
{
|
61 |
+
name: "object with one non-string prop",
|
62 |
+
input: &models.Object{
|
63 |
+
Class: "Car",
|
64 |
+
Properties: map[string]interface{}{
|
65 |
+
"power": 300,
|
66 |
+
},
|
67 |
+
},
|
68 |
+
expectedClientCall: "car",
|
69 |
+
},
|
70 |
+
|
71 |
+
{
|
72 |
+
name: "object with a mix of props",
|
73 |
+
input: &models.Object{
|
74 |
+
Class: "Car",
|
75 |
+
Properties: map[string]interface{}{
|
76 |
+
"brand": "best brand",
|
77 |
+
"power": 300,
|
78 |
+
"review": "a very great car",
|
79 |
+
},
|
80 |
+
},
|
81 |
+
expectedClientCall: "car brand best brand review a very great car",
|
82 |
+
},
|
83 |
+
{
|
84 |
+
name: "with a noindexed property",
|
85 |
+
noindex: "review",
|
86 |
+
input: &models.Object{
|
87 |
+
Class: "Car",
|
88 |
+
Properties: map[string]interface{}{
|
89 |
+
"brand": "best brand",
|
90 |
+
"power": 300,
|
91 |
+
"review": "a very great car",
|
92 |
+
},
|
93 |
+
},
|
94 |
+
expectedClientCall: "car brand best brand",
|
95 |
+
},
|
96 |
+
|
97 |
+
{
|
98 |
+
name: "with the class name not vectorized",
|
99 |
+
excludedClass: "Car",
|
100 |
+
input: &models.Object{
|
101 |
+
Class: "Car",
|
102 |
+
Properties: map[string]interface{}{
|
103 |
+
"brand": "best brand",
|
104 |
+
"power": 300,
|
105 |
+
"review": "a very great car",
|
106 |
+
},
|
107 |
+
},
|
108 |
+
expectedClientCall: "brand best brand review a very great car",
|
109 |
+
},
|
110 |
+
|
111 |
+
{
|
112 |
+
name: "with a property name not vectorized",
|
113 |
+
excludedProperty: "review",
|
114 |
+
input: &models.Object{
|
115 |
+
Class: "Car",
|
116 |
+
Properties: map[string]interface{}{
|
117 |
+
"brand": "best brand",
|
118 |
+
"power": 300,
|
119 |
+
"review": "a very great car",
|
120 |
+
},
|
121 |
+
},
|
122 |
+
expectedClientCall: "car brand best brand a very great car",
|
123 |
+
},
|
124 |
+
|
125 |
+
{
|
126 |
+
name: "with no schema labels vectorized",
|
127 |
+
excludedProperty: "review",
|
128 |
+
excludedClass: "Car",
|
129 |
+
input: &models.Object{
|
130 |
+
Class: "Car",
|
131 |
+
Properties: map[string]interface{}{
|
132 |
+
"review": "a very great car",
|
133 |
+
},
|
134 |
+
},
|
135 |
+
expectedClientCall: "a very great car",
|
136 |
+
},
|
137 |
+
|
138 |
+
{
|
139 |
+
name: "with string/text arrays without propname or classname",
|
140 |
+
excludedProperty: "reviews",
|
141 |
+
excludedClass: "Car",
|
142 |
+
input: &models.Object{
|
143 |
+
Class: "Car",
|
144 |
+
Properties: map[string]interface{}{
|
145 |
+
"reviews": []interface{}{
|
146 |
+
"a very great car",
|
147 |
+
"you should consider buying one",
|
148 |
+
},
|
149 |
+
},
|
150 |
+
},
|
151 |
+
expectedClientCall: "a very great car you should consider buying one",
|
152 |
+
},
|
153 |
+
|
154 |
+
{
|
155 |
+
name: "with string/text arrays with propname and classname",
|
156 |
+
input: &models.Object{
|
157 |
+
Class: "Car",
|
158 |
+
Properties: map[string]interface{}{
|
159 |
+
"reviews": []interface{}{
|
160 |
+
"a very great car",
|
161 |
+
"you should consider buying one",
|
162 |
+
},
|
163 |
+
},
|
164 |
+
},
|
165 |
+
expectedClientCall: "car reviews a very great car reviews you should consider buying one",
|
166 |
+
},
|
167 |
+
|
168 |
+
{
|
169 |
+
name: "with compound class and prop names",
|
170 |
+
input: &models.Object{
|
171 |
+
Class: "SuperCar",
|
172 |
+
Properties: map[string]interface{}{
|
173 |
+
"brandOfTheCar": "best brand",
|
174 |
+
"power": 300,
|
175 |
+
"review": "a very great car",
|
176 |
+
},
|
177 |
+
},
|
178 |
+
expectedClientCall: "super car brand of the car best brand review a very great car",
|
179 |
+
},
|
180 |
+
}
|
181 |
+
|
182 |
+
for _, test := range tests {
|
183 |
+
t.Run(test.name, func(t *testing.T) {
|
184 |
+
client := &fakeClient{}
|
185 |
+
|
186 |
+
v := New(client)
|
187 |
+
|
188 |
+
ic := &fakeClassConfig{
|
189 |
+
excludedProperty: test.excludedProperty,
|
190 |
+
skippedProperty: test.noindex,
|
191 |
+
vectorizeClassName: test.excludedClass != "Car",
|
192 |
+
poolingStrategy: test.poolingStrategy,
|
193 |
+
vectorizePropertyName: true,
|
194 |
+
}
|
195 |
+
err := v.Object(context.Background(), test.input, nil, ic)
|
196 |
+
|
197 |
+
require.Nil(t, err)
|
198 |
+
assert.Equal(t, models.C11yVector{0, 1, 2, 3}, test.input.Vector)
|
199 |
+
expected := strings.Split(test.expectedClientCall, " ")
|
200 |
+
actual := strings.Split(client.lastInput, " ")
|
201 |
+
assert.Equal(t, expected, actual)
|
202 |
+
assert.Equal(t, client.lastConfig.PoolingStrategy, test.expectedPoolingStrategy)
|
203 |
+
})
|
204 |
+
}
|
205 |
+
}
|
206 |
+
|
207 |
+
func TestVectorizingObjectsWithDiff(t *testing.T) {
|
208 |
+
type testCase struct {
|
209 |
+
name string
|
210 |
+
input *models.Object
|
211 |
+
skipped string
|
212 |
+
diff *moduletools.ObjectDiff
|
213 |
+
expectedVectorize bool
|
214 |
+
}
|
215 |
+
|
216 |
+
tests := []testCase{
|
217 |
+
{
|
218 |
+
name: "no diff",
|
219 |
+
input: &models.Object{
|
220 |
+
Class: "Car",
|
221 |
+
Properties: map[string]interface{}{
|
222 |
+
"brand": "best brand",
|
223 |
+
"power": 300,
|
224 |
+
"description": "a very great car",
|
225 |
+
"reviews": []interface{}{
|
226 |
+
"a very great car",
|
227 |
+
"you should consider buying one",
|
228 |
+
},
|
229 |
+
},
|
230 |
+
},
|
231 |
+
diff: nil,
|
232 |
+
expectedVectorize: true,
|
233 |
+
},
|
234 |
+
{
|
235 |
+
name: "diff all props unchanged",
|
236 |
+
input: &models.Object{
|
237 |
+
Class: "Car",
|
238 |
+
Properties: map[string]interface{}{
|
239 |
+
"brand": "best brand",
|
240 |
+
"power": 300,
|
241 |
+
"description": "a very great car",
|
242 |
+
"reviews": []interface{}{
|
243 |
+
"a very great car",
|
244 |
+
"you should consider buying one",
|
245 |
+
},
|
246 |
+
},
|
247 |
+
},
|
248 |
+
diff: newObjectDiffWithVector().
|
249 |
+
WithProp("brand", "best brand", "best brand").
|
250 |
+
WithProp("power", 300, 300).
|
251 |
+
WithProp("description", "a very great car", "a very great car").
|
252 |
+
WithProp("reviews", []interface{}{
|
253 |
+
"a very great car",
|
254 |
+
"you should consider buying one",
|
255 |
+
}, []interface{}{
|
256 |
+
"a very great car",
|
257 |
+
"you should consider buying one",
|
258 |
+
}),
|
259 |
+
expectedVectorize: false,
|
260 |
+
},
|
261 |
+
{
|
262 |
+
name: "diff one vectorizable prop changed (1)",
|
263 |
+
input: &models.Object{
|
264 |
+
Class: "Car",
|
265 |
+
Properties: map[string]interface{}{
|
266 |
+
"brand": "best brand",
|
267 |
+
"power": 300,
|
268 |
+
"description": "a very great car",
|
269 |
+
"reviews": []interface{}{
|
270 |
+
"a very great car",
|
271 |
+
"you should consider buying one",
|
272 |
+
},
|
273 |
+
},
|
274 |
+
},
|
275 |
+
diff: newObjectDiffWithVector().
|
276 |
+
WithProp("brand", "old best brand", "best brand"),
|
277 |
+
expectedVectorize: true,
|
278 |
+
},
|
279 |
+
{
|
280 |
+
name: "diff one vectorizable prop changed (2)",
|
281 |
+
input: &models.Object{
|
282 |
+
Class: "Car",
|
283 |
+
Properties: map[string]interface{}{
|
284 |
+
"brand": "best brand",
|
285 |
+
"power": 300,
|
286 |
+
"description": "a very great car",
|
287 |
+
"reviews": []interface{}{
|
288 |
+
"a very great car",
|
289 |
+
"you should consider buying one",
|
290 |
+
},
|
291 |
+
},
|
292 |
+
},
|
293 |
+
diff: newObjectDiffWithVector().
|
294 |
+
WithProp("description", "old a very great car", "a very great car"),
|
295 |
+
expectedVectorize: true,
|
296 |
+
},
|
297 |
+
{
|
298 |
+
name: "diff one vectorizable prop changed (3)",
|
299 |
+
input: &models.Object{
|
300 |
+
Class: "Car",
|
301 |
+
Properties: map[string]interface{}{
|
302 |
+
"brand": "best brand",
|
303 |
+
"power": 300,
|
304 |
+
"description": "a very great car",
|
305 |
+
"reviews": []interface{}{
|
306 |
+
"a very great car",
|
307 |
+
"you should consider buying one",
|
308 |
+
},
|
309 |
+
},
|
310 |
+
},
|
311 |
+
diff: newObjectDiffWithVector().
|
312 |
+
WithProp("reviews", []interface{}{
|
313 |
+
"old a very great car",
|
314 |
+
"you should consider buying one",
|
315 |
+
}, []interface{}{
|
316 |
+
"a very great car",
|
317 |
+
"you should consider buying one",
|
318 |
+
}),
|
319 |
+
expectedVectorize: true,
|
320 |
+
},
|
321 |
+
{
|
322 |
+
name: "all non-vectorizable props changed",
|
323 |
+
skipped: "description",
|
324 |
+
input: &models.Object{
|
325 |
+
Class: "Car",
|
326 |
+
Properties: map[string]interface{}{
|
327 |
+
"brand": "best brand",
|
328 |
+
"power": 300,
|
329 |
+
"description": "a very great car",
|
330 |
+
"reviews": []interface{}{
|
331 |
+
"a very great car",
|
332 |
+
"you should consider buying one",
|
333 |
+
},
|
334 |
+
},
|
335 |
+
},
|
336 |
+
diff: newObjectDiffWithVector().
|
337 |
+
WithProp("power", 123, 300).
|
338 |
+
WithProp("description", "old a very great car", "a very great car"),
|
339 |
+
expectedVectorize: false,
|
340 |
+
},
|
341 |
+
}
|
342 |
+
|
343 |
+
for _, test := range tests {
|
344 |
+
t.Run(test.name, func(t *testing.T) {
|
345 |
+
ic := &fakeClassConfig{
|
346 |
+
skippedProperty: test.skipped,
|
347 |
+
}
|
348 |
+
|
349 |
+
client := &fakeClient{}
|
350 |
+
v := New(client)
|
351 |
+
|
352 |
+
err := v.Object(context.Background(), test.input, test.diff, ic)
|
353 |
+
|
354 |
+
require.Nil(t, err)
|
355 |
+
if test.expectedVectorize {
|
356 |
+
assert.Equal(t, models.C11yVector{0, 1, 2, 3}, test.input.Vector)
|
357 |
+
assert.NotEmpty(t, client.lastInput)
|
358 |
+
} else {
|
359 |
+
assert.Equal(t, models.C11yVector{0, 0, 0, 0}, test.input.Vector)
|
360 |
+
assert.Empty(t, client.lastInput)
|
361 |
+
}
|
362 |
+
})
|
363 |
+
}
|
364 |
+
}
|
365 |
+
|
366 |
+
func newObjectDiffWithVector() *moduletools.ObjectDiff {
|
367 |
+
return moduletools.NewObjectDiff([]float32{0, 0, 0, 0})
|
368 |
+
}
|
modules/text2vec-transformers/vectorizer/texts.go
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package vectorizer
|
13 |
+
|
14 |
+
import (
|
15 |
+
"context"
|
16 |
+
|
17 |
+
"github.com/pkg/errors"
|
18 |
+
"github.com/weaviate/weaviate/entities/moduletools"
|
19 |
+
"github.com/weaviate/weaviate/modules/text2vec-transformers/ent"
|
20 |
+
libvectorizer "github.com/weaviate/weaviate/usecases/vectorizer"
|
21 |
+
)
|
22 |
+
|
23 |
+
func (v *Vectorizer) Texts(ctx context.Context, inputs []string,
|
24 |
+
cfg moduletools.ClassConfig,
|
25 |
+
) ([]float32, error) {
|
26 |
+
vectors := make([][]float32, len(inputs))
|
27 |
+
for i := range inputs {
|
28 |
+
res, err := v.client.VectorizeQuery(ctx, inputs[i], ent.VectorizationConfig{
|
29 |
+
PoolingStrategy: NewClassSettings(cfg).PoolingStrategy(),
|
30 |
+
})
|
31 |
+
if err != nil {
|
32 |
+
return nil, errors.Wrap(err, "remote client vectorize")
|
33 |
+
}
|
34 |
+
vectors[i] = res.Vector
|
35 |
+
}
|
36 |
+
|
37 |
+
return libvectorizer.CombineVectors(vectors), nil
|
38 |
+
}
|
modules/text2vec-transformers/vectorizer/texts_test.go
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// _ _
|
2 |
+
// __ _____ __ ___ ___ __ _| |_ ___
|
3 |
+
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
|
4 |
+
// \ V V / __/ (_| |\ V /| | (_| | || __/
|
5 |
+
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
|
6 |
+
//
|
7 |
+
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
|
8 |
+
//
|
9 |
+
// CONTACT: [email protected]
|
10 |
+
//
|
11 |
+
|
12 |
+
package vectorizer
|
13 |
+
|
14 |
+
import (
|
15 |
+
"context"
|
16 |
+
"testing"
|
17 |
+
|
18 |
+
"github.com/stretchr/testify/assert"
|
19 |
+
"github.com/stretchr/testify/require"
|
20 |
+
)
|
21 |
+
|
22 |
+
// as used in the nearText searcher
|
23 |
+
func TestVectorizingTexts(t *testing.T) {
|
24 |
+
type testCase struct {
|
25 |
+
name string
|
26 |
+
input []string
|
27 |
+
expectedPoolingStrategy string
|
28 |
+
poolingStrategy string
|
29 |
+
}
|
30 |
+
|
31 |
+
tests := []testCase{
|
32 |
+
{
|
33 |
+
name: "single word",
|
34 |
+
input: []string{"hello"},
|
35 |
+
poolingStrategy: "cls",
|
36 |
+
expectedPoolingStrategy: "cls",
|
37 |
+
},
|
38 |
+
{
|
39 |
+
name: "multiple words",
|
40 |
+
input: []string{"hello world, this is me!"},
|
41 |
+
poolingStrategy: "cls",
|
42 |
+
expectedPoolingStrategy: "cls",
|
43 |
+
},
|
44 |
+
|
45 |
+
{
|
46 |
+
name: "multiple sentences (joined with a dot)",
|
47 |
+
input: []string{"this is sentence 1", "and here's number 2"},
|
48 |
+
poolingStrategy: "cls",
|
49 |
+
expectedPoolingStrategy: "cls",
|
50 |
+
},
|
51 |
+
|
52 |
+
{
|
53 |
+
name: "multiple sentences already containing a dot",
|
54 |
+
input: []string{"this is sentence 1.", "and here's number 2"},
|
55 |
+
poolingStrategy: "cls",
|
56 |
+
expectedPoolingStrategy: "cls",
|
57 |
+
},
|
58 |
+
{
|
59 |
+
name: "multiple sentences already containing a question mark",
|
60 |
+
input: []string{"this is sentence 1?", "and here's number 2"},
|
61 |
+
poolingStrategy: "cls",
|
62 |
+
expectedPoolingStrategy: "cls",
|
63 |
+
},
|
64 |
+
{
|
65 |
+
name: "multiple sentences already containing an exclamation mark",
|
66 |
+
input: []string{"this is sentence 1!", "and here's number 2"},
|
67 |
+
poolingStrategy: "cls",
|
68 |
+
expectedPoolingStrategy: "cls",
|
69 |
+
},
|
70 |
+
{
|
71 |
+
name: "multiple sentences already containing comma",
|
72 |
+
input: []string{"this is sentence 1,", "and here's number 2"},
|
73 |
+
poolingStrategy: "cls",
|
74 |
+
expectedPoolingStrategy: "cls",
|
75 |
+
},
|
76 |
+
}
|
77 |
+
|
78 |
+
for _, test := range tests {
|
79 |
+
t.Run(test.name, func(t *testing.T) {
|
80 |
+
client := &fakeClient{}
|
81 |
+
|
82 |
+
v := New(client)
|
83 |
+
|
84 |
+
settings := &fakeClassConfig{
|
85 |
+
poolingStrategy: test.poolingStrategy,
|
86 |
+
}
|
87 |
+
vec, err := v.Texts(context.Background(), test.input, settings)
|
88 |
+
|
89 |
+
require.Nil(t, err)
|
90 |
+
assert.Equal(t, []float32{0, 1, 2, 3}, vec)
|
91 |
+
assert.Equal(t, client.lastConfig.PoolingStrategy, test.expectedPoolingStrategy)
|
92 |
+
})
|
93 |
+
}
|
94 |
+
}
|