Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package client | |
import ( | |
"context" | |
"fmt" | |
"strings" | |
"time" | |
"github.com/pkg/errors" | |
"github.com/sirupsen/logrus" | |
pb "github.com/weaviate/contextionary/contextionary" | |
"github.com/weaviate/weaviate/entities/models" | |
txt2vecmodels "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional/models" | |
"github.com/weaviate/weaviate/modules/text2vec-contextionary/vectorizer" | |
"github.com/weaviate/weaviate/usecases/traverser" | |
"google.golang.org/grpc" | |
"google.golang.org/grpc/codes" | |
"google.golang.org/grpc/credentials/insecure" | |
"google.golang.org/grpc/status" | |
) | |
const ModelUncontactable = "module uncontactable" | |
// Client establishes a gRPC connection to a remote contextionary service | |
type Client struct { | |
grpcClient pb.ContextionaryClient | |
logger logrus.FieldLogger | |
} | |
// NewClient from gRPC discovery url to connect to a remote contextionary service | |
func NewClient(uri string, logger logrus.FieldLogger) (*Client, error) { | |
conn, err := grpc.Dial(uri, | |
grpc.WithTransportCredentials(insecure.NewCredentials()), | |
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(1024*1024*48))) | |
if err != nil { | |
return nil, fmt.Errorf("couldn't connect to remote contextionary gRPC server: %s", err) | |
} | |
client := pb.NewContextionaryClient(conn) | |
return &Client{ | |
grpcClient: client, | |
logger: logger, | |
}, nil | |
} | |
// IsStopWord returns true if the given word is a stopword, errors on connection errors | |
func (c *Client) IsStopWord(ctx context.Context, word string) (bool, error) { | |
res, err := c.grpcClient.IsWordStopword(ctx, &pb.Word{Word: word}) | |
if err != nil { | |
logConnectionRefused(c.logger, err) | |
return false, err | |
} | |
return res.Stopword, nil | |
} | |
// IsWordPresent returns true if the given word is a stopword, errors on connection errors | |
func (c *Client) IsWordPresent(ctx context.Context, word string) (bool, error) { | |
res, err := c.grpcClient.IsWordPresent(ctx, &pb.Word{Word: word}) | |
if err != nil { | |
logConnectionRefused(c.logger, err) | |
return false, err | |
} | |
return res.Present, nil | |
} | |
// SafeGetSimilarWordsWithCertainty will always return a list words - unless there is a network error | |
func (c *Client) SafeGetSimilarWordsWithCertainty(ctx context.Context, word string, certainty float32) ([]string, error) { | |
res, err := c.grpcClient.SafeGetSimilarWordsWithCertainty(ctx, &pb.SimilarWordsParams{Word: word, Certainty: certainty}) | |
if err != nil { | |
logConnectionRefused(c.logger, err) | |
return nil, err | |
} | |
output := make([]string, len(res.Words)) | |
for i, word := range res.Words { | |
output[i] = word.Word | |
} | |
return output, nil | |
} | |
// SchemaSearch for related classes and properties | |
// TODO: is this still used? | |
func (c *Client) SchemaSearch(ctx context.Context, params traverser.SearchParams) (traverser.SearchResults, error) { | |
pbParams := &pb.SchemaSearchParams{ | |
Certainty: params.Certainty, | |
Name: params.Name, | |
SearchType: searchTypeToProto(params.SearchType), | |
} | |
res, err := c.grpcClient.SchemaSearch(ctx, pbParams) | |
if err != nil { | |
logConnectionRefused(c.logger, err) | |
return traverser.SearchResults{}, err | |
} | |
return schemaSearchResultsFromProto(res), nil | |
} | |
func searchTypeToProto(input traverser.SearchType) pb.SearchType { | |
switch input { | |
case traverser.SearchTypeClass: | |
return pb.SearchType_CLASS | |
case traverser.SearchTypeProperty: | |
return pb.SearchType_PROPERTY | |
default: | |
panic(fmt.Sprintf("unknown search type %v", input)) | |
} | |
} | |
func searchTypeFromProto(input pb.SearchType) traverser.SearchType { | |
switch input { | |
case pb.SearchType_CLASS: | |
return traverser.SearchTypeClass | |
case pb.SearchType_PROPERTY: | |
return traverser.SearchTypeProperty | |
default: | |
panic(fmt.Sprintf("unknown search type %v", input)) | |
} | |
} | |
func schemaSearchResultsFromProto(res *pb.SchemaSearchResults) traverser.SearchResults { | |
return traverser.SearchResults{ | |
Type: searchTypeFromProto(res.Type), | |
Results: searchResultsFromProto(res.Results), | |
} | |
} | |
func searchResultsFromProto(input []*pb.SchemaSearchResult) []traverser.SearchResult { | |
output := make([]traverser.SearchResult, len(input)) | |
for i, res := range input { | |
output[i] = traverser.SearchResult{ | |
Certainty: res.Certainty, | |
Name: res.Name, | |
} | |
} | |
return output | |
} | |
func (c *Client) VectorForWord(ctx context.Context, word string) ([]float32, error) { | |
res, err := c.grpcClient.VectorForWord(ctx, &pb.Word{Word: word}) | |
if err != nil { | |
logConnectionRefused(c.logger, err) | |
return nil, fmt.Errorf("could not get vector from remote: %v", err) | |
} | |
v, _, _ := vectorFromProto(res) | |
return v, nil | |
} | |
func logConnectionRefused(logger logrus.FieldLogger, err error) { | |
if strings.Contains(fmt.Sprintf("%v", err), "connect: connection refused") { | |
logger.WithError(err).WithField("module", "contextionary").Warnf(ModelUncontactable) | |
} else if strings.Contains(err.Error(), "connectex: No connection could be made because the target machine actively refused it.") { | |
logger.WithError(err).WithField("module", "contextionary").Warnf(ModelUncontactable) | |
} | |
} | |
func (c *Client) MultiVectorForWord(ctx context.Context, words []string) ([][]float32, error) { | |
out := make([][]float32, len(words)) | |
wordParams := make([]*pb.Word, len(words)) | |
for i, word := range words { | |
wordParams[i] = &pb.Word{Word: word} | |
} | |
res, err := c.grpcClient.MultiVectorForWord(ctx, &pb.WordList{Words: wordParams}) | |
if err != nil { | |
logConnectionRefused(c.logger, err) | |
return nil, err | |
} | |
for i, elem := range res.Vectors { | |
if len(elem.Entries) == 0 { | |
// indicates word not found | |
continue | |
} | |
out[i], _, _ = vectorFromProto(elem) | |
} | |
return out, nil | |
} | |
func (c *Client) MultiNearestWordsByVector(ctx context.Context, vectors [][]float32, k, n int) ([]*txt2vecmodels.NearestNeighbors, error) { | |
out := make([]*txt2vecmodels.NearestNeighbors, len(vectors)) | |
searchParams := make([]*pb.VectorNNParams, len(vectors)) | |
for i, vector := range vectors { | |
searchParams[i] = &pb.VectorNNParams{ | |
Vector: vectorToProto(vector), | |
K: int32(k), | |
N: int32(n), | |
} | |
} | |
res, err := c.grpcClient.MultiNearestWordsByVector(ctx, &pb.VectorNNParamsList{Params: searchParams}) | |
if err != nil { | |
logConnectionRefused(c.logger, err) | |
return nil, err | |
} | |
for i, elem := range res.Words { | |
out[i] = &txt2vecmodels.NearestNeighbors{ | |
Neighbors: c.extractNeighbors(elem), | |
} | |
} | |
return out, nil | |
} | |
func (c *Client) extractNeighbors(elem *pb.NearestWords) []*txt2vecmodels.NearestNeighbor { | |
out := make([]*txt2vecmodels.NearestNeighbor, len(elem.Words)) | |
for i := range out { | |
vec, _, _ := vectorFromProto(elem.Vectors.Vectors[i]) | |
out[i] = &txt2vecmodels.NearestNeighbor{ | |
Concept: elem.Words[i], | |
Distance: elem.Distances[i], | |
Vector: vec, | |
} | |
} | |
return out | |
} | |
func vectorFromProto(in *pb.Vector) ([]float32, []txt2vecmodels.InterpretationSource, error) { | |
output := make([]float32, len(in.Entries)) | |
for i, entry := range in.Entries { | |
output[i] = entry.Entry | |
} | |
source := make([]txt2vecmodels.InterpretationSource, len(in.Source)) | |
for i, s := range in.Source { | |
source[i].Concept = s.Concept | |
source[i].Weight = float64(s.Weight) | |
source[i].Occurrence = s.Occurrence | |
} | |
return output, source, nil | |
} | |
func (c *Client) VectorForCorpi(ctx context.Context, corpi []string, overridesMap map[string]string) ([]float32, []txt2vecmodels.InterpretationSource, error) { | |
overrides := overridesFromMap(overridesMap) | |
res, err := c.grpcClient.VectorForCorpi(ctx, &pb.Corpi{Corpi: corpi, Overrides: overrides}) | |
if err != nil { | |
if strings.Contains(err.Error(), "connect: connection refused") { | |
c.logger.WithError(err).WithField("module", "contextionary").Warnf(ModelUncontactable) | |
} else if strings.Contains(err.Error(), "connectex: No connection could be made because the target machine actively refused it.") { | |
c.logger.WithError(err).WithField("module", "contextionary").Warnf(ModelUncontactable) | |
} | |
st, ok := status.FromError(err) | |
if !ok || st.Code() != codes.InvalidArgument { | |
return nil, nil, fmt.Errorf("could not get vector from remote: %v", err) | |
} | |
return nil, nil, vectorizer.NewErrNoUsableWordsf(st.Message()) | |
} | |
return vectorFromProto(res) | |
} | |
func (c *Client) VectorOnlyForCorpi(ctx context.Context, corpi []string, overrides map[string]string) ([]float32, error) { | |
vec, _, err := c.VectorForCorpi(ctx, corpi, overrides) | |
return vec, err | |
} | |
func (c *Client) NearestWordsByVector(ctx context.Context, vector []float32, n int, k int) ([]string, []float32, error) { | |
res, err := c.grpcClient.NearestWordsByVector(ctx, &pb.VectorNNParams{ | |
K: int32(k), | |
N: int32(n), | |
Vector: vectorToProto(vector), | |
}) | |
if err != nil { | |
logConnectionRefused(c.logger, err) | |
return nil, nil, fmt.Errorf("could not get nearest words by vector: %v", err) | |
} | |
return res.Words, res.Distances, nil | |
} | |
func (c *Client) AddExtension(ctx context.Context, extension *models.C11yExtension) error { | |
_, err := c.grpcClient.AddExtension(ctx, &pb.ExtensionInput{ | |
Concept: extension.Concept, | |
Definition: strings.ToLower(extension.Definition), | |
Weight: extension.Weight, | |
}) | |
return err | |
} | |
func vectorToProto(in []float32) *pb.Vector { | |
output := make([]*pb.VectorEntry, len(in)) | |
for i, entry := range in { | |
output[i] = &pb.VectorEntry{ | |
Entry: entry, | |
} | |
} | |
return &pb.Vector{Entries: output} | |
} | |
func (c *Client) WaitForStartupAndValidateVersion(startupCtx context.Context, | |
requiredMinimumVersion string, interval time.Duration, | |
) error { | |
for { | |
if err := startupCtx.Err(); err != nil { | |
return errors.Wrap(err, "wait for contextionary remote inference service") | |
} | |
time.Sleep(interval) | |
ctx, cancel := context.WithTimeout(startupCtx, 2*time.Second) | |
defer cancel() | |
v, err := c.version(ctx) | |
if err != nil { | |
c.logger.WithField("action", "startup_check_contextionary").WithError(err). | |
Warnf("could not connect to contextionary at startup, trying again in 1 sec") | |
continue | |
} | |
ok, err := extractVersionAndCompare(v, requiredMinimumVersion) | |
if err != nil { | |
c.logger.WithField("action", "startup_check_contextionary"). | |
WithField("requiredMinimumContextionaryVersion", requiredMinimumVersion). | |
WithField("contextionaryVersion", v). | |
WithError(err). | |
Warnf("cannot determine if contextionary version is compatible. " + | |
"This is fine in development, but probelematic if you see this production") | |
return nil | |
} | |
if ok { | |
c.logger.WithField("action", "startup_check_contextionary"). | |
WithField("requiredMinimumContextionaryVersion", requiredMinimumVersion). | |
WithField("contextionaryVersion", v). | |
Infof("found a valid contextionary version") | |
return nil | |
} else { | |
return errors.Errorf("insuffcient contextionary version: need at least %s, got %s", | |
requiredMinimumVersion, v) | |
} | |
} | |
} | |
func overridesFromMap(in map[string]string) []*pb.Override { | |
if in == nil { | |
return nil | |
} | |
out := make([]*pb.Override, len(in)) | |
i := 0 | |
for key, value := range in { | |
out[i] = &pb.Override{ | |
Word: key, | |
Expression: value, | |
} | |
i++ | |
} | |
return out | |
} | |