KevinStephenson
Adding in weaviate code
b110593
raw
history blame
5.66 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package clients
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/weaviate/weaviate/usecases/modulecomponents"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/weaviate/weaviate/modules/text2vec-jinaai/ent"
)
type embeddingsRequest struct {
Input []string `json:"input"`
Model string `json:"model,omitempty"`
}
type embedding struct {
Object string `json:"object"`
Data []embeddingData `json:"data,omitempty"`
Error *jinaAIApiError `json:"error,omitempty"`
}
type embeddingData struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float32 `json:"embedding"`
}
type jinaAIApiError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code string `json:"code"`
}
func buildUrl(config ent.VectorizationConfig) (string, error) {
host := config.BaseURL
path := "/v1/embeddings"
return url.JoinPath(host, path)
}
type vectorizer struct {
jinaAIApiKey string
httpClient *http.Client
buildUrlFn func(config ent.VectorizationConfig) (string, error)
logger logrus.FieldLogger
}
func New(jinaAIApiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer {
return &vectorizer{
jinaAIApiKey: jinaAIApiKey,
httpClient: &http.Client{
Timeout: timeout,
},
buildUrlFn: buildUrl,
logger: logger,
}
}
func (v *vectorizer) Vectorize(ctx context.Context, input string,
config ent.VectorizationConfig,
) (*ent.VectorizationResult, error) {
return v.vectorize(ctx, []string{input}, config.Model, config)
}
func (v *vectorizer) VectorizeQuery(ctx context.Context, input []string,
config ent.VectorizationConfig,
) (*ent.VectorizationResult, error) {
return v.vectorize(ctx, input, config.Model, config)
}
func (v *vectorizer) vectorize(ctx context.Context, input []string, model string, config ent.VectorizationConfig) (*ent.VectorizationResult, error) {
body, err := json.Marshal(v.getEmbeddingsRequest(input, model))
if err != nil {
return nil, errors.Wrap(err, "marshal body")
}
endpoint, err := v.buildUrlFn(config)
if err != nil {
return nil, errors.Wrap(err, "join jinaAI API host and path")
}
req, err := http.NewRequestWithContext(ctx, "POST", endpoint,
bytes.NewReader(body))
if err != nil {
return nil, errors.Wrap(err, "create POST request")
}
apiKey, err := v.getApiKey(ctx)
if err != nil {
return nil, errors.Wrap(err, "API Key")
}
req.Header.Add(v.getApiKeyHeaderAndValue(apiKey))
req.Header.Add("Content-Type", "application/json")
res, err := v.httpClient.Do(req)
if err != nil {
return nil, errors.Wrap(err, "send POST request")
}
defer res.Body.Close()
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
return nil, errors.Wrap(err, "read response body")
}
var resBody embedding
if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
return nil, errors.Wrap(err, "unmarshal response body")
}
if res.StatusCode != 200 || resBody.Error != nil {
return nil, v.getError(res.StatusCode, resBody.Error)
}
texts := make([]string, len(resBody.Data))
embeddings := make([][]float32, len(resBody.Data))
for i := range resBody.Data {
texts[i] = resBody.Data[i].Object
embeddings[i] = resBody.Data[i].Embedding
}
return &ent.VectorizationResult{
Text: texts,
Dimensions: len(resBody.Data[0].Embedding),
Vector: embeddings,
}, nil
}
func (v *vectorizer) getError(statusCode int, resBodyError *jinaAIApiError) error {
endpoint := "JinaAI API"
if resBodyError != nil {
return fmt.Errorf("connection to: %s failed with status: %d error: %v", endpoint, statusCode, resBodyError.Message)
}
return fmt.Errorf("connection to: %s failed with status: %d", endpoint, statusCode)
}
func (v *vectorizer) getEmbeddingsRequest(input []string, model string) embeddingsRequest {
return embeddingsRequest{Input: input, Model: model}
}
func (v *vectorizer) getApiKeyHeaderAndValue(apiKey string) (string, string) {
return "Authorization", fmt.Sprintf("Bearer %s", apiKey)
}
func (v *vectorizer) getApiKey(ctx context.Context) (string, error) {
var apiKey, envVar string
apiKey = "X-Jinaai-Api-Key"
envVar = "JINAAI_APIKEY"
if len(v.jinaAIApiKey) > 0 {
return v.jinaAIApiKey, nil
}
return v.getApiKeyFromContext(ctx, apiKey, envVar)
}
func (v *vectorizer) getApiKeyFromContext(ctx context.Context, apiKey, envVar string) (string, error) {
if apiKeyValue := v.getValueFromContext(ctx, apiKey); apiKeyValue != "" {
return apiKeyValue, nil
}
return "", fmt.Errorf("no api key found neither in request header: %s nor in environment variable under %s", apiKey, envVar)
}
func (v *vectorizer) getValueFromContext(ctx context.Context, key string) string {
if value := ctx.Value(key); value != nil {
if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 {
return keyHeader[0]
}
}
// try getting header from GRPC if not successful
if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 {
return apiKey[0]
}
return ""
}