Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package clients | |
import ( | |
"bytes" | |
"context" | |
"encoding/json" | |
"fmt" | |
"io" | |
"net/http" | |
"net/url" | |
"regexp" | |
"strings" | |
"time" | |
"github.com/weaviate/weaviate/usecases/modulecomponents" | |
"github.com/pkg/errors" | |
"github.com/sirupsen/logrus" | |
"github.com/weaviate/weaviate/entities/moduletools" | |
"github.com/weaviate/weaviate/modules/generative-openai/config" | |
generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models" | |
) | |
var compile, _ = regexp.Compile(`{([\w\s]*?)}`) | |
func buildUrlFn(isLegacy bool, resourceName, deploymentID, baseURL string) (string, error) { | |
if resourceName != "" && deploymentID != "" { | |
host := "https://" + resourceName + ".openai.azure.com" | |
path := "openai/deployments/" + deploymentID + "/chat/completions" | |
queryParam := "api-version=2023-03-15-preview" | |
return fmt.Sprintf("%s/%s?%s", host, path, queryParam), nil | |
} | |
path := "/v1/chat/completions" | |
if isLegacy { | |
path = "/v1/completions" | |
} | |
return url.JoinPath(baseURL, path) | |
} | |
type openai struct { | |
openAIApiKey string | |
openAIOrganization string | |
azureApiKey string | |
buildUrl func(isLegacy bool, resourceName, deploymentID, baseURL string) (string, error) | |
httpClient *http.Client | |
logger logrus.FieldLogger | |
} | |
func New(openAIApiKey, openAIOrganization, azureApiKey string, timeout time.Duration, logger logrus.FieldLogger) *openai { | |
return &openai{ | |
openAIApiKey: openAIApiKey, | |
openAIOrganization: openAIOrganization, | |
azureApiKey: azureApiKey, | |
httpClient: &http.Client{ | |
Timeout: timeout, | |
}, | |
buildUrl: buildUrlFn, | |
logger: logger, | |
} | |
} | |
func (v *openai) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { | |
forPrompt, err := v.generateForPrompt(textProperties, prompt) | |
if err != nil { | |
return nil, err | |
} | |
return v.Generate(ctx, cfg, forPrompt) | |
} | |
func (v *openai) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { | |
forTask, err := v.generatePromptForTask(textProperties, task) | |
if err != nil { | |
return nil, err | |
} | |
return v.Generate(ctx, cfg, forTask) | |
} | |
func (v *openai) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) { | |
settings := config.NewClassSettings(cfg) | |
oaiUrl, err := v.buildOpenAIUrl(ctx, settings) | |
if err != nil { | |
return nil, errors.Wrap(err, "url join path") | |
} | |
input, err := v.generateInput(prompt, settings) | |
if err != nil { | |
return nil, errors.Wrap(err, "generate input") | |
} | |
body, err := json.Marshal(input) | |
if err != nil { | |
return nil, errors.Wrap(err, "marshal body") | |
} | |
req, err := http.NewRequestWithContext(ctx, "POST", oaiUrl, | |
bytes.NewReader(body)) | |
if err != nil { | |
return nil, errors.Wrap(err, "create POST request") | |
} | |
apiKey, err := v.getApiKey(ctx, settings.IsAzure()) | |
if err != nil { | |
return nil, errors.Wrapf(err, "OpenAI API Key") | |
} | |
req.Header.Add(v.getApiKeyHeaderAndValue(apiKey, settings.IsAzure())) | |
if openAIOrganization := v.getOpenAIOrganization(ctx); openAIOrganization != "" { | |
req.Header.Add("OpenAI-Organization", openAIOrganization) | |
} | |
req.Header.Add("Content-Type", "application/json") | |
res, err := v.httpClient.Do(req) | |
if err != nil { | |
return nil, errors.Wrap(err, "send POST request") | |
} | |
defer res.Body.Close() | |
bodyBytes, err := io.ReadAll(res.Body) | |
if err != nil { | |
return nil, errors.Wrap(err, "read response body") | |
} | |
var resBody generateResponse | |
if err := json.Unmarshal(bodyBytes, &resBody); err != nil { | |
return nil, errors.Wrap(err, "unmarshal response body") | |
} | |
if res.StatusCode != 200 || resBody.Error != nil { | |
return nil, v.getError(res.StatusCode, resBody.Error, settings.IsAzure()) | |
} | |
textResponse := resBody.Choices[0].Text | |
if len(resBody.Choices) > 0 && textResponse != "" { | |
trimmedResponse := strings.Trim(textResponse, "\n") | |
return &generativemodels.GenerateResponse{ | |
Result: &trimmedResponse, | |
}, nil | |
} | |
message := resBody.Choices[0].Message | |
if message != nil { | |
textResponse = message.Content | |
trimmedResponse := strings.Trim(textResponse, "\n") | |
return &generativemodels.GenerateResponse{ | |
Result: &trimmedResponse, | |
}, nil | |
} | |
return &generativemodels.GenerateResponse{ | |
Result: nil, | |
}, nil | |
} | |
func (v *openai) buildOpenAIUrl(ctx context.Context, settings config.ClassSettings) (string, error) { | |
baseURL := settings.BaseURL() | |
if headerBaseURL := v.getValueFromContext(ctx, "X-Openai-Baseurl"); headerBaseURL != "" { | |
baseURL = headerBaseURL | |
} | |
return v.buildUrl(settings.IsLegacy(), settings.ResourceName(), settings.DeploymentID(), baseURL) | |
} | |
func (v *openai) generateInput(prompt string, settings config.ClassSettings) (generateInput, error) { | |
if settings.IsLegacy() { | |
return generateInput{ | |
Prompt: prompt, | |
Model: settings.Model(), | |
MaxTokens: settings.MaxTokens(), | |
Temperature: settings.Temperature(), | |
FrequencyPenalty: settings.FrequencyPenalty(), | |
PresencePenalty: settings.PresencePenalty(), | |
TopP: settings.TopP(), | |
}, nil | |
} else { | |
var input generateInput | |
messages := []message{{ | |
Role: "user", | |
Content: prompt, | |
}} | |
tokens, err := v.determineTokens(settings.GetMaxTokensForModel(settings.Model()), settings.MaxTokens(), settings.Model(), messages) | |
if err != nil { | |
return input, errors.Wrap(err, "determine tokens count") | |
} | |
input = generateInput{ | |
Messages: messages, | |
MaxTokens: tokens, | |
Temperature: settings.Temperature(), | |
FrequencyPenalty: settings.FrequencyPenalty(), | |
PresencePenalty: settings.PresencePenalty(), | |
TopP: settings.TopP(), | |
} | |
if !settings.IsAzure() { | |
// model is mandatory for OpenAI calls, but obsolete for Azure calls | |
input.Model = settings.Model() | |
} | |
return input, nil | |
} | |
} | |
func (v *openai) getError(statusCode int, resBodyError *openAIApiError, isAzure bool) error { | |
endpoint := "OpenAI API" | |
if isAzure { | |
endpoint = "Azure OpenAI API" | |
} | |
if resBodyError != nil { | |
return fmt.Errorf("connection to: %s failed with status: %d error: %v", endpoint, statusCode, resBodyError.Message) | |
} | |
return fmt.Errorf("connection to: %s failed with status: %d", endpoint, statusCode) | |
} | |
func (v *openai) determineTokens(maxTokensSetting float64, classSetting float64, model string, messages []message) (float64, error) { | |
tokenMessagesCount, err := getTokensCount(model, messages) | |
if err != nil { | |
return 0, err | |
} | |
messageTokens := float64(tokenMessagesCount) | |
if messageTokens+classSetting >= maxTokensSetting { | |
// max token limit must be in range: [1, maxTokensSetting) that's why -1 is added | |
return maxTokensSetting - messageTokens - 1, nil | |
} | |
return messageTokens, nil | |
} | |
func (v *openai) getApiKeyHeaderAndValue(apiKey string, isAzure bool) (string, string) { | |
if isAzure { | |
return "api-key", apiKey | |
} | |
return "Authorization", fmt.Sprintf("Bearer %s", apiKey) | |
} | |
func (v *openai) generatePromptForTask(textProperties []map[string]string, task string) (string, error) { | |
marshal, err := json.Marshal(textProperties) | |
if err != nil { | |
return "", err | |
} | |
return fmt.Sprintf(`'%v: | |
%v`, task, string(marshal)), nil | |
} | |
func (v *openai) generateForPrompt(textProperties map[string]string, prompt string) (string, error) { | |
all := compile.FindAll([]byte(prompt), -1) | |
for _, match := range all { | |
originalProperty := string(match) | |
replacedProperty := compile.FindStringSubmatch(originalProperty)[1] | |
replacedProperty = strings.TrimSpace(replacedProperty) | |
value := textProperties[replacedProperty] | |
if value == "" { | |
return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty) | |
} | |
prompt = strings.ReplaceAll(prompt, originalProperty, value) | |
} | |
return prompt, nil | |
} | |
func (v *openai) getApiKey(ctx context.Context, isAzure bool) (string, error) { | |
var apiKey, envVar string | |
if isAzure { | |
apiKey = "X-Azure-Api-Key" | |
envVar = "AZURE_APIKEY" | |
if len(v.azureApiKey) > 0 { | |
return v.azureApiKey, nil | |
} | |
} else { | |
apiKey = "X-Openai-Api-Key" | |
envVar = "OPENAI_APIKEY" | |
if len(v.openAIApiKey) > 0 { | |
return v.openAIApiKey, nil | |
} | |
} | |
return v.getApiKeyFromContext(ctx, apiKey, envVar) | |
} | |
func (v *openai) getApiKeyFromContext(ctx context.Context, apiKey, envVar string) (string, error) { | |
if apiKeyValue := v.getValueFromContext(ctx, apiKey); apiKeyValue != "" { | |
return apiKeyValue, nil | |
} | |
return "", fmt.Errorf("no api key found neither in request header: %s nor in environment variable under %s", apiKey, envVar) | |
} | |
func (v *openai) getValueFromContext(ctx context.Context, key string) string { | |
if value := ctx.Value(key); value != nil { | |
if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 { | |
return keyHeader[0] | |
} | |
} | |
// try getting header from GRPC if not successful | |
if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 { | |
return apiKey[0] | |
} | |
return "" | |
} | |
func (v *openai) getOpenAIOrganization(ctx context.Context) string { | |
if value := v.getValueFromContext(ctx, "X-Openai-Organization"); value != "" { | |
return value | |
} | |
return v.openAIOrganization | |
} | |
type generateInput struct { | |
Prompt string `json:"prompt,omitempty"` | |
Messages []message `json:"messages,omitempty"` | |
Model string `json:"model,omitempty"` | |
MaxTokens float64 `json:"max_tokens"` | |
Temperature float64 `json:"temperature"` | |
Stop []string `json:"stop"` | |
FrequencyPenalty float64 `json:"frequency_penalty"` | |
PresencePenalty float64 `json:"presence_penalty"` | |
TopP float64 `json:"top_p"` | |
} | |
type message struct { | |
Role string `json:"role"` | |
Content string `json:"content"` | |
Name string `json:"name,omitempty"` | |
} | |
type generateResponse struct { | |
Choices []choice | |
Error *openAIApiError `json:"error,omitempty"` | |
} | |
type choice struct { | |
FinishReason string | |
Index float32 | |
Logprobs string | |
Text string `json:"text,omitempty"` | |
Message *message `json:"message,omitempty"` | |
} | |
type openAIApiError struct { | |
Message string `json:"message"` | |
Type string `json:"type"` | |
Param string `json:"param"` | |
Code json.Number `json:"code"` | |
} | |