// _ _ // __ _____ __ ___ ___ __ _| |_ ___ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ // \ V V / __/ (_| |\ V /| | (_| | || __/ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| // // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. // // CONTACT: hello@weaviate.io // package clients import ( "context" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "testing" "time" "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/weaviate/weaviate/modules/text2vec-transformers/ent" ) func TestClient(t *testing.T) { t.Run("when all is fine", func(t *testing.T) { server := httptest.NewServer(&fakeHandler{t: t}) defer server.Close() c := New(server.URL, server.URL, 0, nullLogger()) expected := &ent.VectorizationResult{ Text: "This is my text", Vector: []float32{0.1, 0.2, 0.3}, Dimensions: 3, } res, err := c.VectorizeObject(context.Background(), "This is my text", ent.VectorizationConfig{ PoolingStrategy: "masked_mean", }) assert.Nil(t, err) assert.Equal(t, expected, res) }) t.Run("when the context is expired", func(t *testing.T) { server := httptest.NewServer(&fakeHandler{t: t}) defer server.Close() c := New(server.URL, server.URL, 0, nullLogger()) ctx, cancel := context.WithDeadline(context.Background(), time.Now()) defer cancel() _, err := c.VectorizeObject(ctx, "This is my text", ent.VectorizationConfig{}) require.NotNil(t, err) assert.Contains(t, err.Error(), "context deadline exceeded") }) t.Run("when the server returns an error", func(t *testing.T) { server := httptest.NewServer(&fakeHandler{ t: t, serverError: errors.Errorf("nope, not gonna happen"), }) defer server.Close() c := New(server.URL, server.URL, 0, nullLogger()) _, err := c.VectorizeObject(context.Background(), "This is my text", ent.VectorizationConfig{}) require.NotNil(t, err) assert.Contains(t, err.Error(), "nope, not gonna happen") }) } type fakeHandler struct { t *testing.T serverError error } func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { assert.Equal(f.t, "/vectors", r.URL.String()) assert.Equal(f.t, http.MethodPost, r.Method) if f.serverError != nil { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(fmt.Sprintf(`{"error":"%s"}`, f.serverError.Error()))) return } bodyBytes, err := io.ReadAll(r.Body) require.Nil(f.t, err) defer r.Body.Close() var b map[string]interface{} require.Nil(f.t, json.Unmarshal(bodyBytes, &b)) textInput := b["text"].(string) assert.Greater(f.t, len(textInput), 0) pooling := b["config"].(map[string]interface{})["pooling_strategy"].(string) assert.Equal(f.t, "masked_mean", pooling) out := map[string]interface{}{ "text": textInput, "dims": 3, "vector": []float32{0.1, 0.2, 0.3}, } outBytes, err := json.Marshal(out) require.Nil(f.t, err) w.Write(outBytes) }