Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package clients | |
import ( | |
"context" | |
"encoding/json" | |
"io" | |
"net/http" | |
"net/http/httptest" | |
"os" | |
"strings" | |
"testing" | |
"time" | |
"github.com/pkg/errors" | |
"github.com/sirupsen/logrus" | |
"github.com/sirupsen/logrus/hooks/test" | |
"github.com/stretchr/testify/assert" | |
"github.com/stretchr/testify/require" | |
"github.com/weaviate/weaviate/modules/text2vec-aws/ent" | |
) | |
func TestClient(t *testing.T) { | |
t.Run("when all is fine", func(t *testing.T) { | |
t.Skip("Skipping this test for now") | |
server := httptest.NewServer(&fakeHandler{t: t}) | |
defer server.Close() | |
c := &aws{ | |
httpClient: &http.Client{}, | |
logger: nullLogger(), | |
awsAccessKey: "access_key", | |
awsSecret: "secret", | |
buildBedrockUrlFn: func(service, region, model string) string { | |
return server.URL | |
}, | |
buildSagemakerUrlFn: func(service, region, endpoint string) string { | |
return server.URL | |
}, | |
} | |
expected := &ent.VectorizationResult{ | |
Text: "This is my text", | |
Vector: []float32{0.1, 0.2, 0.3}, | |
Dimensions: 3, | |
} | |
res, err := c.Vectorize(context.Background(), []string{"This is my text"}, | |
ent.VectorizationConfig{ | |
Service: "bedrock", | |
Region: "region", | |
Model: "model", | |
}) | |
assert.Nil(t, err) | |
assert.Equal(t, expected, res) | |
}) | |
t.Run("when all is fine - Sagemaker", func(t *testing.T) { | |
server := httptest.NewServer(&fakeHandler{t: t}) | |
defer server.Close() | |
c := &aws{ | |
httpClient: &http.Client{}, | |
logger: nullLogger(), | |
awsAccessKey: "access_key", | |
awsSecret: "secret", | |
buildBedrockUrlFn: func(service, region, model string) string { | |
return server.URL | |
}, | |
buildSagemakerUrlFn: func(service, region, endpoint string) string { | |
return server.URL | |
}, | |
} | |
expected := &ent.VectorizationResult{ | |
Text: "This is my text", | |
Vector: []float32{0.1, 0.2, 0.3}, | |
Dimensions: 3, | |
} | |
res, err := c.Vectorize(context.Background(), []string{"This is my text"}, | |
ent.VectorizationConfig{ | |
Service: "sagemaker", | |
Region: "region", | |
Endpoint: "endpoint", | |
}) | |
assert.Nil(t, err) | |
assert.Equal(t, expected, res) | |
}) | |
t.Run("when the server returns an error", func(t *testing.T) { | |
t.Skip("Skipping this test for now") | |
server := httptest.NewServer(&fakeHandler{ | |
t: t, | |
serverError: errors.Errorf("nope, not gonna happen"), | |
}) | |
defer server.Close() | |
c := &aws{ | |
httpClient: &http.Client{}, | |
logger: nullLogger(), | |
awsAccessKey: "access_key", | |
awsSecret: "secret", | |
buildBedrockUrlFn: func(service, region, model string) string { | |
return server.URL | |
}, | |
buildSagemakerUrlFn: func(service, region, endpoint string) string { | |
return server.URL | |
}, | |
} | |
_, err := c.Vectorize(context.Background(), []string{"This is my text"}, | |
ent.VectorizationConfig{ | |
Service: "bedrock", | |
}) | |
require.NotNil(t, err) | |
assert.EqualError(t, err, "connection to AWS failed with status: 500 error: nope, not gonna happen") | |
}) | |
t.Run("when AWS key is passed using X-Aws-Api-Key header", func(t *testing.T) { | |
t.Skip("Skipping this test for now") | |
server := httptest.NewServer(&fakeHandler{t: t}) | |
defer server.Close() | |
c := &aws{ | |
httpClient: &http.Client{}, | |
logger: nullLogger(), | |
awsAccessKey: "access_key", | |
awsSecret: "secret", | |
buildBedrockUrlFn: func(service, region, model string) string { | |
return server.URL | |
}, | |
buildSagemakerUrlFn: func(service, region, endpoint string) string { | |
return server.URL | |
}, | |
} | |
ctxWithValue := context.WithValue(context.Background(), | |
"X-Aws-Api-Key", []string{"some-key"}) | |
expected := &ent.VectorizationResult{ | |
Text: "This is my text", | |
Vector: []float32{0.1, 0.2, 0.3}, | |
Dimensions: 3, | |
} | |
res, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, ent.VectorizationConfig{ | |
Service: "bedrock", | |
}) | |
require.Nil(t, err) | |
assert.Equal(t, expected, res) | |
}) | |
t.Run("when X-Aws-Access-Key header is passed but empty", func(t *testing.T) { | |
t.Skip("Skipping this test for now") | |
server := httptest.NewServer(&fakeHandler{t: t}) | |
defer server.Close() | |
c := &aws{ | |
httpClient: &http.Client{}, | |
logger: nullLogger(), | |
awsAccessKey: "", | |
awsSecret: "123", | |
buildBedrockUrlFn: func(service, region, model string) string { | |
return server.URL | |
}, | |
buildSagemakerUrlFn: func(service, region, endpoint string) string { | |
return server.URL | |
}, | |
} | |
ctxWithValue := context.WithValue(context.Background(), | |
"X-Aws-Api-Key", []string{""}) | |
_, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, ent.VectorizationConfig{ | |
Service: "bedrock", | |
}) | |
require.NotNil(t, err) | |
assert.Equal(t, err.Error(), "AWS Access Key: no access key found neither in request header: "+ | |
"X-Aws-Access-Key nor in environment variable under AWS_ACCESS_KEY_ID") | |
}) | |
t.Run("when X-Aws-Secret-Key header is passed but empty", func(t *testing.T) { | |
t.Skip("Skipping this test for now") | |
server := httptest.NewServer(&fakeHandler{t: t}) | |
defer server.Close() | |
c := &aws{ | |
httpClient: &http.Client{}, | |
logger: nullLogger(), | |
awsAccessKey: "123", | |
awsSecret: "", | |
buildBedrockUrlFn: func(service, region, model string) string { | |
return server.URL | |
}, | |
buildSagemakerUrlFn: func(service, region, endpoint string) string { | |
return server.URL | |
}, | |
} | |
ctxWithValue := context.WithValue(context.Background(), | |
"X-Aws-Api-Key", []string{""}) | |
_, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, ent.VectorizationConfig{ | |
Service: "bedrock", | |
}) | |
require.NotNil(t, err) | |
assert.Equal(t, err.Error(), "AWS Secret Key: no secret found neither in request header: "+ | |
"X-Aws-Access-Secret nor in environment variable under AWS_SECRET_ACCESS_KEY") | |
}) | |
} | |
func TestBuildBedrockUrl(t *testing.T) { | |
service := "bedrock" | |
region := "us-east-1" | |
t.Run("when using a Cohere", func(t *testing.T) { | |
model := "cohere.embed-english-v3" | |
expected := "https://bedrock-runtime.us-east-1.amazonaws.com/model/cohere.embed-english-v3/invoke" | |
result := buildBedrockUrl(service, region, model) | |
if result != expected { | |
t.Errorf("Expected %s but got %s", expected, result) | |
} | |
}) | |
t.Run("When using an AWS model", func(t *testing.T) { | |
model := "amazon.titan-e1t-medium" | |
expected := "https://bedrock.us-east-1.amazonaws.com/model/amazon.titan-e1t-medium/invoke" | |
result := buildBedrockUrl(service, region, model) | |
if result != expected { | |
t.Errorf("Expected %s but got %s", expected, result) | |
} | |
}) | |
} | |
func TestCreateRequestBody(t *testing.T) { | |
input := []string{"Hello, world!"} | |
t.Run("Create request for Amazon embedding model", func(t *testing.T) { | |
model := "amazon.titan-e1t-medium" | |
req, _ := createRequestBody(model, input, vectorizeObject) | |
_, ok := req.(bedrockEmbeddingsRequest) | |
if !ok { | |
t.Fatalf("Expected req to be a bedrockEmbeddingsRequest, got %T", req) | |
} | |
}) | |
t.Run("Create request for Cohere embedding model", func(t *testing.T) { | |
model := "cohere.embed-english-v3" | |
req, _ := createRequestBody(model, input, vectorizeObject) | |
_, ok := req.(bedrockCohereEmbeddingRequest) | |
if !ok { | |
t.Fatalf("Expected req to be a bedrockCohereEmbeddingRequest, got %T", req) | |
} | |
}) | |
t.Run("Create request for unknown embedding model", func(t *testing.T) { | |
model := "unknown.model" | |
_, err := createRequestBody(model, input, vectorizeObject) | |
if err == nil { | |
t.Errorf("Expected an error for unknown model, got nil") | |
} | |
}) | |
} | |
func TestVectorize(t *testing.T) { | |
ctx := context.Background() | |
input := []string{"Hello, world!"} | |
t.Run("Vectorize using an Amazon model", func(t *testing.T) { | |
t.Skip("Skipping because CI doesnt have the right credentials") | |
config := ent.VectorizationConfig{ | |
Model: "amazon.titan-e1t-medium", | |
Service: "bedrock", | |
Region: "us-east-1", | |
} | |
awsAccessKeyID := os.Getenv("AWS_ACCESS_KEY_ID_AMAZON") | |
awsSecretAccessKey := os.Getenv("AWS_SECRET_ACCESS_KEY_AMAZON") | |
aws := New(awsAccessKeyID, awsSecretAccessKey, 60*time.Second, nil) | |
_, err := aws.Vectorize(ctx, input, config) | |
if err != nil { | |
t.Errorf("Vectorize returned an error: %v", err) | |
} | |
}) | |
t.Run("Vectorize using a Cohere model", func(t *testing.T) { | |
t.Skip("Skipping because CI doesnt have the right credentials") | |
config := ent.VectorizationConfig{ | |
Model: "cohere.embed-english-v3", | |
Service: "bedrock", | |
Region: "us-east-1", | |
} | |
awsAccessKeyID := os.Getenv("AWS_ACCESS_KEY_ID_COHERE") | |
awsSecretAccessKey := os.Getenv("AWS_SECRET_ACCESS_KEY_COHERE") | |
aws := New(awsAccessKeyID, awsSecretAccessKey, 60*time.Second, nil) | |
_, err := aws.Vectorize(ctx, input, config) | |
if err != nil { | |
t.Errorf("Vectorize returned an error: %v", err) | |
} | |
}) | |
} | |
func TestExtractHostAndPath(t *testing.T) { | |
t.Run("valid URL", func(t *testing.T) { | |
endpointUrl := "https://service.region.amazonaws.com/model/model-name/invoke" | |
expectedHost := "service.region.amazonaws.com" | |
expectedPath := "/model/model-name/invoke" | |
host, path, err := extractHostAndPath(endpointUrl) | |
if err != nil { | |
t.Errorf("Unexpected error: %v", err) | |
} | |
if host != expectedHost { | |
t.Errorf("Expected host %s but got %s", expectedHost, host) | |
} | |
if path != expectedPath { | |
t.Errorf("Expected path %s but got %s", expectedPath, path) | |
} | |
}) | |
t.Run("URL without host or path", func(t *testing.T) { | |
endpointUrl := "https://" | |
_, _, err := extractHostAndPath(endpointUrl) | |
if err == nil { | |
t.Error("Expected error but got nil") | |
} | |
}) | |
} | |
type fakeHandler struct { | |
t *testing.T | |
serverError error | |
} | |
func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |
assert.Equal(f.t, http.MethodPost, r.Method) | |
authHeader := r.Header["Authorization"][0] | |
if f.serverError != nil { | |
var outBytes []byte | |
var err error | |
if strings.Contains(authHeader, "bedrock") { | |
embeddingResponse := &bedrockEmbeddingResponse{ | |
Message: ptString(f.serverError.Error()), | |
} | |
outBytes, err = json.Marshal(embeddingResponse) | |
} else { | |
embeddingResponse := &sagemakerEmbeddingResponse{ | |
Message: ptString(f.serverError.Error()), | |
} | |
outBytes, err = json.Marshal(embeddingResponse) | |
} | |
require.Nil(f.t, err) | |
w.WriteHeader(http.StatusInternalServerError) | |
w.Write(outBytes) | |
return | |
} | |
bodyBytes, err := io.ReadAll(r.Body) | |
require.Nil(f.t, err) | |
defer r.Body.Close() | |
var outBytes []byte | |
if strings.Contains(authHeader, "bedrock") { | |
var req bedrockEmbeddingsRequest | |
require.Nil(f.t, json.Unmarshal(bodyBytes, &req)) | |
textInput := req.InputText | |
assert.Greater(f.t, len(textInput), 0) | |
embeddingResponse := &bedrockEmbeddingResponse{ | |
Embedding: []float32{0.1, 0.2, 0.3}, | |
} | |
outBytes, err = json.Marshal(embeddingResponse) | |
} else { | |
var req sagemakerEmbeddingsRequest | |
require.Nil(f.t, json.Unmarshal(bodyBytes, &req)) | |
textInputs := req.TextInputs | |
assert.Greater(f.t, len(textInputs), 0) | |
embeddingResponse := &sagemakerEmbeddingResponse{ | |
Embedding: [][]float32{{0.1, 0.2, 0.3}}, | |
} | |
outBytes, err = json.Marshal(embeddingResponse) | |
} | |
require.Nil(f.t, err) | |
w.Write(outBytes) | |
} | |
func nullLogger() logrus.FieldLogger { | |
l, _ := test.NewNullLogger() | |
return l | |
} | |
func ptString(in string) *string { | |
return &in | |
} | |