KevinStephenson
Adding in weaviate code
b110593
raw
history blame
12 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package clients
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/weaviate/weaviate/modules/text2vec-aws/ent"
)
type operationType string
var (
vectorizeObject operationType = "vectorize_object"
vectorizeQuery operationType = "vectorize_query"
)
func buildBedrockUrl(service, region, model string) string {
serviceName := service
if strings.HasPrefix(model, "cohere") {
serviceName = fmt.Sprintf("%s-runtime", serviceName)
}
urlTemplate := "https://%s.%s.amazonaws.com/model/%s/invoke"
return fmt.Sprintf(urlTemplate, serviceName, region, model)
}
func buildSagemakerUrl(service, region, endpoint string) string {
urlTemplate := "https://runtime.%s.%s.amazonaws.com/endpoints/%s/invocations"
return fmt.Sprintf(urlTemplate, service, region, endpoint)
}
type aws struct {
awsAccessKey string
awsSecret string
buildBedrockUrlFn func(service, region, model string) string
buildSagemakerUrlFn func(service, region, endpoint string) string
httpClient *http.Client
logger logrus.FieldLogger
}
func New(awsAccessKey string, awsSecret string, timeout time.Duration, logger logrus.FieldLogger) *aws {
return &aws{
awsAccessKey: awsAccessKey,
awsSecret: awsSecret,
httpClient: &http.Client{
Timeout: timeout,
},
buildBedrockUrlFn: buildBedrockUrl,
buildSagemakerUrlFn: buildSagemakerUrl,
logger: logger,
}
}
func (v *aws) Vectorize(ctx context.Context, input []string,
config ent.VectorizationConfig,
) (*ent.VectorizationResult, error) {
return v.vectorize(ctx, input, vectorizeObject, config)
}
func (v *aws) VectorizeQuery(ctx context.Context, input []string,
config ent.VectorizationConfig,
) (*ent.VectorizationResult, error) {
return v.vectorize(ctx, input, vectorizeQuery, config)
}
func (v *aws) vectorize(ctx context.Context, input []string, operation operationType, config ent.VectorizationConfig) (*ent.VectorizationResult, error) {
service := v.getService(config)
region := v.getRegion(config)
model := v.getModel(config)
endpoint := v.getEndpoint(config)
targetModel := v.getTargetModel(config)
targetVariant := v.getTargetVariant(config)
var body []byte
var endpointUrl string
var host string
var path string
var err error
headers := map[string]string{
"accept": "*/*",
"content-type": contentType,
}
if v.isBedrock(service) {
endpointUrl = v.buildBedrockUrlFn(service, region, model)
host, path, _ = extractHostAndPath(endpointUrl)
req, err := createRequestBody(model, input, operation)
if err != nil {
return nil, err
}
body, err = json.Marshal(req)
if err != nil {
return nil, errors.Wrapf(err, "marshal body")
}
} else if v.isSagemaker(service) {
endpointUrl = v.buildSagemakerUrlFn(service, region, endpoint)
host = "runtime." + service + "." + region + ".amazonaws.com"
path = "/endpoints/" + endpoint + "/invocations"
if targetModel != "" {
headers["x-amzn-sagemaker-target-model"] = targetModel
}
if targetVariant != "" {
headers["x-amzn-sagemaker-target-variant"] = targetVariant
}
body, err = json.Marshal(sagemakerEmbeddingsRequest{
TextInputs: input,
})
if err != nil {
return nil, errors.Wrapf(err, "marshal body")
}
} else {
return nil, errors.Wrapf(err, "service error")
}
accessKey, err := v.getAwsAccessKey(ctx)
if err != nil {
return nil, errors.Wrapf(err, "AWS Access Key")
}
secretKey, err := v.getAwsAccessSecret(ctx)
if err != nil {
return nil, errors.Wrapf(err, "AWS Secret Key")
}
headers["host"] = host
amzDate, headers, authorizationHeader := getAuthHeader(accessKey, secretKey, host, service, region, path, body, headers)
headers["Authorization"] = authorizationHeader
headers["x-amz-date"] = amzDate
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointUrl, bytes.NewReader(body))
if err != nil {
return nil, errors.Wrap(err, "create POST request")
}
for k, v := range headers {
req.Header.Set(k, v)
}
res, err := v.makeRequest(req, 30, 5)
if err != nil {
return nil, errors.Wrap(err, "send POST request")
}
defer res.Body.Close()
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
return nil, errors.Wrap(err, "read response body")
}
if v.isBedrock(service) {
return v.parseBedrockResponse(bodyBytes, res, input)
} else {
return v.parseSagemakerResponse(bodyBytes, res, input)
}
}
func (v *aws) makeRequest(req *http.Request, delayInSeconds int, maxRetries int) (*http.Response, error) {
var res *http.Response
var err error
// Generate a UUID for this request
requestID := uuid.New().String()
for i := 0; i < maxRetries; i++ {
res, err = v.httpClient.Do(req)
if err != nil {
return nil, errors.Wrap(err, "send POST request")
}
// If the status code is not 429 or 400, break the loop
if res.StatusCode != http.StatusTooManyRequests && res.StatusCode != http.StatusBadRequest {
break
}
v.logger.Debugf("Request ID %s to %s returned 429, retrying in %d seconds", requestID, req.URL, delayInSeconds)
// Sleep for a while and then continue to the next iteration
time.Sleep(time.Duration(delayInSeconds) * time.Second)
// Double the delay for the next iteration
delayInSeconds *= 2
}
return res, err
}
func (v *aws) parseBedrockResponse(bodyBytes []byte, res *http.Response, input []string) (*ent.VectorizationResult, error) {
var resBodyMap map[string]interface{}
if err := json.Unmarshal(bodyBytes, &resBodyMap); err != nil {
return nil, errors.Wrap(err, "unmarshal response body")
}
// if resBodyMap has inputTextTokenCount, it's a resonse from an Amazon model
// otherwise, it is a response from a Cohere model
var resBody bedrockEmbeddingResponse
if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
return nil, errors.Wrap(err, "unmarshal response body")
}
if res.StatusCode != 200 || resBody.Message != nil {
if resBody.Message != nil {
return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %v error: %s",
res.StatusCode, *resBody.Message)
}
return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %d", res.StatusCode)
}
if len(resBody.Embedding) == 0 && len(resBody.Embeddings) == 0 {
return nil, fmt.Errorf("could not obtain vector from AWS Bedrock")
}
embedding := resBody.Embedding
if len(resBody.Embeddings) > 0 {
embedding = resBody.Embeddings[0]
}
return &ent.VectorizationResult{
Text: input[0],
Dimensions: len(embedding),
Vector: embedding,
}, nil
}
func (v *aws) parseSagemakerResponse(bodyBytes []byte, res *http.Response, input []string) (*ent.VectorizationResult, error) {
var resBody sagemakerEmbeddingResponse
if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
return nil, errors.Wrap(err, "unmarshal response body")
}
if res.StatusCode != 200 || resBody.Message != nil {
if resBody.Message != nil {
return nil, fmt.Errorf("connection to AWS failed with status: %v error: %s",
res.StatusCode, *resBody.Message)
}
return nil, fmt.Errorf("connection to AWS failed with status: %d", res.StatusCode)
}
if len(resBody.Embedding) == 0 {
return nil, errors.Errorf("empty embeddings response")
}
return &ent.VectorizationResult{
Text: input[0],
Dimensions: len(resBody.Embedding[0]),
Vector: resBody.Embedding[0],
}, nil
}
func (v *aws) isSagemaker(service string) bool {
return service == "sagemaker"
}
func (v *aws) isBedrock(service string) bool {
return service == "bedrock"
}
func (v *aws) getAwsAccessKey(ctx context.Context) (string, error) {
awsAccessKey := ctx.Value("X-Aws-Access-Key")
if awsAccessKeyHeader, ok := awsAccessKey.([]string); ok &&
len(awsAccessKeyHeader) > 0 && len(awsAccessKeyHeader[0]) > 0 {
return awsAccessKeyHeader[0], nil
}
if len(v.awsAccessKey) > 0 {
return v.awsAccessKey, nil
}
return "", errors.New("no access key found " +
"neither in request header: X-AWS-Access-Key " +
"nor in environment variable under AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY")
}
func (v *aws) getAwsAccessSecret(ctx context.Context) (string, error) {
awsSecretKey := ctx.Value("X-Aws-Secret-Key")
if awsAccessSecretHeader, ok := awsSecretKey.([]string); ok &&
len(awsAccessSecretHeader) > 0 && len(awsAccessSecretHeader[0]) > 0 {
return awsAccessSecretHeader[0], nil
}
if len(v.awsSecret) > 0 {
return v.awsSecret, nil
}
return "", errors.New("no secret found " +
"neither in request header: X-AWS-Secret-Key " +
"nor in environment variable under AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY")
}
func (v *aws) getModel(config ent.VectorizationConfig) string {
return config.Model
}
func (v *aws) getRegion(config ent.VectorizationConfig) string {
return config.Region
}
func (v *aws) getService(config ent.VectorizationConfig) string {
return config.Service
}
func (v *aws) getEndpoint(config ent.VectorizationConfig) string {
return config.Endpoint
}
func (v *aws) getTargetModel(config ent.VectorizationConfig) string {
return config.TargetModel
}
func (v *aws) getTargetVariant(config ent.VectorizationConfig) string {
return config.TargetVariant
}
type bedrockEmbeddingsRequest struct {
InputText string `json:"inputText,omitempty"`
}
type bedrockCohereEmbeddingRequest struct {
Texts []string `json:"texts"`
InputType string `json:"input_type"`
}
type sagemakerEmbeddingsRequest struct {
TextInputs []string `json:"text_inputs,omitempty"`
}
type bedrockEmbeddingResponse struct {
InputTextTokenCount int `json:"InputTextTokenCount,omitempty"`
Embedding []float32 `json:"embedding,omitempty"`
Embeddings [][]float32 `json:"embeddings,omitempty"`
Message *string `json:"message,omitempty"`
}
type sagemakerEmbeddingResponse struct {
Embedding [][]float32 `json:"embedding,omitempty"`
ErrorCode *string `json:"ErrorCode,omitempty"`
LogStreamArn *string `json:"LogStreamArn,omitempty"`
OriginalMessage *string `json:"OriginalMessage,omitempty"`
Message *string `json:"Message,omitempty"`
OriginalStatusCode *int `json:"OriginalStatusCode,omitempty"`
}
func extractHostAndPath(endpointUrl string) (string, string, error) {
u, err := url.Parse(endpointUrl)
if err != nil {
return "", "", err
}
if u.Host == "" || u.Path == "" {
return "", "", fmt.Errorf("invalid endpoint URL: %s", endpointUrl)
}
return u.Host, u.Path, nil
}
func createRequestBody(model string, texts []string, operation operationType) (interface{}, error) {
modelParts := strings.Split(model, ".")
if len(modelParts) == 0 {
return nil, fmt.Errorf("invalid model: %s", model)
}
modelProvider := modelParts[0]
switch modelProvider {
case "amazon":
return bedrockEmbeddingsRequest{
InputText: texts[0],
}, nil
case "cohere":
inputType := "search_document"
if operation == vectorizeQuery {
inputType = "search_query"
}
return bedrockCohereEmbeddingRequest{
Texts: texts,
InputType: inputType,
}, nil
default:
return nil, fmt.Errorf("unknown model provider: %s", modelProvider)
}
}