KevinStephenson
Adding in weaviate code
b110593
raw
history blame
11.2 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package clients
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/weaviate/weaviate/usecases/modulecomponents"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/weaviate/weaviate/entities/moduletools"
"github.com/weaviate/weaviate/modules/generative-openai/config"
generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models"
)
var compile, _ = regexp.Compile(`{([\w\s]*?)}`)
func buildUrlFn(isLegacy bool, resourceName, deploymentID, baseURL string) (string, error) {
if resourceName != "" && deploymentID != "" {
host := "https://" + resourceName + ".openai.azure.com"
path := "openai/deployments/" + deploymentID + "/chat/completions"
queryParam := "api-version=2023-03-15-preview"
return fmt.Sprintf("%s/%s?%s", host, path, queryParam), nil
}
path := "/v1/chat/completions"
if isLegacy {
path = "/v1/completions"
}
return url.JoinPath(baseURL, path)
}
type openai struct {
openAIApiKey string
openAIOrganization string
azureApiKey string
buildUrl func(isLegacy bool, resourceName, deploymentID, baseURL string) (string, error)
httpClient *http.Client
logger logrus.FieldLogger
}
func New(openAIApiKey, openAIOrganization, azureApiKey string, timeout time.Duration, logger logrus.FieldLogger) *openai {
return &openai{
openAIApiKey: openAIApiKey,
openAIOrganization: openAIOrganization,
azureApiKey: azureApiKey,
httpClient: &http.Client{
Timeout: timeout,
},
buildUrl: buildUrlFn,
logger: logger,
}
}
func (v *openai) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
forPrompt, err := v.generateForPrompt(textProperties, prompt)
if err != nil {
return nil, err
}
return v.Generate(ctx, cfg, forPrompt)
}
func (v *openai) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
forTask, err := v.generatePromptForTask(textProperties, task)
if err != nil {
return nil, err
}
return v.Generate(ctx, cfg, forTask)
}
func (v *openai) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) {
settings := config.NewClassSettings(cfg)
oaiUrl, err := v.buildOpenAIUrl(ctx, settings)
if err != nil {
return nil, errors.Wrap(err, "url join path")
}
input, err := v.generateInput(prompt, settings)
if err != nil {
return nil, errors.Wrap(err, "generate input")
}
body, err := json.Marshal(input)
if err != nil {
return nil, errors.Wrap(err, "marshal body")
}
req, err := http.NewRequestWithContext(ctx, "POST", oaiUrl,
bytes.NewReader(body))
if err != nil {
return nil, errors.Wrap(err, "create POST request")
}
apiKey, err := v.getApiKey(ctx, settings.IsAzure())
if err != nil {
return nil, errors.Wrapf(err, "OpenAI API Key")
}
req.Header.Add(v.getApiKeyHeaderAndValue(apiKey, settings.IsAzure()))
if openAIOrganization := v.getOpenAIOrganization(ctx); openAIOrganization != "" {
req.Header.Add("OpenAI-Organization", openAIOrganization)
}
req.Header.Add("Content-Type", "application/json")
res, err := v.httpClient.Do(req)
if err != nil {
return nil, errors.Wrap(err, "send POST request")
}
defer res.Body.Close()
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
return nil, errors.Wrap(err, "read response body")
}
var resBody generateResponse
if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
return nil, errors.Wrap(err, "unmarshal response body")
}
if res.StatusCode != 200 || resBody.Error != nil {
return nil, v.getError(res.StatusCode, resBody.Error, settings.IsAzure())
}
textResponse := resBody.Choices[0].Text
if len(resBody.Choices) > 0 && textResponse != "" {
trimmedResponse := strings.Trim(textResponse, "\n")
return &generativemodels.GenerateResponse{
Result: &trimmedResponse,
}, nil
}
message := resBody.Choices[0].Message
if message != nil {
textResponse = message.Content
trimmedResponse := strings.Trim(textResponse, "\n")
return &generativemodels.GenerateResponse{
Result: &trimmedResponse,
}, nil
}
return &generativemodels.GenerateResponse{
Result: nil,
}, nil
}
func (v *openai) buildOpenAIUrl(ctx context.Context, settings config.ClassSettings) (string, error) {
baseURL := settings.BaseURL()
if headerBaseURL := v.getValueFromContext(ctx, "X-Openai-Baseurl"); headerBaseURL != "" {
baseURL = headerBaseURL
}
return v.buildUrl(settings.IsLegacy(), settings.ResourceName(), settings.DeploymentID(), baseURL)
}
func (v *openai) generateInput(prompt string, settings config.ClassSettings) (generateInput, error) {
if settings.IsLegacy() {
return generateInput{
Prompt: prompt,
Model: settings.Model(),
MaxTokens: settings.MaxTokens(),
Temperature: settings.Temperature(),
FrequencyPenalty: settings.FrequencyPenalty(),
PresencePenalty: settings.PresencePenalty(),
TopP: settings.TopP(),
}, nil
} else {
var input generateInput
messages := []message{{
Role: "user",
Content: prompt,
}}
tokens, err := v.determineTokens(settings.GetMaxTokensForModel(settings.Model()), settings.MaxTokens(), settings.Model(), messages)
if err != nil {
return input, errors.Wrap(err, "determine tokens count")
}
input = generateInput{
Messages: messages,
MaxTokens: tokens,
Temperature: settings.Temperature(),
FrequencyPenalty: settings.FrequencyPenalty(),
PresencePenalty: settings.PresencePenalty(),
TopP: settings.TopP(),
}
if !settings.IsAzure() {
// model is mandatory for OpenAI calls, but obsolete for Azure calls
input.Model = settings.Model()
}
return input, nil
}
}
func (v *openai) getError(statusCode int, resBodyError *openAIApiError, isAzure bool) error {
endpoint := "OpenAI API"
if isAzure {
endpoint = "Azure OpenAI API"
}
if resBodyError != nil {
return fmt.Errorf("connection to: %s failed with status: %d error: %v", endpoint, statusCode, resBodyError.Message)
}
return fmt.Errorf("connection to: %s failed with status: %d", endpoint, statusCode)
}
func (v *openai) determineTokens(maxTokensSetting float64, classSetting float64, model string, messages []message) (float64, error) {
tokenMessagesCount, err := getTokensCount(model, messages)
if err != nil {
return 0, err
}
messageTokens := float64(tokenMessagesCount)
if messageTokens+classSetting >= maxTokensSetting {
// max token limit must be in range: [1, maxTokensSetting) that's why -1 is added
return maxTokensSetting - messageTokens - 1, nil
}
return messageTokens, nil
}
func (v *openai) getApiKeyHeaderAndValue(apiKey string, isAzure bool) (string, string) {
if isAzure {
return "api-key", apiKey
}
return "Authorization", fmt.Sprintf("Bearer %s", apiKey)
}
func (v *openai) generatePromptForTask(textProperties []map[string]string, task string) (string, error) {
marshal, err := json.Marshal(textProperties)
if err != nil {
return "", err
}
return fmt.Sprintf(`'%v:
%v`, task, string(marshal)), nil
}
func (v *openai) generateForPrompt(textProperties map[string]string, prompt string) (string, error) {
all := compile.FindAll([]byte(prompt), -1)
for _, match := range all {
originalProperty := string(match)
replacedProperty := compile.FindStringSubmatch(originalProperty)[1]
replacedProperty = strings.TrimSpace(replacedProperty)
value := textProperties[replacedProperty]
if value == "" {
return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty)
}
prompt = strings.ReplaceAll(prompt, originalProperty, value)
}
return prompt, nil
}
func (v *openai) getApiKey(ctx context.Context, isAzure bool) (string, error) {
var apiKey, envVar string
if isAzure {
apiKey = "X-Azure-Api-Key"
envVar = "AZURE_APIKEY"
if len(v.azureApiKey) > 0 {
return v.azureApiKey, nil
}
} else {
apiKey = "X-Openai-Api-Key"
envVar = "OPENAI_APIKEY"
if len(v.openAIApiKey) > 0 {
return v.openAIApiKey, nil
}
}
return v.getApiKeyFromContext(ctx, apiKey, envVar)
}
func (v *openai) getApiKeyFromContext(ctx context.Context, apiKey, envVar string) (string, error) {
if apiKeyValue := v.getValueFromContext(ctx, apiKey); apiKeyValue != "" {
return apiKeyValue, nil
}
return "", fmt.Errorf("no api key found neither in request header: %s nor in environment variable under %s", apiKey, envVar)
}
func (v *openai) getValueFromContext(ctx context.Context, key string) string {
if value := ctx.Value(key); value != nil {
if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 {
return keyHeader[0]
}
}
// try getting header from GRPC if not successful
if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 {
return apiKey[0]
}
return ""
}
func (v *openai) getOpenAIOrganization(ctx context.Context) string {
if value := v.getValueFromContext(ctx, "X-Openai-Organization"); value != "" {
return value
}
return v.openAIOrganization
}
type generateInput struct {
Prompt string `json:"prompt,omitempty"`
Messages []message `json:"messages,omitempty"`
Model string `json:"model,omitempty"`
MaxTokens float64 `json:"max_tokens"`
Temperature float64 `json:"temperature"`
Stop []string `json:"stop"`
FrequencyPenalty float64 `json:"frequency_penalty"`
PresencePenalty float64 `json:"presence_penalty"`
TopP float64 `json:"top_p"`
}
type message struct {
Role string `json:"role"`
Content string `json:"content"`
Name string `json:"name,omitempty"`
}
type generateResponse struct {
Choices []choice
Error *openAIApiError `json:"error,omitempty"`
}
type choice struct {
FinishReason string
Index float32
Logprobs string
Text string `json:"text,omitempty"`
Message *message `json:"message,omitempty"`
}
type openAIApiError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code json.Number `json:"code"`
}