KevinStephenson
Adding in weaviate code
b110593
raw
history blame
11.7 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package client
import (
"context"
"fmt"
"strings"
"time"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
pb "github.com/weaviate/contextionary/contextionary"
"github.com/weaviate/weaviate/entities/models"
txt2vecmodels "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional/models"
"github.com/weaviate/weaviate/modules/text2vec-contextionary/vectorizer"
"github.com/weaviate/weaviate/usecases/traverser"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
)
const ModelUncontactable = "module uncontactable"
// Client establishes a gRPC connection to a remote contextionary service
type Client struct {
grpcClient pb.ContextionaryClient
logger logrus.FieldLogger
}
// NewClient from gRPC discovery url to connect to a remote contextionary service
func NewClient(uri string, logger logrus.FieldLogger) (*Client, error) {
conn, err := grpc.Dial(uri,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(1024*1024*48)))
if err != nil {
return nil, fmt.Errorf("couldn't connect to remote contextionary gRPC server: %s", err)
}
client := pb.NewContextionaryClient(conn)
return &Client{
grpcClient: client,
logger: logger,
}, nil
}
// IsStopWord returns true if the given word is a stopword, errors on connection errors
func (c *Client) IsStopWord(ctx context.Context, word string) (bool, error) {
res, err := c.grpcClient.IsWordStopword(ctx, &pb.Word{Word: word})
if err != nil {
logConnectionRefused(c.logger, err)
return false, err
}
return res.Stopword, nil
}
// IsWordPresent returns true if the given word is a stopword, errors on connection errors
func (c *Client) IsWordPresent(ctx context.Context, word string) (bool, error) {
res, err := c.grpcClient.IsWordPresent(ctx, &pb.Word{Word: word})
if err != nil {
logConnectionRefused(c.logger, err)
return false, err
}
return res.Present, nil
}
// SafeGetSimilarWordsWithCertainty will always return a list words - unless there is a network error
func (c *Client) SafeGetSimilarWordsWithCertainty(ctx context.Context, word string, certainty float32) ([]string, error) {
res, err := c.grpcClient.SafeGetSimilarWordsWithCertainty(ctx, &pb.SimilarWordsParams{Word: word, Certainty: certainty})
if err != nil {
logConnectionRefused(c.logger, err)
return nil, err
}
output := make([]string, len(res.Words))
for i, word := range res.Words {
output[i] = word.Word
}
return output, nil
}
// SchemaSearch for related classes and properties
// TODO: is this still used?
func (c *Client) SchemaSearch(ctx context.Context, params traverser.SearchParams) (traverser.SearchResults, error) {
pbParams := &pb.SchemaSearchParams{
Certainty: params.Certainty,
Name: params.Name,
SearchType: searchTypeToProto(params.SearchType),
}
res, err := c.grpcClient.SchemaSearch(ctx, pbParams)
if err != nil {
logConnectionRefused(c.logger, err)
return traverser.SearchResults{}, err
}
return schemaSearchResultsFromProto(res), nil
}
func searchTypeToProto(input traverser.SearchType) pb.SearchType {
switch input {
case traverser.SearchTypeClass:
return pb.SearchType_CLASS
case traverser.SearchTypeProperty:
return pb.SearchType_PROPERTY
default:
panic(fmt.Sprintf("unknown search type %v", input))
}
}
func searchTypeFromProto(input pb.SearchType) traverser.SearchType {
switch input {
case pb.SearchType_CLASS:
return traverser.SearchTypeClass
case pb.SearchType_PROPERTY:
return traverser.SearchTypeProperty
default:
panic(fmt.Sprintf("unknown search type %v", input))
}
}
func schemaSearchResultsFromProto(res *pb.SchemaSearchResults) traverser.SearchResults {
return traverser.SearchResults{
Type: searchTypeFromProto(res.Type),
Results: searchResultsFromProto(res.Results),
}
}
func searchResultsFromProto(input []*pb.SchemaSearchResult) []traverser.SearchResult {
output := make([]traverser.SearchResult, len(input))
for i, res := range input {
output[i] = traverser.SearchResult{
Certainty: res.Certainty,
Name: res.Name,
}
}
return output
}
func (c *Client) VectorForWord(ctx context.Context, word string) ([]float32, error) {
res, err := c.grpcClient.VectorForWord(ctx, &pb.Word{Word: word})
if err != nil {
logConnectionRefused(c.logger, err)
return nil, fmt.Errorf("could not get vector from remote: %v", err)
}
v, _, _ := vectorFromProto(res)
return v, nil
}
func logConnectionRefused(logger logrus.FieldLogger, err error) {
if strings.Contains(fmt.Sprintf("%v", err), "connect: connection refused") {
logger.WithError(err).WithField("module", "contextionary").Warnf(ModelUncontactable)
} else if strings.Contains(err.Error(), "connectex: No connection could be made because the target machine actively refused it.") {
logger.WithError(err).WithField("module", "contextionary").Warnf(ModelUncontactable)
}
}
func (c *Client) MultiVectorForWord(ctx context.Context, words []string) ([][]float32, error) {
out := make([][]float32, len(words))
wordParams := make([]*pb.Word, len(words))
for i, word := range words {
wordParams[i] = &pb.Word{Word: word}
}
res, err := c.grpcClient.MultiVectorForWord(ctx, &pb.WordList{Words: wordParams})
if err != nil {
logConnectionRefused(c.logger, err)
return nil, err
}
for i, elem := range res.Vectors {
if len(elem.Entries) == 0 {
// indicates word not found
continue
}
out[i], _, _ = vectorFromProto(elem)
}
return out, nil
}
func (c *Client) MultiNearestWordsByVector(ctx context.Context, vectors [][]float32, k, n int) ([]*txt2vecmodels.NearestNeighbors, error) {
out := make([]*txt2vecmodels.NearestNeighbors, len(vectors))
searchParams := make([]*pb.VectorNNParams, len(vectors))
for i, vector := range vectors {
searchParams[i] = &pb.VectorNNParams{
Vector: vectorToProto(vector),
K: int32(k),
N: int32(n),
}
}
res, err := c.grpcClient.MultiNearestWordsByVector(ctx, &pb.VectorNNParamsList{Params: searchParams})
if err != nil {
logConnectionRefused(c.logger, err)
return nil, err
}
for i, elem := range res.Words {
out[i] = &txt2vecmodels.NearestNeighbors{
Neighbors: c.extractNeighbors(elem),
}
}
return out, nil
}
func (c *Client) extractNeighbors(elem *pb.NearestWords) []*txt2vecmodels.NearestNeighbor {
out := make([]*txt2vecmodels.NearestNeighbor, len(elem.Words))
for i := range out {
vec, _, _ := vectorFromProto(elem.Vectors.Vectors[i])
out[i] = &txt2vecmodels.NearestNeighbor{
Concept: elem.Words[i],
Distance: elem.Distances[i],
Vector: vec,
}
}
return out
}
func vectorFromProto(in *pb.Vector) ([]float32, []txt2vecmodels.InterpretationSource, error) {
output := make([]float32, len(in.Entries))
for i, entry := range in.Entries {
output[i] = entry.Entry
}
source := make([]txt2vecmodels.InterpretationSource, len(in.Source))
for i, s := range in.Source {
source[i].Concept = s.Concept
source[i].Weight = float64(s.Weight)
source[i].Occurrence = s.Occurrence
}
return output, source, nil
}
func (c *Client) VectorForCorpi(ctx context.Context, corpi []string, overridesMap map[string]string) ([]float32, []txt2vecmodels.InterpretationSource, error) {
overrides := overridesFromMap(overridesMap)
res, err := c.grpcClient.VectorForCorpi(ctx, &pb.Corpi{Corpi: corpi, Overrides: overrides})
if err != nil {
if strings.Contains(err.Error(), "connect: connection refused") {
c.logger.WithError(err).WithField("module", "contextionary").Warnf(ModelUncontactable)
} else if strings.Contains(err.Error(), "connectex: No connection could be made because the target machine actively refused it.") {
c.logger.WithError(err).WithField("module", "contextionary").Warnf(ModelUncontactable)
}
st, ok := status.FromError(err)
if !ok || st.Code() != codes.InvalidArgument {
return nil, nil, fmt.Errorf("could not get vector from remote: %v", err)
}
return nil, nil, vectorizer.NewErrNoUsableWordsf(st.Message())
}
return vectorFromProto(res)
}
func (c *Client) VectorOnlyForCorpi(ctx context.Context, corpi []string, overrides map[string]string) ([]float32, error) {
vec, _, err := c.VectorForCorpi(ctx, corpi, overrides)
return vec, err
}
func (c *Client) NearestWordsByVector(ctx context.Context, vector []float32, n int, k int) ([]string, []float32, error) {
res, err := c.grpcClient.NearestWordsByVector(ctx, &pb.VectorNNParams{
K: int32(k),
N: int32(n),
Vector: vectorToProto(vector),
})
if err != nil {
logConnectionRefused(c.logger, err)
return nil, nil, fmt.Errorf("could not get nearest words by vector: %v", err)
}
return res.Words, res.Distances, nil
}
func (c *Client) AddExtension(ctx context.Context, extension *models.C11yExtension) error {
_, err := c.grpcClient.AddExtension(ctx, &pb.ExtensionInput{
Concept: extension.Concept,
Definition: strings.ToLower(extension.Definition),
Weight: extension.Weight,
})
return err
}
func vectorToProto(in []float32) *pb.Vector {
output := make([]*pb.VectorEntry, len(in))
for i, entry := range in {
output[i] = &pb.VectorEntry{
Entry: entry,
}
}
return &pb.Vector{Entries: output}
}
func (c *Client) WaitForStartupAndValidateVersion(startupCtx context.Context,
requiredMinimumVersion string, interval time.Duration,
) error {
for {
if err := startupCtx.Err(); err != nil {
return errors.Wrap(err, "wait for contextionary remote inference service")
}
time.Sleep(interval)
ctx, cancel := context.WithTimeout(startupCtx, 2*time.Second)
defer cancel()
v, err := c.version(ctx)
if err != nil {
c.logger.WithField("action", "startup_check_contextionary").WithError(err).
Warnf("could not connect to contextionary at startup, trying again in 1 sec")
continue
}
ok, err := extractVersionAndCompare(v, requiredMinimumVersion)
if err != nil {
c.logger.WithField("action", "startup_check_contextionary").
WithField("requiredMinimumContextionaryVersion", requiredMinimumVersion).
WithField("contextionaryVersion", v).
WithError(err).
Warnf("cannot determine if contextionary version is compatible. " +
"This is fine in development, but probelematic if you see this production")
return nil
}
if ok {
c.logger.WithField("action", "startup_check_contextionary").
WithField("requiredMinimumContextionaryVersion", requiredMinimumVersion).
WithField("contextionaryVersion", v).
Infof("found a valid contextionary version")
return nil
} else {
return errors.Errorf("insuffcient contextionary version: need at least %s, got %s",
requiredMinimumVersion, v)
}
}
}
func overridesFromMap(in map[string]string) []*pb.Override {
if in == nil {
return nil
}
out := make([]*pb.Override, len(in))
i := 0
for key, value := range in {
out[i] = &pb.Override{
Word: key,
Expression: value,
}
i++
}
return out
}