KevinStephenson
Adding in weaviate code
b110593
raw
history blame
6.12 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package clients
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/weaviate/weaviate/usecases/modulecomponents"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/weaviate/weaviate/modules/text2vec-huggingface/ent"
)
const (
DefaultOrigin = "https://api-inference.huggingface.co"
DefaultPath = "pipeline/feature-extraction"
)
type embeddingsRequest struct {
Inputs []string `json:"inputs"`
Options *options `json:"options,omitempty"`
}
type options struct {
WaitForModel bool `json:"wait_for_model,omitempty"`
UseGPU bool `json:"use_gpu,omitempty"`
UseCache bool `json:"use_cache,omitempty"`
}
type embedding [][]float32
type embeddingBert [][][][]float32
type embeddingObject struct {
Embeddings embedding `json:"embeddings"`
}
type huggingFaceApiError struct {
Error string `json:"error"`
EstimatedTime *float32 `json:"estimated_time,omitempty"`
Warnings []string `json:"warnings"`
}
type vectorizer struct {
apiKey string
httpClient *http.Client
bertEmbeddingsDecoder *bertEmbeddingsDecoder
logger logrus.FieldLogger
}
func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer {
return &vectorizer{
apiKey: apiKey,
httpClient: &http.Client{
Timeout: timeout,
},
bertEmbeddingsDecoder: newBertEmbeddingsDecoder(),
logger: logger,
}
}
func (v *vectorizer) Vectorize(ctx context.Context, input string,
config ent.VectorizationConfig,
) (*ent.VectorizationResult, error) {
return v.vectorize(ctx, v.getURL(config), input, v.getOptions(config))
}
func (v *vectorizer) VectorizeQuery(ctx context.Context, input string,
config ent.VectorizationConfig,
) (*ent.VectorizationResult, error) {
return v.vectorize(ctx, v.getURL(config), input, v.getOptions(config))
}
func (v *vectorizer) vectorize(ctx context.Context, url string,
input string, options options,
) (*ent.VectorizationResult, error) {
body, err := json.Marshal(embeddingsRequest{
Inputs: []string{input},
Options: &options,
})
if err != nil {
return nil, errors.Wrapf(err, "marshal body")
}
req, err := http.NewRequestWithContext(ctx, "POST", url,
bytes.NewReader(body))
if err != nil {
return nil, errors.Wrap(err, "create POST request")
}
if apiKey := v.getApiKey(ctx); apiKey != "" {
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey))
}
req.Header.Add("Content-Type", "application/json")
res, err := v.httpClient.Do(req)
if err != nil {
return nil, errors.Wrap(err, "send POST request")
}
defer res.Body.Close()
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
return nil, errors.Wrap(err, "read response body")
}
if err := checkResponse(res, bodyBytes); err != nil {
return nil, err
}
vector, err := v.decodeVector(bodyBytes)
if err != nil {
return nil, errors.Wrap(err, "cannot decode vector")
}
return &ent.VectorizationResult{
Text: input,
Dimensions: len(vector),
Vector: vector,
}, nil
}
func checkResponse(res *http.Response, bodyBytes []byte) error {
if res.StatusCode < 400 {
return nil
}
var resBody huggingFaceApiError
if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
return fmt.Errorf("unmarshal error response body: %v", string(bodyBytes))
}
message := fmt.Sprintf("failed with status: %d", res.StatusCode)
if resBody.Error != "" {
message = fmt.Sprintf("%s error: %v", message, resBody.Error)
if resBody.EstimatedTime != nil {
message = fmt.Sprintf("%s estimated time: %v", message, *resBody.EstimatedTime)
}
if len(resBody.Warnings) > 0 {
message = fmt.Sprintf("%s warnings: %v", message, resBody.Warnings)
}
}
if res.StatusCode == http.StatusInternalServerError {
message = fmt.Sprintf("connection to HuggingFace %v", message)
}
return errors.New(message)
}
func (v *vectorizer) decodeVector(bodyBytes []byte) ([]float32, error) {
var emb embedding
if err := json.Unmarshal(bodyBytes, &emb); err != nil {
var embObject embeddingObject
if err := json.Unmarshal(bodyBytes, &embObject); err != nil {
var embBert embeddingBert
if err := json.Unmarshal(bodyBytes, &embBert); err != nil {
return nil, errors.Wrap(err, "unmarshal response body")
}
if len(embBert) == 1 && len(embBert[0]) == 1 {
return v.bertEmbeddingsDecoder.calculateVector(embBert[0][0])
}
return nil, errors.New("unprocessable response body")
}
if len(embObject.Embeddings) == 1 {
return embObject.Embeddings[0], nil
}
return nil, errors.New("unprocessable response body")
}
if len(emb) == 1 {
return emb[0], nil
}
return nil, errors.New("unprocessable response body")
}
func (v *vectorizer) getApiKey(ctx context.Context) string {
if len(v.apiKey) > 0 {
return v.apiKey
}
key := "X-Huggingface-Api-Key"
apiKey := ctx.Value(key)
// try getting header from GRPC if not successful
if apiKey == nil {
apiKey = modulecomponents.GetValueFromGRPC(ctx, key)
}
if apiKeyHeader, ok := apiKey.([]string); ok &&
len(apiKeyHeader) > 0 && len(apiKeyHeader[0]) > 0 {
return apiKeyHeader[0]
}
return ""
}
func (v *vectorizer) getOptions(config ent.VectorizationConfig) options {
return options{
WaitForModel: config.WaitForModel,
UseGPU: config.UseGPU,
UseCache: config.UseCache,
}
}
func (v *vectorizer) getURL(config ent.VectorizationConfig) string {
if config.EndpointURL != "" {
return config.EndpointURL
}
return fmt.Sprintf("%s/%s/%s", DefaultOrigin, DefaultPath, config.Model)
}