KevinStephenson
Adding in weaviate code
b110593
raw
history blame
8.49 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package clients
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/weaviate/weaviate/usecases/modulecomponents"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/weaviate/weaviate/modules/text2vec-openai/ent"
)
type embeddingsRequest struct {
Input []string `json:"input"`
Model string `json:"model,omitempty"`
}
type embedding struct {
Object string `json:"object"`
Data []embeddingData `json:"data,omitempty"`
Error *openAIApiError `json:"error,omitempty"`
}
type embeddingData struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float32 `json:"embedding"`
}
type openAIApiError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code json.Number `json:"code"`
}
func buildUrl(baseURL, resourceName, deploymentID string, isAzure bool) (string, error) {
if isAzure {
host := baseURL
if host == "" || host == "https://api.openai.com" {
// Fall back to old assumption
host = "https://" + resourceName + ".openai.azure.com"
}
path := "openai/deployments/" + deploymentID + "/embeddings"
queryParam := "api-version=2022-12-01"
return fmt.Sprintf("%s/%s?%s", host, path, queryParam), nil
}
host := baseURL
path := "/v1/embeddings"
return url.JoinPath(host, path)
}
type vectorizer struct {
openAIApiKey string
openAIOrganization string
azureApiKey string
httpClient *http.Client
buildUrlFn func(baseURL, resourceName, deploymentID string, isAzure bool) (string, error)
logger logrus.FieldLogger
}
func New(openAIApiKey, openAIOrganization, azureApiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer {
return &vectorizer{
openAIApiKey: openAIApiKey,
openAIOrganization: openAIOrganization,
azureApiKey: azureApiKey,
httpClient: &http.Client{
Timeout: timeout,
},
buildUrlFn: buildUrl,
logger: logger,
}
}
func (v *vectorizer) Vectorize(ctx context.Context, input string,
config ent.VectorizationConfig,
) (*ent.VectorizationResult, error) {
return v.vectorize(ctx, []string{input}, v.getModelString(config.Type, config.Model, "document", config.ModelVersion), config)
}
func (v *vectorizer) VectorizeQuery(ctx context.Context, input []string,
config ent.VectorizationConfig,
) (*ent.VectorizationResult, error) {
return v.vectorize(ctx, input, v.getModelString(config.Type, config.Model, "query", config.ModelVersion), config)
}
func (v *vectorizer) vectorize(ctx context.Context, input []string, model string, config ent.VectorizationConfig) (*ent.VectorizationResult, error) {
body, err := json.Marshal(v.getEmbeddingsRequest(input, model, config.IsAzure))
if err != nil {
return nil, errors.Wrap(err, "marshal body")
}
endpoint, err := v.buildURL(ctx, config)
if err != nil {
return nil, errors.Wrap(err, "join OpenAI API host and path")
}
req, err := http.NewRequestWithContext(ctx, "POST", endpoint,
bytes.NewReader(body))
if err != nil {
return nil, errors.Wrap(err, "create POST request")
}
apiKey, err := v.getApiKey(ctx, config.IsAzure)
if err != nil {
return nil, errors.Wrap(err, "API Key")
}
req.Header.Add(v.getApiKeyHeaderAndValue(apiKey, config.IsAzure))
if openAIOrganization := v.getOpenAIOrganization(ctx); openAIOrganization != "" {
req.Header.Add("OpenAI-Organization", openAIOrganization)
}
req.Header.Add("Content-Type", "application/json")
res, err := v.httpClient.Do(req)
if err != nil {
return nil, errors.Wrap(err, "send POST request")
}
defer res.Body.Close()
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
return nil, errors.Wrap(err, "read response body")
}
var resBody embedding
if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
return nil, errors.Wrap(err, "unmarshal response body")
}
if res.StatusCode != 200 || resBody.Error != nil {
return nil, v.getError(res.StatusCode, resBody.Error, config.IsAzure)
}
texts := make([]string, len(resBody.Data))
embeddings := make([][]float32, len(resBody.Data))
for i := range resBody.Data {
texts[i] = resBody.Data[i].Object
embeddings[i] = resBody.Data[i].Embedding
}
return &ent.VectorizationResult{
Text: texts,
Dimensions: len(resBody.Data[0].Embedding),
Vector: embeddings,
}, nil
}
func (v *vectorizer) buildURL(ctx context.Context, config ent.VectorizationConfig) (string, error) {
baseURL, resourceName, deploymentID, isAzure := config.BaseURL, config.ResourceName, config.DeploymentID, config.IsAzure
if headerBaseURL := v.getValueFromContext(ctx, "X-Openai-Baseurl"); headerBaseURL != "" {
baseURL = headerBaseURL
}
return v.buildUrlFn(baseURL, resourceName, deploymentID, isAzure)
}
func (v *vectorizer) getError(statusCode int, resBodyError *openAIApiError, isAzure bool) error {
endpoint := "OpenAI API"
if isAzure {
endpoint = "Azure OpenAI API"
}
if resBodyError != nil {
return fmt.Errorf("connection to: %s failed with status: %d error: %v", endpoint, statusCode, resBodyError.Message)
}
return fmt.Errorf("connection to: %s failed with status: %d", endpoint, statusCode)
}
func (v *vectorizer) getEmbeddingsRequest(input []string, model string, isAzure bool) embeddingsRequest {
if isAzure {
return embeddingsRequest{Input: input}
}
return embeddingsRequest{Input: input, Model: model}
}
func (v *vectorizer) getApiKeyHeaderAndValue(apiKey string, isAzure bool) (string, string) {
if isAzure {
return "api-key", apiKey
}
return "Authorization", fmt.Sprintf("Bearer %s", apiKey)
}
func (v *vectorizer) getOpenAIOrganization(ctx context.Context) string {
if value := v.getValueFromContext(ctx, "X-Openai-Organization"); value != "" {
return value
}
return v.openAIOrganization
}
func (v *vectorizer) getApiKey(ctx context.Context, isAzure bool) (string, error) {
var apiKey, envVar string
if isAzure {
apiKey = "X-Azure-Api-Key"
envVar = "AZURE_APIKEY"
if len(v.azureApiKey) > 0 {
return v.azureApiKey, nil
}
} else {
apiKey = "X-Openai-Api-Key"
envVar = "OPENAI_APIKEY"
if len(v.openAIApiKey) > 0 {
return v.openAIApiKey, nil
}
}
return v.getApiKeyFromContext(ctx, apiKey, envVar)
}
func (v *vectorizer) getApiKeyFromContext(ctx context.Context, apiKey, envVar string) (string, error) {
if apiKeyValue := v.getValueFromContext(ctx, apiKey); apiKeyValue != "" {
return apiKeyValue, nil
}
return "", fmt.Errorf("no api key found neither in request header: %s nor in environment variable under %s", apiKey, envVar)
}
func (v *vectorizer) getValueFromContext(ctx context.Context, key string) string {
if value := ctx.Value(key); value != nil {
if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 {
return keyHeader[0]
}
}
// try getting header from GRPC if not successful
if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 {
return apiKey[0]
}
return ""
}
func (v *vectorizer) getModelString(docType, model, action, version string) string {
if version == "002" {
return v.getModel002String(model)
}
return v.getModel001String(docType, model, action)
}
func (v *vectorizer) getModel001String(docType, model, action string) string {
modelBaseString := "%s-search-%s-%s-001"
if action == "document" {
if docType == "code" {
return fmt.Sprintf(modelBaseString, docType, model, "code")
}
return fmt.Sprintf(modelBaseString, docType, model, "doc")
} else {
if docType == "code" {
return fmt.Sprintf(modelBaseString, docType, model, "text")
}
return fmt.Sprintf(modelBaseString, docType, model, "query")
}
}
func (v *vectorizer) getModel002String(model string) string {
modelBaseString := "text-embedding-%s-002"
return fmt.Sprintf(modelBaseString, model)
}