Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package clients | |
import ( | |
"bytes" | |
"context" | |
"encoding/json" | |
"fmt" | |
"io" | |
"net/http" | |
"net/url" | |
"strings" | |
"time" | |
"github.com/google/uuid" | |
"github.com/pkg/errors" | |
"github.com/sirupsen/logrus" | |
"github.com/weaviate/weaviate/modules/text2vec-aws/ent" | |
) | |
type operationType string | |
var ( | |
vectorizeObject operationType = "vectorize_object" | |
vectorizeQuery operationType = "vectorize_query" | |
) | |
func buildBedrockUrl(service, region, model string) string { | |
serviceName := service | |
if strings.HasPrefix(model, "cohere") { | |
serviceName = fmt.Sprintf("%s-runtime", serviceName) | |
} | |
urlTemplate := "https://%s.%s.amazonaws.com/model/%s/invoke" | |
return fmt.Sprintf(urlTemplate, serviceName, region, model) | |
} | |
func buildSagemakerUrl(service, region, endpoint string) string { | |
urlTemplate := "https://runtime.%s.%s.amazonaws.com/endpoints/%s/invocations" | |
return fmt.Sprintf(urlTemplate, service, region, endpoint) | |
} | |
type aws struct { | |
awsAccessKey string | |
awsSecret string | |
buildBedrockUrlFn func(service, region, model string) string | |
buildSagemakerUrlFn func(service, region, endpoint string) string | |
httpClient *http.Client | |
logger logrus.FieldLogger | |
} | |
func New(awsAccessKey string, awsSecret string, timeout time.Duration, logger logrus.FieldLogger) *aws { | |
return &aws{ | |
awsAccessKey: awsAccessKey, | |
awsSecret: awsSecret, | |
httpClient: &http.Client{ | |
Timeout: timeout, | |
}, | |
buildBedrockUrlFn: buildBedrockUrl, | |
buildSagemakerUrlFn: buildSagemakerUrl, | |
logger: logger, | |
} | |
} | |
func (v *aws) Vectorize(ctx context.Context, input []string, | |
config ent.VectorizationConfig, | |
) (*ent.VectorizationResult, error) { | |
return v.vectorize(ctx, input, vectorizeObject, config) | |
} | |
func (v *aws) VectorizeQuery(ctx context.Context, input []string, | |
config ent.VectorizationConfig, | |
) (*ent.VectorizationResult, error) { | |
return v.vectorize(ctx, input, vectorizeQuery, config) | |
} | |
func (v *aws) vectorize(ctx context.Context, input []string, operation operationType, config ent.VectorizationConfig) (*ent.VectorizationResult, error) { | |
service := v.getService(config) | |
region := v.getRegion(config) | |
model := v.getModel(config) | |
endpoint := v.getEndpoint(config) | |
targetModel := v.getTargetModel(config) | |
targetVariant := v.getTargetVariant(config) | |
var body []byte | |
var endpointUrl string | |
var host string | |
var path string | |
var err error | |
headers := map[string]string{ | |
"accept": "*/*", | |
"content-type": contentType, | |
} | |
if v.isBedrock(service) { | |
endpointUrl = v.buildBedrockUrlFn(service, region, model) | |
host, path, _ = extractHostAndPath(endpointUrl) | |
req, err := createRequestBody(model, input, operation) | |
if err != nil { | |
return nil, err | |
} | |
body, err = json.Marshal(req) | |
if err != nil { | |
return nil, errors.Wrapf(err, "marshal body") | |
} | |
} else if v.isSagemaker(service) { | |
endpointUrl = v.buildSagemakerUrlFn(service, region, endpoint) | |
host = "runtime." + service + "." + region + ".amazonaws.com" | |
path = "/endpoints/" + endpoint + "/invocations" | |
if targetModel != "" { | |
headers["x-amzn-sagemaker-target-model"] = targetModel | |
} | |
if targetVariant != "" { | |
headers["x-amzn-sagemaker-target-variant"] = targetVariant | |
} | |
body, err = json.Marshal(sagemakerEmbeddingsRequest{ | |
TextInputs: input, | |
}) | |
if err != nil { | |
return nil, errors.Wrapf(err, "marshal body") | |
} | |
} else { | |
return nil, errors.Wrapf(err, "service error") | |
} | |
accessKey, err := v.getAwsAccessKey(ctx) | |
if err != nil { | |
return nil, errors.Wrapf(err, "AWS Access Key") | |
} | |
secretKey, err := v.getAwsAccessSecret(ctx) | |
if err != nil { | |
return nil, errors.Wrapf(err, "AWS Secret Key") | |
} | |
headers["host"] = host | |
amzDate, headers, authorizationHeader := getAuthHeader(accessKey, secretKey, host, service, region, path, body, headers) | |
headers["Authorization"] = authorizationHeader | |
headers["x-amz-date"] = amzDate | |
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointUrl, bytes.NewReader(body)) | |
if err != nil { | |
return nil, errors.Wrap(err, "create POST request") | |
} | |
for k, v := range headers { | |
req.Header.Set(k, v) | |
} | |
res, err := v.makeRequest(req, 30, 5) | |
if err != nil { | |
return nil, errors.Wrap(err, "send POST request") | |
} | |
defer res.Body.Close() | |
bodyBytes, err := io.ReadAll(res.Body) | |
if err != nil { | |
return nil, errors.Wrap(err, "read response body") | |
} | |
if v.isBedrock(service) { | |
return v.parseBedrockResponse(bodyBytes, res, input) | |
} else { | |
return v.parseSagemakerResponse(bodyBytes, res, input) | |
} | |
} | |
func (v *aws) makeRequest(req *http.Request, delayInSeconds int, maxRetries int) (*http.Response, error) { | |
var res *http.Response | |
var err error | |
// Generate a UUID for this request | |
requestID := uuid.New().String() | |
for i := 0; i < maxRetries; i++ { | |
res, err = v.httpClient.Do(req) | |
if err != nil { | |
return nil, errors.Wrap(err, "send POST request") | |
} | |
// If the status code is not 429 or 400, break the loop | |
if res.StatusCode != http.StatusTooManyRequests && res.StatusCode != http.StatusBadRequest { | |
break | |
} | |
v.logger.Debugf("Request ID %s to %s returned 429, retrying in %d seconds", requestID, req.URL, delayInSeconds) | |
// Sleep for a while and then continue to the next iteration | |
time.Sleep(time.Duration(delayInSeconds) * time.Second) | |
// Double the delay for the next iteration | |
delayInSeconds *= 2 | |
} | |
return res, err | |
} | |
func (v *aws) parseBedrockResponse(bodyBytes []byte, res *http.Response, input []string) (*ent.VectorizationResult, error) { | |
var resBodyMap map[string]interface{} | |
if err := json.Unmarshal(bodyBytes, &resBodyMap); err != nil { | |
return nil, errors.Wrap(err, "unmarshal response body") | |
} | |
// if resBodyMap has inputTextTokenCount, it's a resonse from an Amazon model | |
// otherwise, it is a response from a Cohere model | |
var resBody bedrockEmbeddingResponse | |
if err := json.Unmarshal(bodyBytes, &resBody); err != nil { | |
return nil, errors.Wrap(err, "unmarshal response body") | |
} | |
if res.StatusCode != 200 || resBody.Message != nil { | |
if resBody.Message != nil { | |
return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %v error: %s", | |
res.StatusCode, *resBody.Message) | |
} | |
return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %d", res.StatusCode) | |
} | |
if len(resBody.Embedding) == 0 && len(resBody.Embeddings) == 0 { | |
return nil, fmt.Errorf("could not obtain vector from AWS Bedrock") | |
} | |
embedding := resBody.Embedding | |
if len(resBody.Embeddings) > 0 { | |
embedding = resBody.Embeddings[0] | |
} | |
return &ent.VectorizationResult{ | |
Text: input[0], | |
Dimensions: len(embedding), | |
Vector: embedding, | |
}, nil | |
} | |
func (v *aws) parseSagemakerResponse(bodyBytes []byte, res *http.Response, input []string) (*ent.VectorizationResult, error) { | |
var resBody sagemakerEmbeddingResponse | |
if err := json.Unmarshal(bodyBytes, &resBody); err != nil { | |
return nil, errors.Wrap(err, "unmarshal response body") | |
} | |
if res.StatusCode != 200 || resBody.Message != nil { | |
if resBody.Message != nil { | |
return nil, fmt.Errorf("connection to AWS failed with status: %v error: %s", | |
res.StatusCode, *resBody.Message) | |
} | |
return nil, fmt.Errorf("connection to AWS failed with status: %d", res.StatusCode) | |
} | |
if len(resBody.Embedding) == 0 { | |
return nil, errors.Errorf("empty embeddings response") | |
} | |
return &ent.VectorizationResult{ | |
Text: input[0], | |
Dimensions: len(resBody.Embedding[0]), | |
Vector: resBody.Embedding[0], | |
}, nil | |
} | |
func (v *aws) isSagemaker(service string) bool { | |
return service == "sagemaker" | |
} | |
func (v *aws) isBedrock(service string) bool { | |
return service == "bedrock" | |
} | |
func (v *aws) getAwsAccessKey(ctx context.Context) (string, error) { | |
awsAccessKey := ctx.Value("X-Aws-Access-Key") | |
if awsAccessKeyHeader, ok := awsAccessKey.([]string); ok && | |
len(awsAccessKeyHeader) > 0 && len(awsAccessKeyHeader[0]) > 0 { | |
return awsAccessKeyHeader[0], nil | |
} | |
if len(v.awsAccessKey) > 0 { | |
return v.awsAccessKey, nil | |
} | |
return "", errors.New("no access key found " + | |
"neither in request header: X-AWS-Access-Key " + | |
"nor in environment variable under AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY") | |
} | |
func (v *aws) getAwsAccessSecret(ctx context.Context) (string, error) { | |
awsSecretKey := ctx.Value("X-Aws-Secret-Key") | |
if awsAccessSecretHeader, ok := awsSecretKey.([]string); ok && | |
len(awsAccessSecretHeader) > 0 && len(awsAccessSecretHeader[0]) > 0 { | |
return awsAccessSecretHeader[0], nil | |
} | |
if len(v.awsSecret) > 0 { | |
return v.awsSecret, nil | |
} | |
return "", errors.New("no secret found " + | |
"neither in request header: X-AWS-Secret-Key " + | |
"nor in environment variable under AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY") | |
} | |
func (v *aws) getModel(config ent.VectorizationConfig) string { | |
return config.Model | |
} | |
func (v *aws) getRegion(config ent.VectorizationConfig) string { | |
return config.Region | |
} | |
func (v *aws) getService(config ent.VectorizationConfig) string { | |
return config.Service | |
} | |
func (v *aws) getEndpoint(config ent.VectorizationConfig) string { | |
return config.Endpoint | |
} | |
func (v *aws) getTargetModel(config ent.VectorizationConfig) string { | |
return config.TargetModel | |
} | |
func (v *aws) getTargetVariant(config ent.VectorizationConfig) string { | |
return config.TargetVariant | |
} | |
type bedrockEmbeddingsRequest struct { | |
InputText string `json:"inputText,omitempty"` | |
} | |
type bedrockCohereEmbeddingRequest struct { | |
Texts []string `json:"texts"` | |
InputType string `json:"input_type"` | |
} | |
type sagemakerEmbeddingsRequest struct { | |
TextInputs []string `json:"text_inputs,omitempty"` | |
} | |
type bedrockEmbeddingResponse struct { | |
InputTextTokenCount int `json:"InputTextTokenCount,omitempty"` | |
Embedding []float32 `json:"embedding,omitempty"` | |
Embeddings [][]float32 `json:"embeddings,omitempty"` | |
Message *string `json:"message,omitempty"` | |
} | |
type sagemakerEmbeddingResponse struct { | |
Embedding [][]float32 `json:"embedding,omitempty"` | |
ErrorCode *string `json:"ErrorCode,omitempty"` | |
LogStreamArn *string `json:"LogStreamArn,omitempty"` | |
OriginalMessage *string `json:"OriginalMessage,omitempty"` | |
Message *string `json:"Message,omitempty"` | |
OriginalStatusCode *int `json:"OriginalStatusCode,omitempty"` | |
} | |
func extractHostAndPath(endpointUrl string) (string, string, error) { | |
u, err := url.Parse(endpointUrl) | |
if err != nil { | |
return "", "", err | |
} | |
if u.Host == "" || u.Path == "" { | |
return "", "", fmt.Errorf("invalid endpoint URL: %s", endpointUrl) | |
} | |
return u.Host, u.Path, nil | |
} | |
func createRequestBody(model string, texts []string, operation operationType) (interface{}, error) { | |
modelParts := strings.Split(model, ".") | |
if len(modelParts) == 0 { | |
return nil, fmt.Errorf("invalid model: %s", model) | |
} | |
modelProvider := modelParts[0] | |
switch modelProvider { | |
case "amazon": | |
return bedrockEmbeddingsRequest{ | |
InputText: texts[0], | |
}, nil | |
case "cohere": | |
inputType := "search_document" | |
if operation == vectorizeQuery { | |
inputType = "search_query" | |
} | |
return bedrockCohereEmbeddingRequest{ | |
Texts: texts, | |
InputType: inputType, | |
}, nil | |
default: | |
return nil, fmt.Errorf("unknown model provider: %s", modelProvider) | |
} | |
} | |