Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package clients | |
import ( | |
"bytes" | |
"context" | |
"encoding/json" | |
"fmt" | |
"io" | |
"net/http" | |
"time" | |
"github.com/weaviate/weaviate/usecases/modulecomponents" | |
"github.com/pkg/errors" | |
"github.com/sirupsen/logrus" | |
"github.com/weaviate/weaviate/modules/text2vec-huggingface/ent" | |
) | |
const ( | |
DefaultOrigin = "https://api-inference.huggingface.co" | |
DefaultPath = "pipeline/feature-extraction" | |
) | |
type embeddingsRequest struct { | |
Inputs []string `json:"inputs"` | |
Options *options `json:"options,omitempty"` | |
} | |
type options struct { | |
WaitForModel bool `json:"wait_for_model,omitempty"` | |
UseGPU bool `json:"use_gpu,omitempty"` | |
UseCache bool `json:"use_cache,omitempty"` | |
} | |
type embedding [][]float32 | |
type embeddingBert [][][][]float32 | |
type embeddingObject struct { | |
Embeddings embedding `json:"embeddings"` | |
} | |
type huggingFaceApiError struct { | |
Error string `json:"error"` | |
EstimatedTime *float32 `json:"estimated_time,omitempty"` | |
Warnings []string `json:"warnings"` | |
} | |
type vectorizer struct { | |
apiKey string | |
httpClient *http.Client | |
bertEmbeddingsDecoder *bertEmbeddingsDecoder | |
logger logrus.FieldLogger | |
} | |
func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer { | |
return &vectorizer{ | |
apiKey: apiKey, | |
httpClient: &http.Client{ | |
Timeout: timeout, | |
}, | |
bertEmbeddingsDecoder: newBertEmbeddingsDecoder(), | |
logger: logger, | |
} | |
} | |
func (v *vectorizer) Vectorize(ctx context.Context, input string, | |
config ent.VectorizationConfig, | |
) (*ent.VectorizationResult, error) { | |
return v.vectorize(ctx, v.getURL(config), input, v.getOptions(config)) | |
} | |
func (v *vectorizer) VectorizeQuery(ctx context.Context, input string, | |
config ent.VectorizationConfig, | |
) (*ent.VectorizationResult, error) { | |
return v.vectorize(ctx, v.getURL(config), input, v.getOptions(config)) | |
} | |
func (v *vectorizer) vectorize(ctx context.Context, url string, | |
input string, options options, | |
) (*ent.VectorizationResult, error) { | |
body, err := json.Marshal(embeddingsRequest{ | |
Inputs: []string{input}, | |
Options: &options, | |
}) | |
if err != nil { | |
return nil, errors.Wrapf(err, "marshal body") | |
} | |
req, err := http.NewRequestWithContext(ctx, "POST", url, | |
bytes.NewReader(body)) | |
if err != nil { | |
return nil, errors.Wrap(err, "create POST request") | |
} | |
if apiKey := v.getApiKey(ctx); apiKey != "" { | |
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey)) | |
} | |
req.Header.Add("Content-Type", "application/json") | |
res, err := v.httpClient.Do(req) | |
if err != nil { | |
return nil, errors.Wrap(err, "send POST request") | |
} | |
defer res.Body.Close() | |
bodyBytes, err := io.ReadAll(res.Body) | |
if err != nil { | |
return nil, errors.Wrap(err, "read response body") | |
} | |
if err := checkResponse(res, bodyBytes); err != nil { | |
return nil, err | |
} | |
vector, err := v.decodeVector(bodyBytes) | |
if err != nil { | |
return nil, errors.Wrap(err, "cannot decode vector") | |
} | |
return &ent.VectorizationResult{ | |
Text: input, | |
Dimensions: len(vector), | |
Vector: vector, | |
}, nil | |
} | |
func checkResponse(res *http.Response, bodyBytes []byte) error { | |
if res.StatusCode < 400 { | |
return nil | |
} | |
var resBody huggingFaceApiError | |
if err := json.Unmarshal(bodyBytes, &resBody); err != nil { | |
return fmt.Errorf("unmarshal error response body: %v", string(bodyBytes)) | |
} | |
message := fmt.Sprintf("failed with status: %d", res.StatusCode) | |
if resBody.Error != "" { | |
message = fmt.Sprintf("%s error: %v", message, resBody.Error) | |
if resBody.EstimatedTime != nil { | |
message = fmt.Sprintf("%s estimated time: %v", message, *resBody.EstimatedTime) | |
} | |
if len(resBody.Warnings) > 0 { | |
message = fmt.Sprintf("%s warnings: %v", message, resBody.Warnings) | |
} | |
} | |
if res.StatusCode == http.StatusInternalServerError { | |
message = fmt.Sprintf("connection to HuggingFace %v", message) | |
} | |
return errors.New(message) | |
} | |
func (v *vectorizer) decodeVector(bodyBytes []byte) ([]float32, error) { | |
var emb embedding | |
if err := json.Unmarshal(bodyBytes, &emb); err != nil { | |
var embObject embeddingObject | |
if err := json.Unmarshal(bodyBytes, &embObject); err != nil { | |
var embBert embeddingBert | |
if err := json.Unmarshal(bodyBytes, &embBert); err != nil { | |
return nil, errors.Wrap(err, "unmarshal response body") | |
} | |
if len(embBert) == 1 && len(embBert[0]) == 1 { | |
return v.bertEmbeddingsDecoder.calculateVector(embBert[0][0]) | |
} | |
return nil, errors.New("unprocessable response body") | |
} | |
if len(embObject.Embeddings) == 1 { | |
return embObject.Embeddings[0], nil | |
} | |
return nil, errors.New("unprocessable response body") | |
} | |
if len(emb) == 1 { | |
return emb[0], nil | |
} | |
return nil, errors.New("unprocessable response body") | |
} | |
func (v *vectorizer) getApiKey(ctx context.Context) string { | |
if len(v.apiKey) > 0 { | |
return v.apiKey | |
} | |
key := "X-Huggingface-Api-Key" | |
apiKey := ctx.Value(key) | |
// try getting header from GRPC if not successful | |
if apiKey == nil { | |
apiKey = modulecomponents.GetValueFromGRPC(ctx, key) | |
} | |
if apiKeyHeader, ok := apiKey.([]string); ok && | |
len(apiKeyHeader) > 0 && len(apiKeyHeader[0]) > 0 { | |
return apiKeyHeader[0] | |
} | |
return "" | |
} | |
func (v *vectorizer) getOptions(config ent.VectorizationConfig) options { | |
return options{ | |
WaitForModel: config.WaitForModel, | |
UseGPU: config.UseGPU, | |
UseCache: config.UseCache, | |
} | |
} | |
func (v *vectorizer) getURL(config ent.VectorizationConfig) string { | |
if config.EndpointURL != "" { | |
return config.EndpointURL | |
} | |
return fmt.Sprintf("%s/%s/%s", DefaultOrigin, DefaultPath, config.Model) | |
} | |