Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package clients | |
import ( | |
"bytes" | |
"context" | |
"encoding/json" | |
"fmt" | |
"io" | |
"net/http" | |
"net/url" | |
"strings" | |
"time" | |
"github.com/weaviate/weaviate/usecases/modulecomponents" | |
"github.com/pkg/errors" | |
"github.com/sirupsen/logrus" | |
"github.com/weaviate/weaviate/entities/moduletools" | |
"github.com/weaviate/weaviate/modules/qna-openai/config" | |
"github.com/weaviate/weaviate/modules/qna-openai/ent" | |
) | |
func buildUrl(baseURL, resourceName, deploymentID string) (string, error) { | |
///X update with base url | |
if resourceName != "" && deploymentID != "" { | |
host := "https://" + resourceName + ".openai.azure.com" | |
path := "openai/deployments/" + deploymentID + "/completions" | |
queryParam := "api-version=2022-12-01" | |
return fmt.Sprintf("%s/%s?%s", host, path, queryParam), nil | |
} | |
host := baseURL | |
path := "/v1/completions" | |
return url.JoinPath(host, path) | |
} | |
type qna struct { | |
openAIApiKey string | |
openAIOrganization string | |
azureApiKey string | |
buildUrlFn func(baseURL, resourceName, deploymentID string) (string, error) | |
httpClient *http.Client | |
logger logrus.FieldLogger | |
} | |
func New(openAIApiKey, openAIOrganization, azureApiKey string, timeout time.Duration, logger logrus.FieldLogger) *qna { | |
return &qna{ | |
openAIApiKey: openAIApiKey, | |
openAIOrganization: openAIOrganization, | |
azureApiKey: azureApiKey, | |
httpClient: &http.Client{Timeout: timeout}, | |
buildUrlFn: buildUrl, | |
logger: logger, | |
} | |
} | |
func (v *qna) Answer(ctx context.Context, text, question string, cfg moduletools.ClassConfig) (*ent.AnswerResult, error) { | |
prompt := v.generatePrompt(text, question) | |
settings := config.NewClassSettings(cfg) | |
body, err := json.Marshal(answersInput{ | |
Prompt: prompt, | |
Model: settings.Model(), | |
MaxTokens: settings.MaxTokens(), | |
Temperature: settings.Temperature(), | |
Stop: []string{"\n"}, | |
FrequencyPenalty: settings.FrequencyPenalty(), | |
PresencePenalty: settings.PresencePenalty(), | |
TopP: settings.TopP(), | |
}) | |
if err != nil { | |
return nil, errors.Wrapf(err, "marshal body") | |
} | |
oaiUrl, err := v.buildOpenAIUrl(ctx, settings.BaseURL(), settings.ResourceName(), settings.DeploymentID()) | |
if err != nil { | |
return nil, errors.Wrap(err, "join OpenAI API host and path") | |
} | |
fmt.Printf("using the OpenAI URL: %v\n", oaiUrl) | |
req, err := http.NewRequestWithContext(ctx, "POST", oaiUrl, | |
bytes.NewReader(body)) | |
if err != nil { | |
return nil, errors.Wrap(err, "create POST request") | |
} | |
apiKey, err := v.getApiKey(ctx, settings.IsAzure()) | |
if err != nil { | |
return nil, errors.Wrapf(err, "OpenAI API Key") | |
} | |
req.Header.Add(v.getApiKeyHeaderAndValue(apiKey, settings.IsAzure())) | |
if openAIOrganization := v.getOpenAIOrganization(ctx); openAIOrganization != "" { | |
req.Header.Add("OpenAI-Organization", openAIOrganization) | |
} | |
req.Header.Add("Content-Type", "application/json") | |
res, err := v.httpClient.Do(req) | |
if err != nil { | |
return nil, errors.Wrap(err, "send POST request") | |
} | |
defer res.Body.Close() | |
bodyBytes, err := io.ReadAll(res.Body) | |
if err != nil { | |
return nil, errors.Wrap(err, "read response body") | |
} | |
var resBody answersResponse | |
if err := json.Unmarshal(bodyBytes, &resBody); err != nil { | |
return nil, errors.Wrap(err, "unmarshal response body") | |
} | |
if res.StatusCode != 200 || resBody.Error != nil { | |
return nil, v.getError(res.StatusCode, resBody.Error, settings.IsAzure()) | |
} | |
if len(resBody.Choices) > 0 && resBody.Choices[0].Text != "" { | |
return &ent.AnswerResult{ | |
Text: text, | |
Question: question, | |
Answer: &resBody.Choices[0].Text, | |
}, nil | |
} | |
return &ent.AnswerResult{ | |
Text: text, | |
Question: question, | |
Answer: nil, | |
}, nil | |
} | |
func (v *qna) buildOpenAIUrl(ctx context.Context, baseURL, resourceName, deploymentID string) (string, error) { | |
passedBaseURL := baseURL | |
if headerBaseURL := v.getValueFromContext(ctx, "X-Openai-Baseurl"); headerBaseURL != "" { | |
passedBaseURL = headerBaseURL | |
} | |
return v.buildUrlFn(passedBaseURL, resourceName, deploymentID) | |
} | |
func (v *qna) getError(statusCode int, resBodyError *openAIApiError, isAzure bool) error { | |
endpoint := "OpenAI API" | |
if isAzure { | |
endpoint = "Azure OpenAI API" | |
} | |
if resBodyError != nil { | |
return fmt.Errorf("connection to: %s failed with status: %d error: %v", endpoint, statusCode, resBodyError.Message) | |
} | |
return fmt.Errorf("connection to: %s failed with status: %d", endpoint, statusCode) | |
} | |
func (v *qna) getApiKeyHeaderAndValue(apiKey string, isAzure bool) (string, string) { | |
if isAzure { | |
return "api-key", apiKey | |
} | |
return "Authorization", fmt.Sprintf("Bearer %s", apiKey) | |
} | |
func (v *qna) generatePrompt(text string, question string) string { | |
return fmt.Sprintf(`'Please answer the question according to the above context. | |
=== | |
Context: %v | |
=== | |
Q: %v | |
A:`, strings.ReplaceAll(text, "\n", " "), question) | |
} | |
func (v *qna) getApiKey(ctx context.Context, isAzure bool) (string, error) { | |
var apiKey, envVar string | |
if isAzure { | |
apiKey = "X-Azure-Api-Key" | |
envVar = "AZURE_APIKEY" | |
if len(v.azureApiKey) > 0 { | |
return v.azureApiKey, nil | |
} | |
} else { | |
apiKey = "X-Openai-Api-Key" | |
envVar = "OPENAI_APIKEY" | |
if len(v.openAIApiKey) > 0 { | |
return v.openAIApiKey, nil | |
} | |
} | |
return v.getApiKeyFromContext(ctx, apiKey, envVar) | |
} | |
func (v *qna) getApiKeyFromContext(ctx context.Context, apiKey, envVar string) (string, error) { | |
if apiKeyValue := v.getValueFromContext(ctx, apiKey); apiKeyValue != "" { | |
return apiKeyValue, nil | |
} | |
return "", fmt.Errorf("no api key found neither in request header: %s nor in environment variable under %s", apiKey, envVar) | |
} | |
func (v *qna) getValueFromContext(ctx context.Context, key string) string { | |
if value := ctx.Value(key); value != nil { | |
if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 { | |
return keyHeader[0] | |
} | |
} | |
// try getting header from GRPC if not successful | |
if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 { | |
return apiKey[0] | |
} | |
return "" | |
} | |
func (v *qna) getOpenAIOrganization(ctx context.Context) string { | |
if value := v.getValueFromContext(ctx, "X-Openai-Organization"); value != "" { | |
return value | |
} | |
return v.openAIOrganization | |
} | |
type answersInput struct { | |
Prompt string `json:"prompt"` | |
Model string `json:"model"` | |
MaxTokens float64 `json:"max_tokens"` | |
Temperature float64 `json:"temperature"` | |
Stop []string `json:"stop"` | |
FrequencyPenalty float64 `json:"frequency_penalty"` | |
PresencePenalty float64 `json:"presence_penalty"` | |
TopP float64 `json:"top_p"` | |
} | |
type answersResponse struct { | |
Choices []choice | |
Error *openAIApiError `json:"error,omitempty"` | |
} | |
type choice struct { | |
FinishReason string | |
Index float32 | |
Logprobs string | |
Text string | |
} | |
type openAIApiError struct { | |
Message string `json:"message"` | |
Type string `json:"type"` | |
Param string `json:"param"` | |
Code json.Number `json:"code"` | |
} | |