KevinStephenson
Adding in weaviate code
b110593
raw
history blame
5.02 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package modules
import (
"context"
"fmt"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/modulecapabilities"
"github.com/weaviate/weaviate/entities/moduletools"
"github.com/weaviate/weaviate/entities/schema"
"github.com/weaviate/weaviate/entities/vectorindex/flat"
"github.com/weaviate/weaviate/entities/vectorindex/hnsw"
"github.com/weaviate/weaviate/usecases/config"
)
const (
errorVectorizerCapability = "module %q exists, but does not provide the " +
"Vectorizer or ReferenceVectorizer capability"
errorVectorIndexType = "vector index config (%T) is not of type HNSW, " +
"but objects manager is restricted to HNSW"
warningVectorIgnored = "This vector will be ignored. If you meant to index " +
"the vector, make sure to set vectorIndexConfig.skip to 'false'. If the previous " +
"setting is correct, make sure you set vectorizer to 'none' in the schema and " +
"provide a null-vector (i.e. no vector) at import time."
warningSkipVectorGenerated = "this class is configured to skip vector indexing, " +
"but a vector was generated by the %q vectorizer. " + warningVectorIgnored
warningSkipVectorProvided = "this class is configured to skip vector indexing, " +
"but a vector was explicitly provided. " + warningVectorIgnored
)
func (p *Provider) ValidateVectorizer(moduleName string) error {
mod := p.GetByName(moduleName)
if mod == nil {
return errors.Errorf("no module with name %q present", moduleName)
}
_, okVec := mod.(modulecapabilities.Vectorizer)
_, okRefVec := mod.(modulecapabilities.ReferenceVectorizer)
if !okVec && !okRefVec {
return errors.Errorf(errorVectorizerCapability, moduleName)
}
return nil
}
func (p *Provider) UsingRef2Vec(className string) bool {
class, err := p.getClass(className)
if err != nil {
return false
}
cfg := class.ModuleConfig
if cfg == nil {
return false
}
for modName := range cfg.(map[string]interface{}) {
mod := p.GetByName(modName)
if _, ok := mod.(modulecapabilities.ReferenceVectorizer); ok {
return true
}
}
return false
}
func (p *Provider) UpdateVector(ctx context.Context, object *models.Object, class *models.Class,
objectDiff *moduletools.ObjectDiff, findObjectFn modulecapabilities.FindObjectFn,
logger logrus.FieldLogger,
) error {
hnswConfig, okHnsw := class.VectorIndexConfig.(hnsw.UserConfig)
_, okFlat := class.VectorIndexConfig.(flat.UserConfig)
if !(okHnsw || okFlat) {
return fmt.Errorf(errorVectorIndexType, class.VectorIndexConfig)
}
if class.Vectorizer == config.VectorizerModuleNone {
if hnswConfig.Skip && len(object.Vector) > 0 {
logger.WithField("className", object.Class).
Warningf(warningSkipVectorProvided)
}
return nil
}
if hnswConfig.Skip {
logger.WithField("className", object.Class).
WithField("vectorizer", class.Vectorizer).
Warningf(warningSkipVectorGenerated, class.Vectorizer)
}
modConfig, ok := class.ModuleConfig.(map[string]interface{})
if !ok {
return fmt.Errorf("class %v not present", object.Class)
}
var found modulecapabilities.Module
for modName := range modConfig {
if err := p.ValidateVectorizer(modName); err == nil {
found = p.GetByName(modName)
break
}
}
if found == nil {
return fmt.Errorf(
"no vectorizer found for class %q", object.Class)
}
cfg := NewClassBasedModuleConfig(class, found.Name(), "")
if vectorizer, ok := found.(modulecapabilities.Vectorizer); ok {
if object.Vector == nil {
if err := vectorizer.VectorizeObject(ctx, object, objectDiff, cfg); err != nil {
return fmt.Errorf("update vector: %w", err)
}
}
} else {
refVectorizer := found.(modulecapabilities.ReferenceVectorizer)
if err := refVectorizer.VectorizeObject(
ctx, object, cfg, findObjectFn); err != nil {
return fmt.Errorf("update reference vector: %w", err)
}
}
return nil
}
func (p *Provider) VectorizerName(className string) (string, error) {
name, _, err := p.getClassVectorizer(className)
if err != nil {
return "", err
}
return name, nil
}
func (p *Provider) getClassVectorizer(className string) (string, interface{}, error) {
sch := p.schemaGetter.GetSchemaSkipAuth()
class := sch.FindClassByName(schema.ClassName(className))
if class == nil {
// this should be impossible by the time this method gets called, but let's
// be 100% certain
return "", nil, fmt.Errorf("class %s not present", className)
}
return class.Vectorizer, class.VectorIndexConfig, nil
}