KevinStephenson
Adding in weaviate code
b110593
raw
history blame
6.89 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package clients
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"time"
"github.com/weaviate/weaviate/usecases/modulecomponents"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/weaviate/weaviate/entities/moduletools"
"github.com/weaviate/weaviate/modules/generative-anyscale/config"
generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models"
)
var compile, _ = regexp.Compile(`{([\w\s]*?)}`)
type anyscale struct {
apiKey string
httpClient *http.Client
logger logrus.FieldLogger
}
func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *anyscale {
return &anyscale{
apiKey: apiKey,
httpClient: &http.Client{
Timeout: timeout,
},
logger: logger,
}
}
func (v *anyscale) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
forPrompt, err := v.generateForPrompt(textProperties, prompt)
if err != nil {
return nil, err
}
return v.Generate(ctx, cfg, forPrompt)
}
func (v *anyscale) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
forTask, err := v.generatePromptForTask(textProperties, task)
if err != nil {
return nil, err
}
return v.Generate(ctx, cfg, forTask)
}
func (v *anyscale) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) {
settings := config.NewClassSettings(cfg)
anyscaleUrl := v.getAnyscaleUrl(ctx, settings.BaseURL())
anyscalePrompt := []map[string]string{
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
}
input := generateInput{
Messages: anyscalePrompt,
Model: settings.Model(),
Temperature: settings.Temperature(),
}
body, err := json.Marshal(input)
if err != nil {
return nil, errors.Wrap(err, "marshal body")
}
req, err := http.NewRequestWithContext(ctx, "POST", anyscaleUrl,
bytes.NewReader(body))
if err != nil {
return nil, errors.Wrap(err, "create POST request")
}
apiKey, err := v.getApiKey(ctx)
if err != nil {
return nil, errors.Wrapf(err, "Anyscale (OpenAI) API Key")
}
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey))
req.Header.Add("Content-Type", "application/json")
res, err := v.httpClient.Do(req)
if err != nil {
return nil, errors.Wrap(err, "send POST request")
}
defer res.Body.Close()
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
return nil, errors.Wrap(err, "read response body")
}
var resBody generateResponse
if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
return nil, errors.Wrap(err, "unmarshal response body")
}
if res.StatusCode != 200 || resBody.Error != nil {
if resBody.Error != nil {
return nil, errors.Errorf("connection to Anyscale API failed with status: %d error: %v", res.StatusCode, resBody.Error.Message)
}
return nil, errors.Errorf("connection to Anyscale API failed with status: %d", res.StatusCode)
}
textResponse := resBody.Choices[0].Message.Content
return &generativemodels.GenerateResponse{
Result: &textResponse,
}, nil
}
func (v *anyscale) getAnyscaleUrl(ctx context.Context, baseURL string) string {
passedBaseURL := baseURL
if headerBaseURL := v.getValueFromContext(ctx, "X-Anyscale-Baseurl"); headerBaseURL != "" {
passedBaseURL = headerBaseURL
}
return fmt.Sprintf("%s/v1/chat/completions", passedBaseURL)
}
func (v *anyscale) generatePromptForTask(textProperties []map[string]string, task string) (string, error) {
marshal, err := json.Marshal(textProperties)
if err != nil {
return "", err
}
return fmt.Sprintf(`'%v:
%v`, task, string(marshal)), nil
}
func (v *anyscale) generateForPrompt(textProperties map[string]string, prompt string) (string, error) {
all := compile.FindAll([]byte(prompt), -1)
for _, match := range all {
originalProperty := string(match)
replacedProperty := compile.FindStringSubmatch(originalProperty)[1]
replacedProperty = strings.TrimSpace(replacedProperty)
value := textProperties[replacedProperty]
if value == "" {
return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty)
}
prompt = strings.ReplaceAll(prompt, originalProperty, value)
}
return prompt, nil
}
func (v *anyscale) getValueFromContext(ctx context.Context, key string) string {
if value := ctx.Value(key); value != nil {
if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 {
return keyHeader[0]
}
}
// try getting header from GRPC if not successful
if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 {
return apiKey[0]
}
return ""
}
func (v *anyscale) getApiKey(ctx context.Context) (string, error) {
// note Anyscale uses the OpenAI API Key in it's requests.
if apiKey := v.getValueFromContext(ctx, "X-Anyscale-Api-Key"); apiKey != "" {
return apiKey, nil
}
if v.apiKey != "" {
return v.apiKey, nil
}
return "", errors.New("no api key found " +
"neither in request header: X-Anyscale-Api-Key " +
"nor in environment variable under ANYSCALE_APIKEY")
}
type generateInput struct {
Model string `json:"model"`
Messages []map[string]string `json:"messages"`
Temperature int `json:"temperature"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type Choice struct {
Message Message `json:"message"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
}
// The entire response for an error ends up looking different, may want to add omitempty everywhere.
type generateResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
Usage map[string]int `json:"usage"`
Error *anyscaleApiError `json:"error,omitempty"`
}
type anyscaleApiError struct {
Message string `json:"message"`
}