KevinStephenson
Adding in weaviate code
b110593
raw
history blame
14.4 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package clients
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"time"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/weaviate/weaviate/entities/moduletools"
"github.com/weaviate/weaviate/modules/generative-aws/config"
generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models"
)
var compile, _ = regexp.Compile(`{([\w\s]*?)}`)
func buildBedrockUrl(service, region, model string) string {
urlTemplate := "https://%s.%s.amazonaws.com/model/%s/invoke"
return fmt.Sprintf(urlTemplate, fmt.Sprintf("%s-runtime", service), region, model)
}
func buildSagemakerUrl(service, region, endpoint string) string {
urlTemplate := "https://runtime.%s.%s.amazonaws.com/endpoints/%s/invocations"
return fmt.Sprintf(urlTemplate, service, region, endpoint)
}
type aws struct {
awsAccessKey string
awsSecretKey string
buildBedrockUrlFn func(service, region, model string) string
buildSagemakerUrlFn func(service, region, endpoint string) string
httpClient *http.Client
logger logrus.FieldLogger
}
func New(awsAccessKey string, awsSecretKey string, timeout time.Duration, logger logrus.FieldLogger) *aws {
return &aws{
awsAccessKey: awsAccessKey,
awsSecretKey: awsSecretKey,
httpClient: &http.Client{
Timeout: timeout,
},
buildBedrockUrlFn: buildBedrockUrl,
buildSagemakerUrlFn: buildSagemakerUrl,
logger: logger,
}
}
func (v *aws) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
forPrompt, err := v.generateForPrompt(textProperties, prompt)
if err != nil {
return nil, err
}
return v.Generate(ctx, cfg, forPrompt)
}
func (v *aws) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
forTask, err := v.generatePromptForTask(textProperties, task)
if err != nil {
return nil, err
}
return v.Generate(ctx, cfg, forTask)
}
func (v *aws) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) {
settings := config.NewClassSettings(cfg)
service := settings.Service()
region := settings.Region()
model := settings.Model()
endpoint := settings.Endpoint()
targetModel := settings.TargetModel()
targetVariant := settings.TargetVariant()
var body []byte
var endpointUrl string
var host string
var path string
var err error
headers := map[string]string{
"accept": "*/*",
"content-type": contentType,
}
if v.isBedrock(service) {
endpointUrl = v.buildBedrockUrlFn(service, region, model)
host = service + "-runtime" + "." + region + ".amazonaws.com"
path = "/model/" + model + "/invoke"
if v.isAmazonModel(model) {
body, err = json.Marshal(bedrockAmazonGenerateRequest{
InputText: prompt,
})
} else if v.isAnthropicModel(model) {
var builder strings.Builder
builder.WriteString("\n\nHuman: ")
builder.WriteString(prompt)
builder.WriteString("\n\nAssistant:")
body, err = json.Marshal(bedrockAnthropicGenerateRequest{
Prompt: builder.String(),
MaxTokensToSample: *settings.MaxTokenCount(),
Temperature: *settings.Temperature(),
TopK: *settings.TopK(),
TopP: settings.TopP(),
StopSequences: settings.StopSequences(),
AnthropicVersion: "bedrock-2023-05-31",
})
} else if v.isAI21Model(model) {
body, err = json.Marshal(bedrockAI21GenerateRequest{
Prompt: prompt,
MaxTokens: *settings.MaxTokenCount(),
Temperature: *settings.Temperature(),
TopP: settings.TopP(),
StopSequences: settings.StopSequences(),
})
} else if v.isCohereModel(model) {
body, err = json.Marshal(bedrockCohereRequest{
Prompt: prompt,
Temperature: *settings.Temperature(),
MaxTokens: *settings.MaxTokenCount(),
// ReturnLikeliHood: "GENERATION", // contray to docs, this is invalid
})
}
headers["x-amzn-bedrock-save"] = "false"
if err != nil {
return nil, errors.Wrapf(err, "marshal body")
}
} else if v.isSagemaker(service) {
endpointUrl = v.buildSagemakerUrlFn(service, region, endpoint)
host = "runtime." + service + "." + region + ".amazonaws.com"
path = "/endpoints/" + endpoint + "/invocations"
if targetModel != "" {
headers["x-amzn-sagemaker-target-model"] = targetModel
}
if targetVariant != "" {
headers["x-amzn-sagemaker-target-variant"] = targetVariant
}
body, err = json.Marshal(sagemakerGenerateRequest{
Prompt: prompt,
})
if err != nil {
return nil, errors.Wrapf(err, "marshal body")
}
} else {
return nil, errors.Wrapf(err, "service error")
}
accessKey, err := v.getAwsAccessKey(ctx)
if err != nil {
return nil, errors.Wrapf(err, "AWS Access Key")
}
secretKey, err := v.getAwsAccessSecret(ctx)
if err != nil {
return nil, errors.Wrapf(err, "AWS Secret Key")
}
headers["host"] = host
amzDate, headers, authorizationHeader := getAuthHeader(accessKey, secretKey, host, service, region, path, body, headers)
headers["Authorization"] = authorizationHeader
headers["x-amz-date"] = amzDate
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointUrl, bytes.NewReader(body))
if err != nil {
return nil, errors.Wrap(err, "create POST request")
}
for k, v := range headers {
req.Header.Set(k, v)
}
res, err := v.httpClient.Do(req)
if err != nil {
return nil, errors.Wrap(err, "send POST request")
}
defer res.Body.Close()
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
return nil, errors.Wrap(err, "read response body")
}
if v.isBedrock(service) {
return v.parseBedrockResponse(bodyBytes, res)
} else if v.isSagemaker(service) {
return v.parseSagemakerResponse(bodyBytes, res)
} else {
return &generativemodels.GenerateResponse{
Result: nil,
}, nil
}
}
func (v *aws) parseBedrockResponse(bodyBytes []byte, res *http.Response) (*generativemodels.GenerateResponse, error) {
var resBodyMap map[string]interface{}
if err := json.Unmarshal(bodyBytes, &resBodyMap); err != nil {
return nil, errors.Wrap(err, "unmarshal response body")
}
var resBody bedrockGenerateResponse
if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
return nil, errors.Wrap(err, "unmarshal response body")
}
if res.StatusCode != 200 || resBody.Message != nil {
if resBody.Message != nil {
return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %v error: %s",
res.StatusCode, *resBody.Message)
}
return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %d", res.StatusCode)
}
if len(resBody.Results) == 0 && len(resBody.Generations) == 0 {
return nil, fmt.Errorf("received empty response from AWS Bedrock")
}
var content string
if len(resBody.Results) > 0 && len(resBody.Results[0].CompletionReason) > 0 {
content = resBody.Results[0].OutputText
} else if len(resBody.Generations) > 0 {
content = resBody.Generations[0].Text
}
if content != "" {
return &generativemodels.GenerateResponse{
Result: &content,
}, nil
}
return &generativemodels.GenerateResponse{
Result: nil,
}, nil
}
func (v *aws) parseSagemakerResponse(bodyBytes []byte, res *http.Response) (*generativemodels.GenerateResponse, error) {
var resBody sagemakerGenerateResponse
if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
return nil, errors.Wrap(err, "unmarshal response body")
}
if res.StatusCode != 200 || resBody.Message != nil {
if resBody.Message != nil {
return nil, fmt.Errorf("connection to AWS Sagemaker failed with status: %v error: %s",
res.StatusCode, *resBody.Message)
}
return nil, fmt.Errorf("connection to AWS Sagemaker failed with status: %d", res.StatusCode)
}
if len(resBody.Generations) == 0 {
return nil, fmt.Errorf("received empty response from AWS Sagemaker")
}
if len(resBody.Generations) > 0 && len(resBody.Generations[0].Id) > 0 {
content := resBody.Generations[0].Text
if content != "" {
return &generativemodels.GenerateResponse{
Result: &content,
}, nil
}
}
return &generativemodels.GenerateResponse{
Result: nil,
}, nil
}
func (v *aws) isSagemaker(service string) bool {
return service == "sagemaker"
}
func (v *aws) isBedrock(service string) bool {
return service == "bedrock"
}
func (v *aws) generatePromptForTask(textProperties []map[string]string, task string) (string, error) {
marshal, err := json.Marshal(textProperties)
if err != nil {
return "", err
}
return fmt.Sprintf(`'%v:
%v`, task, string(marshal)), nil
}
func (v *aws) generateForPrompt(textProperties map[string]string, prompt string) (string, error) {
all := compile.FindAll([]byte(prompt), -1)
for _, match := range all {
originalProperty := string(match)
replacedProperty := compile.FindStringSubmatch(originalProperty)[1]
replacedProperty = strings.TrimSpace(replacedProperty)
value := textProperties[replacedProperty]
if value == "" {
return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty)
}
prompt = strings.ReplaceAll(prompt, originalProperty, value)
}
return prompt, nil
}
func (v *aws) getAwsAccessKey(ctx context.Context) (string, error) {
awsAccessKey := ctx.Value("X-Aws-Access-Key")
if awsAccessKeyHeader, ok := awsAccessKey.([]string); ok &&
len(awsAccessKeyHeader) > 0 && len(awsAccessKeyHeader[0]) > 0 {
return awsAccessKeyHeader[0], nil
}
if len(v.awsAccessKey) > 0 {
return v.awsAccessKey, nil
}
return "", errors.New("no access key found " +
"neither in request header: X-AWS-Access-Key " +
"nor in environment variable under AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY")
}
func (v *aws) getAwsAccessSecret(ctx context.Context) (string, error) {
awsAccessSecret := ctx.Value("X-Aws-Secret-Key")
if awsAccessSecretHeader, ok := awsAccessSecret.([]string); ok &&
len(awsAccessSecretHeader) > 0 && len(awsAccessSecretHeader[0]) > 0 {
return awsAccessSecretHeader[0], nil
}
if len(v.awsSecretKey) > 0 {
return v.awsSecretKey, nil
}
return "", errors.New("no secret found " +
"neither in request header: X-Aws-Secret-Key " +
"nor in environment variable under AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY")
}
func (v *aws) isAmazonModel(model string) bool {
return strings.Contains(model, "amazon")
}
func (v *aws) isAI21Model(model string) bool {
return strings.Contains(model, "ai21")
}
func (v *aws) isAnthropicModel(model string) bool {
return strings.Contains(model, "anthropic")
}
func (v *aws) isCohereModel(model string) bool {
return strings.Contains(model, "cohere")
}
type bedrockAmazonGenerateRequest struct {
InputText string `json:"inputText,omitempty"`
TextGenerationConfig *textGenerationConfig `json:"textGenerationConfig,omitempty"`
}
type bedrockAnthropicGenerateRequest struct {
Prompt string `json:"prompt,omitempty"`
MaxTokensToSample int `json:"max_tokens_to_sample,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
AnthropicVersion string `json:"anthropic_version,omitempty"`
}
type bedrockAI21GenerateRequest struct {
Prompt string `json:"prompt,omitempty"`
MaxTokens int `json:"maxTokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
CountPenalty penalty `json:"countPenalty,omitempty"`
PresencePenalty penalty `json:"presencePenalty,omitempty"`
FrequencyPenalty penalty `json:"frequencyPenalty,omitempty"`
}
type bedrockCohereRequest struct {
Prompt string `json:"prompt,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
ReturnLikeliHood string `json:"return_likelihood,omitempty"`
}
type penalty struct {
Scale int `json:"scale,omitempty"`
}
type sagemakerGenerateRequest struct {
Prompt string `json:"prompt,omitempty"`
}
type textGenerationConfig struct {
MaxTokenCount int `json:"maxTokenCount"`
StopSequences []string `json:"stopSequences"`
Temperature float64 `json:"temperature"`
TopP int `json:"topP"`
}
type bedrockGenerateResponse struct {
InputTextTokenCount int `json:"InputTextTokenCount,omitempty"`
Results []Result `json:"results,omitempty"`
Generations []BedrockGeneration `json:"generations,omitempty"`
Message *string `json:"message,omitempty"`
}
type sagemakerGenerateResponse struct {
Generations []Generation `json:"generations,omitempty"`
Message *string `json:"message,omitempty"`
}
type Generation struct {
Id string `json:"id,omitempty"`
Text string `json:"text,omitempty"`
}
type BedrockGeneration struct {
Id string `json:"id,omitempty"`
Text string `json:"text,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
}
type Result struct {
TokenCount int `json:"tokenCount,omitempty"`
OutputText string `json:"outputText,omitempty"`
CompletionReason string `json:"completionReason,omitempty"`
}