Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package clients | |
import ( | |
"bytes" | |
"context" | |
"encoding/json" | |
"fmt" | |
"io" | |
"net/http" | |
"net/url" | |
"time" | |
"github.com/weaviate/weaviate/usecases/modulecomponents" | |
"github.com/pkg/errors" | |
"github.com/sirupsen/logrus" | |
"github.com/weaviate/weaviate/modules/text2vec-openai/ent" | |
) | |
type embeddingsRequest struct { | |
Input []string `json:"input"` | |
Model string `json:"model,omitempty"` | |
} | |
type embedding struct { | |
Object string `json:"object"` | |
Data []embeddingData `json:"data,omitempty"` | |
Error *openAIApiError `json:"error,omitempty"` | |
} | |
type embeddingData struct { | |
Object string `json:"object"` | |
Index int `json:"index"` | |
Embedding []float32 `json:"embedding"` | |
} | |
type openAIApiError struct { | |
Message string `json:"message"` | |
Type string `json:"type"` | |
Param string `json:"param"` | |
Code json.Number `json:"code"` | |
} | |
func buildUrl(baseURL, resourceName, deploymentID string, isAzure bool) (string, error) { | |
if isAzure { | |
host := baseURL | |
if host == "" || host == "https://api.openai.com" { | |
// Fall back to old assumption | |
host = "https://" + resourceName + ".openai.azure.com" | |
} | |
path := "openai/deployments/" + deploymentID + "/embeddings" | |
queryParam := "api-version=2022-12-01" | |
return fmt.Sprintf("%s/%s?%s", host, path, queryParam), nil | |
} | |
host := baseURL | |
path := "/v1/embeddings" | |
return url.JoinPath(host, path) | |
} | |
type vectorizer struct { | |
openAIApiKey string | |
openAIOrganization string | |
azureApiKey string | |
httpClient *http.Client | |
buildUrlFn func(baseURL, resourceName, deploymentID string, isAzure bool) (string, error) | |
logger logrus.FieldLogger | |
} | |
func New(openAIApiKey, openAIOrganization, azureApiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer { | |
return &vectorizer{ | |
openAIApiKey: openAIApiKey, | |
openAIOrganization: openAIOrganization, | |
azureApiKey: azureApiKey, | |
httpClient: &http.Client{ | |
Timeout: timeout, | |
}, | |
buildUrlFn: buildUrl, | |
logger: logger, | |
} | |
} | |
func (v *vectorizer) Vectorize(ctx context.Context, input string, | |
config ent.VectorizationConfig, | |
) (*ent.VectorizationResult, error) { | |
return v.vectorize(ctx, []string{input}, v.getModelString(config.Type, config.Model, "document", config.ModelVersion), config) | |
} | |
func (v *vectorizer) VectorizeQuery(ctx context.Context, input []string, | |
config ent.VectorizationConfig, | |
) (*ent.VectorizationResult, error) { | |
return v.vectorize(ctx, input, v.getModelString(config.Type, config.Model, "query", config.ModelVersion), config) | |
} | |
func (v *vectorizer) vectorize(ctx context.Context, input []string, model string, config ent.VectorizationConfig) (*ent.VectorizationResult, error) { | |
body, err := json.Marshal(v.getEmbeddingsRequest(input, model, config.IsAzure)) | |
if err != nil { | |
return nil, errors.Wrap(err, "marshal body") | |
} | |
endpoint, err := v.buildURL(ctx, config) | |
if err != nil { | |
return nil, errors.Wrap(err, "join OpenAI API host and path") | |
} | |
req, err := http.NewRequestWithContext(ctx, "POST", endpoint, | |
bytes.NewReader(body)) | |
if err != nil { | |
return nil, errors.Wrap(err, "create POST request") | |
} | |
apiKey, err := v.getApiKey(ctx, config.IsAzure) | |
if err != nil { | |
return nil, errors.Wrap(err, "API Key") | |
} | |
req.Header.Add(v.getApiKeyHeaderAndValue(apiKey, config.IsAzure)) | |
if openAIOrganization := v.getOpenAIOrganization(ctx); openAIOrganization != "" { | |
req.Header.Add("OpenAI-Organization", openAIOrganization) | |
} | |
req.Header.Add("Content-Type", "application/json") | |
res, err := v.httpClient.Do(req) | |
if err != nil { | |
return nil, errors.Wrap(err, "send POST request") | |
} | |
defer res.Body.Close() | |
bodyBytes, err := io.ReadAll(res.Body) | |
if err != nil { | |
return nil, errors.Wrap(err, "read response body") | |
} | |
var resBody embedding | |
if err := json.Unmarshal(bodyBytes, &resBody); err != nil { | |
return nil, errors.Wrap(err, "unmarshal response body") | |
} | |
if res.StatusCode != 200 || resBody.Error != nil { | |
return nil, v.getError(res.StatusCode, resBody.Error, config.IsAzure) | |
} | |
texts := make([]string, len(resBody.Data)) | |
embeddings := make([][]float32, len(resBody.Data)) | |
for i := range resBody.Data { | |
texts[i] = resBody.Data[i].Object | |
embeddings[i] = resBody.Data[i].Embedding | |
} | |
return &ent.VectorizationResult{ | |
Text: texts, | |
Dimensions: len(resBody.Data[0].Embedding), | |
Vector: embeddings, | |
}, nil | |
} | |
func (v *vectorizer) buildURL(ctx context.Context, config ent.VectorizationConfig) (string, error) { | |
baseURL, resourceName, deploymentID, isAzure := config.BaseURL, config.ResourceName, config.DeploymentID, config.IsAzure | |
if headerBaseURL := v.getValueFromContext(ctx, "X-Openai-Baseurl"); headerBaseURL != "" { | |
baseURL = headerBaseURL | |
} | |
return v.buildUrlFn(baseURL, resourceName, deploymentID, isAzure) | |
} | |
func (v *vectorizer) getError(statusCode int, resBodyError *openAIApiError, isAzure bool) error { | |
endpoint := "OpenAI API" | |
if isAzure { | |
endpoint = "Azure OpenAI API" | |
} | |
if resBodyError != nil { | |
return fmt.Errorf("connection to: %s failed with status: %d error: %v", endpoint, statusCode, resBodyError.Message) | |
} | |
return fmt.Errorf("connection to: %s failed with status: %d", endpoint, statusCode) | |
} | |
func (v *vectorizer) getEmbeddingsRequest(input []string, model string, isAzure bool) embeddingsRequest { | |
if isAzure { | |
return embeddingsRequest{Input: input} | |
} | |
return embeddingsRequest{Input: input, Model: model} | |
} | |
func (v *vectorizer) getApiKeyHeaderAndValue(apiKey string, isAzure bool) (string, string) { | |
if isAzure { | |
return "api-key", apiKey | |
} | |
return "Authorization", fmt.Sprintf("Bearer %s", apiKey) | |
} | |
func (v *vectorizer) getOpenAIOrganization(ctx context.Context) string { | |
if value := v.getValueFromContext(ctx, "X-Openai-Organization"); value != "" { | |
return value | |
} | |
return v.openAIOrganization | |
} | |
func (v *vectorizer) getApiKey(ctx context.Context, isAzure bool) (string, error) { | |
var apiKey, envVar string | |
if isAzure { | |
apiKey = "X-Azure-Api-Key" | |
envVar = "AZURE_APIKEY" | |
if len(v.azureApiKey) > 0 { | |
return v.azureApiKey, nil | |
} | |
} else { | |
apiKey = "X-Openai-Api-Key" | |
envVar = "OPENAI_APIKEY" | |
if len(v.openAIApiKey) > 0 { | |
return v.openAIApiKey, nil | |
} | |
} | |
return v.getApiKeyFromContext(ctx, apiKey, envVar) | |
} | |
func (v *vectorizer) getApiKeyFromContext(ctx context.Context, apiKey, envVar string) (string, error) { | |
if apiKeyValue := v.getValueFromContext(ctx, apiKey); apiKeyValue != "" { | |
return apiKeyValue, nil | |
} | |
return "", fmt.Errorf("no api key found neither in request header: %s nor in environment variable under %s", apiKey, envVar) | |
} | |
func (v *vectorizer) getValueFromContext(ctx context.Context, key string) string { | |
if value := ctx.Value(key); value != nil { | |
if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 { | |
return keyHeader[0] | |
} | |
} | |
// try getting header from GRPC if not successful | |
if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 { | |
return apiKey[0] | |
} | |
return "" | |
} | |
func (v *vectorizer) getModelString(docType, model, action, version string) string { | |
if version == "002" { | |
return v.getModel002String(model) | |
} | |
return v.getModel001String(docType, model, action) | |
} | |
func (v *vectorizer) getModel001String(docType, model, action string) string { | |
modelBaseString := "%s-search-%s-%s-001" | |
if action == "document" { | |
if docType == "code" { | |
return fmt.Sprintf(modelBaseString, docType, model, "code") | |
} | |
return fmt.Sprintf(modelBaseString, docType, model, "doc") | |
} else { | |
if docType == "code" { | |
return fmt.Sprintf(modelBaseString, docType, model, "text") | |
} | |
return fmt.Sprintf(modelBaseString, docType, model, "query") | |
} | |
} | |
func (v *vectorizer) getModel002String(model string) string { | |
modelBaseString := "text-embedding-%s-002" | |
return fmt.Sprintf(modelBaseString, model) | |
} | |