// _ _ // __ _____ __ ___ ___ __ _| |_ ___ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ // \ V V / __/ (_| |\ V /| | (_| | || __/ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| // // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. // // CONTACT: hello@weaviate.io // package modules import ( "context" "fmt" "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/tailor-inc/graphql" "github.com/tailor-inc/graphql/language/ast" "github.com/weaviate/weaviate/entities/models" "github.com/weaviate/weaviate/entities/modulecapabilities" "github.com/weaviate/weaviate/entities/moduletools" "github.com/weaviate/weaviate/entities/schema" "github.com/weaviate/weaviate/entities/search" ) var ( internalSearchers = []string{ "nearObject", "nearVector", "where", "group", "limit", "offset", "after", "groupBy", "bm25", "hybrid", } internalAdditionalProperties = []string{"classification", "certainty", "id", "distance", "group"} ) type Provider struct { registered map[string]modulecapabilities.Module altNames map[string]string schemaGetter schemaGetter hasMultipleVectorizers bool } type schemaGetter interface { GetSchemaSkipAuth() schema.Schema } func NewProvider() *Provider { return &Provider{ registered: map[string]modulecapabilities.Module{}, altNames: map[string]string{}, } } func (p *Provider) Register(mod modulecapabilities.Module) { p.registered[mod.Name()] = mod if modHasAltNames, ok := mod.(modulecapabilities.ModuleHasAltNames); ok { for _, altName := range modHasAltNames.AltNames() { p.altNames[altName] = mod.Name() } } } func (p *Provider) GetByName(name string) modulecapabilities.Module { if mod, ok := p.registered[name]; ok { return mod } if origName, ok := p.altNames[name]; ok { return p.registered[origName] } return nil } func (p *Provider) GetAll() []modulecapabilities.Module { out := make([]modulecapabilities.Module, len(p.registered)) i := 0 for _, mod := range p.registered { out[i] = mod i++ } return out } func (p *Provider) GetAllExclude(module string) []modulecapabilities.Module { filtered := []modulecapabilities.Module{} for _, mod := range p.GetAll() { if mod.Name() != module { filtered = append(filtered, mod) } } return filtered } func (p *Provider) SetSchemaGetter(sg schemaGetter) { p.schemaGetter = sg } func (p *Provider) Init(ctx context.Context, params moduletools.ModuleInitParams, logger logrus.FieldLogger, ) error { for i, mod := range p.GetAll() { if err := mod.Init(ctx, params); err != nil { return errors.Wrapf(err, "init module %d (%q)", i, mod.Name()) } else { logger.WithField("action", "startup"). WithField("module", mod.Name()). Debug("initialized module") } } for i, mod := range p.GetAll() { if modExtension, ok := mod.(modulecapabilities.ModuleExtension); ok { if err := modExtension.InitExtension(p.GetAllExclude(mod.Name())); err != nil { return errors.Wrapf(err, "init module extension %d (%q)", i, mod.Name()) } else { logger.WithField("action", "startup"). WithField("module", mod.Name()). Debug("initialized module extension") } } } for i, mod := range p.GetAll() { if modDependency, ok := mod.(modulecapabilities.ModuleDependency); ok { if err := modDependency.InitDependency(p.GetAllExclude(mod.Name())); err != nil { return errors.Wrapf(err, "init module dependency %d (%q)", i, mod.Name()) } else { logger.WithField("action", "startup"). WithField("module", mod.Name()). Debug("initialized module dependency") } } } if err := p.validate(); err != nil { return errors.Wrap(err, "validate modules") } if p.HasMultipleVectorizers() { logger.Warn("Multiple vector spaces are present, GraphQL Explore and REST API list objects endpoint module include params has been disabled as a result.") } return nil } func (p *Provider) validate() error { searchers := map[string][]string{} additionalGraphQLProps := map[string][]string{} additionalRestAPIProps := map[string][]string{} for _, mod := range p.GetAll() { if module, ok := mod.(modulecapabilities.GraphQLArguments); ok { allArguments := []string{} for paraName, argument := range module.Arguments() { if argument.ExtractFunction != nil { allArguments = append(allArguments, paraName) } } searchers = p.scanProperties(searchers, allArguments, mod.Name()) } if module, ok := mod.(modulecapabilities.AdditionalProperties); ok { allAdditionalRestAPIProps, allAdditionalGrapQLProps := p.getAdditionalProps(module.AdditionalProperties()) additionalGraphQLProps = p.scanProperties(additionalGraphQLProps, allAdditionalGrapQLProps, mod.Name()) additionalRestAPIProps = p.scanProperties(additionalRestAPIProps, allAdditionalRestAPIProps, mod.Name()) } } var errorMessages []string errorMessages = append(errorMessages, p.validateModules("searcher", searchers, internalSearchers)...) errorMessages = append(errorMessages, p.validateModules("graphql additional property", additionalGraphQLProps, internalAdditionalProperties)...) errorMessages = append(errorMessages, p.validateModules("rest api additional property", additionalRestAPIProps, internalAdditionalProperties)...) if len(errorMessages) > 0 { return errors.Errorf("%v", errorMessages) } return nil } func (p *Provider) scanProperties(result map[string][]string, properties []string, module string) map[string][]string { for i := range properties { if result[properties[i]] == nil { result[properties[i]] = []string{} } modules := result[properties[i]] modules = append(modules, module) result[properties[i]] = modules } return result } func (p *Provider) getAdditionalProps(additionalProps map[string]modulecapabilities.AdditionalProperty) ([]string, []string) { restProps := []string{} graphQLProps := []string{} for _, additionalProperty := range additionalProps { if additionalProperty.RestNames != nil { restProps = append(restProps, additionalProperty.RestNames...) } if additionalProperty.GraphQLNames != nil { graphQLProps = append(graphQLProps, additionalProperty.GraphQLNames...) } } return restProps, graphQLProps } func (p *Provider) validateModules(name string, properties map[string][]string, internalProperties []string) []string { errorMessages := []string{} for propertyName, modules := range properties { for i := range internalProperties { if internalProperties[i] == propertyName { errorMessages = append(errorMessages, fmt.Sprintf("%s: %s conflicts with weaviate's internal searcher in modules: %v", name, propertyName, modules)) } } if len(modules) > 1 { p.hasMultipleVectorizers = true } for _, moduleName := range modules { moduleType := p.GetByName(moduleName).Type() if p.moduleProvidesMultipleVectorizers(moduleType) { p.hasMultipleVectorizers = true } } } return errorMessages } func (p *Provider) moduleProvidesMultipleVectorizers(moduleType modulecapabilities.ModuleType) bool { return moduleType == modulecapabilities.Text2MultiVec } func (p *Provider) isOnlyOneModuleEnabledOfAGivenType(moduleType modulecapabilities.ModuleType) bool { i := 0 for _, mod := range p.registered { if mod.Type() == moduleType { i++ } } return i == 1 } func (p *Provider) isVectorizerModule(moduleType modulecapabilities.ModuleType) bool { switch moduleType { case modulecapabilities.Text2Vec, modulecapabilities.Img2Vec, modulecapabilities.Multi2Vec, modulecapabilities.Text2MultiVec, modulecapabilities.Ref2Vec: return true default: return false } } func (p *Provider) shouldIncludeClassArgument(class *models.Class, module string, moduleType modulecapabilities.ModuleType, ) bool { if p.isVectorizerModule(moduleType) { return class.Vectorizer == module } if moduleConfig, ok := class.ModuleConfig.(map[string]interface{}); ok { existsConfigForModule := moduleConfig[module] != nil if existsConfigForModule { return true } } // Allow Text2Text (Generative, QnA, Summarize, NER) modules to be registered to a given class // only if there's no configuration present and there's only one module of a given type enabled return p.isOnlyOneModuleEnabledOfAGivenType(moduleType) } func (p *Provider) shouldCrossClassIncludeClassArgument(class *models.Class, module string, moduleType modulecapabilities.ModuleType, ) bool { if class == nil { return !p.HasMultipleVectorizers() } return p.shouldIncludeClassArgument(class, module, moduleType) } func (p *Provider) shouldIncludeArgument(schema *models.Schema, module string, moduleType modulecapabilities.ModuleType, ) bool { for _, c := range schema.Classes { if p.shouldIncludeClassArgument(c, module, moduleType) { return true } } return false } // GetArguments provides GraphQL Get arguments func (p *Provider) GetArguments(class *models.Class) map[string]*graphql.ArgumentConfig { arguments := map[string]*graphql.ArgumentConfig{} for _, module := range p.GetAll() { if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { if arg, ok := module.(modulecapabilities.GraphQLArguments); ok { for name, argument := range arg.Arguments() { if argument.GetArgumentsFunction != nil { arguments[name] = argument.GetArgumentsFunction(class.Class) } } } } } return arguments } // AggregateArguments provides GraphQL Aggregate arguments func (p *Provider) AggregateArguments(class *models.Class) map[string]*graphql.ArgumentConfig { arguments := map[string]*graphql.ArgumentConfig{} for _, module := range p.GetAll() { if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { if arg, ok := module.(modulecapabilities.GraphQLArguments); ok { for name, argument := range arg.Arguments() { if argument.AggregateArgumentsFunction != nil { arguments[name] = argument.AggregateArgumentsFunction(class.Class) } } } } } return arguments } // ExploreArguments provides GraphQL Explore arguments func (p *Provider) ExploreArguments(schema *models.Schema) map[string]*graphql.ArgumentConfig { arguments := map[string]*graphql.ArgumentConfig{} for _, module := range p.GetAll() { if p.shouldIncludeArgument(schema, module.Name(), module.Type()) { if arg, ok := module.(modulecapabilities.GraphQLArguments); ok { for name, argument := range arg.Arguments() { if argument.ExploreArgumentsFunction != nil { arguments[name] = argument.ExploreArgumentsFunction() } } } } } return arguments } // CrossClassExtractSearchParams extracts GraphQL arguments from modules without // being specific to any one class and it's configuration. This is used in // Explore() { } for example func (p *Provider) CrossClassExtractSearchParams(arguments map[string]interface{}) map[string]interface{} { return p.extractSearchParams(arguments, nil) } // ExtractSearchParams extracts GraphQL arguments func (p *Provider) ExtractSearchParams(arguments map[string]interface{}, className string) map[string]interface{} { exractedParams := map[string]interface{}{} class, err := p.getClass(className) if err != nil { return exractedParams } return p.extractSearchParams(arguments, class) } func (p *Provider) extractSearchParams(arguments map[string]interface{}, class *models.Class) map[string]interface{} { exractedParams := map[string]interface{}{} for _, module := range p.GetAll() { if p.shouldCrossClassIncludeClassArgument(class, module.Name(), module.Type()) { if args, ok := module.(modulecapabilities.GraphQLArguments); ok { for paramName, argument := range args.Arguments() { if param, ok := arguments[paramName]; ok && argument.ExtractFunction != nil { extracted := argument.ExtractFunction(param.(map[string]interface{})) exractedParams[paramName] = extracted } } } } } return exractedParams } // CrossClassValidateSearchParam validates module parameters without // being specific to any one class and it's configuration. This is used in // Explore() { } for example func (p *Provider) CrossClassValidateSearchParam(name string, value interface{}) error { return p.validateSearchParam(name, value, nil) } // ValidateSearchParam validates module parameters func (p *Provider) ValidateSearchParam(name string, value interface{}, className string) error { class, err := p.getClass(className) if err != nil { return err } return p.validateSearchParam(name, value, class) } func (p *Provider) validateSearchParam(name string, value interface{}, class *models.Class) error { for _, module := range p.GetAll() { if p.shouldCrossClassIncludeClassArgument(class, module.Name(), module.Type()) { if args, ok := module.(modulecapabilities.GraphQLArguments); ok { for paramName, argument := range args.Arguments() { if paramName == name && argument.ValidateFunction != nil { return argument.ValidateFunction(value) } } } } } panic("ValidateParam was called without any known params present") } // GetAdditionalFields provides GraphQL Get additional fields func (p *Provider) GetAdditionalFields(class *models.Class) map[string]*graphql.Field { additionalProperties := map[string]*graphql.Field{} for _, module := range p.GetAll() { if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { if arg, ok := module.(modulecapabilities.AdditionalProperties); ok { for name, additionalProperty := range arg.AdditionalProperties() { if additionalProperty.GraphQLFieldFunction != nil { additionalProperties[name] = additionalProperty.GraphQLFieldFunction(class.Class) } } } } } return additionalProperties } // ExtractAdditionalField extracts additional properties from given graphql arguments func (p *Provider) ExtractAdditionalField(className, name string, params []*ast.Argument) interface{} { class, err := p.getClass(className) if err != nil { return err } for _, module := range p.GetAll() { if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { if arg, ok := module.(modulecapabilities.AdditionalProperties); ok { if additionalProperties := arg.AdditionalProperties(); len(additionalProperties) > 0 { if additionalProperty, ok := additionalProperties[name]; ok { if additionalProperty.GraphQLExtractFunction != nil { return additionalProperty.GraphQLExtractFunction(params) } } } } } } return nil } // GetObjectAdditionalExtend extends rest api get queries with additional properties func (p *Provider) GetObjectAdditionalExtend(ctx context.Context, in *search.Result, moduleParams map[string]interface{}, ) (*search.Result, error) { resArray, err := p.additionalExtend(ctx, search.Results{*in}, moduleParams, nil, "ObjectGet", nil) if err != nil { return nil, err } return &resArray[0], nil } // ListObjectsAdditionalExtend extends rest api list queries with additional properties func (p *Provider) ListObjectsAdditionalExtend(ctx context.Context, in search.Results, moduleParams map[string]interface{}, ) (search.Results, error) { return p.additionalExtend(ctx, in, moduleParams, nil, "ObjectList", nil) } // GetExploreAdditionalExtend extends graphql api get queries with additional properties func (p *Provider) GetExploreAdditionalExtend(ctx context.Context, in []search.Result, moduleParams map[string]interface{}, searchVector []float32, argumentModuleParams map[string]interface{}, ) ([]search.Result, error) { return p.additionalExtend(ctx, in, moduleParams, searchVector, "ExploreGet", argumentModuleParams) } // ListExploreAdditionalExtend extends graphql api list queries with additional properties func (p *Provider) ListExploreAdditionalExtend(ctx context.Context, in []search.Result, moduleParams map[string]interface{}, argumentModuleParams map[string]interface{}, ) ([]search.Result, error) { return p.additionalExtend(ctx, in, moduleParams, nil, "ExploreList", argumentModuleParams) } func (p *Provider) additionalExtend(ctx context.Context, in []search.Result, moduleParams map[string]interface{}, searchVector []float32, capability string, argumentModuleParams map[string]interface{}, ) ([]search.Result, error) { toBeExtended := in if len(toBeExtended) > 0 { class, err := p.getClassFromSearchResult(toBeExtended) if err != nil { return nil, err } allAdditionalProperties := map[string]modulecapabilities.AdditionalProperty{} for _, module := range p.GetAll() { if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { if arg, ok := module.(modulecapabilities.AdditionalProperties); ok { if arg != nil && arg.AdditionalProperties() != nil { for name, additionalProperty := range arg.AdditionalProperties() { allAdditionalProperties[name] = additionalProperty } } } } } if len(allAdditionalProperties) > 0 { if err := p.checkCapabilities(allAdditionalProperties, moduleParams, capability); err != nil { return nil, err } cfg := NewClassBasedModuleConfig(class, "", "") for name, value := range moduleParams { additionalPropertyFn := p.getAdditionalPropertyFn(allAdditionalProperties[name], capability) if additionalPropertyFn != nil && value != nil { searchValue := value if searchVectorValue, ok := value.(modulecapabilities.AdditionalPropertyWithSearchVector); ok { searchVectorValue.SetSearchVector(searchVector) searchValue = searchVectorValue } resArray, err := additionalPropertyFn(ctx, toBeExtended, searchValue, nil, argumentModuleParams, cfg) if err != nil { return nil, errors.Errorf("extend %s: %v", name, err) } toBeExtended = resArray } else { return nil, errors.Errorf("unknown capability: %s", name) } } } } return toBeExtended, nil } func (p *Provider) getClassFromSearchResult(in []search.Result) (*models.Class, error) { if len(in) > 0 { return p.getClass(in[0].ClassName) } return nil, errors.Errorf("unknown class") } func (p *Provider) checkCapabilities(additionalProperties map[string]modulecapabilities.AdditionalProperty, moduleParams map[string]interface{}, capability string, ) error { for name := range moduleParams { additionalPropertyFn := p.getAdditionalPropertyFn(additionalProperties[name], capability) if additionalPropertyFn == nil { return errors.Errorf("unknown capability: %s", name) } } return nil } func (p *Provider) getAdditionalPropertyFn( additionalProperty modulecapabilities.AdditionalProperty, capability string, ) modulecapabilities.AdditionalPropertyFn { switch capability { case "ObjectGet": return additionalProperty.SearchFunctions.ObjectGet case "ObjectList": return additionalProperty.SearchFunctions.ObjectList case "ExploreGet": return additionalProperty.SearchFunctions.ExploreGet case "ExploreList": return additionalProperty.SearchFunctions.ExploreList default: return nil } } // GraphQLAdditionalFieldNames get's all additional field names used in graphql func (p *Provider) GraphQLAdditionalFieldNames() []string { additionalPropertiesNames := []string{} for _, module := range p.GetAll() { if arg, ok := module.(modulecapabilities.AdditionalProperties); ok { for _, additionalProperty := range arg.AdditionalProperties() { if additionalProperty.GraphQLNames != nil { additionalPropertiesNames = append(additionalPropertiesNames, additionalProperty.GraphQLNames...) } } } } return additionalPropertiesNames } // RestApiAdditionalProperties get's all rest specific additional properties with their // default values func (p *Provider) RestApiAdditionalProperties(includeProp string, class *models.Class) map[string]interface{} { moduleParams := map[string]interface{}{} for _, module := range p.GetAll() { if p.shouldCrossClassIncludeClassArgument(class, module.Name(), module.Type()) { if arg, ok := module.(modulecapabilities.AdditionalProperties); ok { for name, additionalProperty := range arg.AdditionalProperties() { for _, includePropName := range additionalProperty.RestNames { if includePropName == includeProp && moduleParams[name] == nil { moduleParams[name] = additionalProperty.DefaultValue } } } } } } return moduleParams } // VectorFromSearchParam gets a vector for a given argument. This is used in // Get { Class() } for example func (p *Provider) VectorFromSearchParam(ctx context.Context, className string, param string, params interface{}, findVectorFn modulecapabilities.FindVectorFn, tenant string, ) ([]float32, error) { class, err := p.getClass(className) if err != nil { return nil, err } for _, mod := range p.GetAll() { if p.shouldIncludeClassArgument(class, mod.Name(), mod.Type()) { var moduleName string var vectorSearches modulecapabilities.ArgumentVectorForParams if searcher, ok := mod.(modulecapabilities.Searcher); ok { moduleName = mod.Name() vectorSearches = searcher.VectorSearches() } else if searchers, ok := mod.(modulecapabilities.DependencySearcher); ok { if dependencySearchers := searchers.VectorSearches(); dependencySearchers != nil { moduleName = class.Vectorizer vectorSearches = dependencySearchers[class.Vectorizer] } } if vectorSearches != nil { if searchVectorFn := vectorSearches[param]; searchVectorFn != nil { cfg := NewClassBasedModuleConfig(class, moduleName, tenant) vector, err := searchVectorFn(ctx, params, class.Class, findVectorFn, cfg) if err != nil { return nil, errors.Errorf("vectorize params: %v", err) } return vector, nil } } } } panic("VectorFromParams was called without any known params present") } // CrossClassVectorFromSearchParam gets a vector for a given argument without // being specific to any one class and it's configuration. This is used in // Explore() { } for example func (p *Provider) CrossClassVectorFromSearchParam(ctx context.Context, param string, params interface{}, findVectorFn modulecapabilities.FindVectorFn, ) ([]float32, error) { for _, mod := range p.GetAll() { if searcher, ok := mod.(modulecapabilities.Searcher); ok { if vectorSearches := searcher.VectorSearches(); vectorSearches != nil { if searchVectorFn := vectorSearches[param]; searchVectorFn != nil { cfg := NewCrossClassModuleConfig() vector, err := searchVectorFn(ctx, params, "", findVectorFn, cfg) if err != nil { return nil, errors.Errorf("vectorize params: %v", err) } return vector, nil } } } } panic("VectorFromParams was called without any known params present") } func (p *Provider) VectorFromInput(ctx context.Context, className string, input string, ) ([]float32, error) { class, err := p.getClass(className) if err != nil { return nil, err } for _, mod := range p.GetAll() { if p.shouldIncludeClassArgument(class, mod.Name(), mod.Type()) { if vectorizer, ok := mod.(modulecapabilities.InputVectorizer); ok { // does not access any objects, therefore tenant is irrelevant cfg := NewClassBasedModuleConfig(class, mod.Name(), "") return vectorizer.VectorizeInput(ctx, input, cfg) } } } return nil, fmt.Errorf("VectorFromInput was called without vectorizer") } // ParseClassifierSettings parses and adds classifier specific settings func (p *Provider) ParseClassifierSettings(name string, params *models.Classification, ) error { class, err := p.getClass(params.Class) if err != nil { return err } for _, module := range p.GetAll() { if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { if c, ok := module.(modulecapabilities.ClassificationProvider); ok { for _, classifier := range c.Classifiers() { if classifier != nil && classifier.Name() == name { return classifier.ParseClassifierSettings(params) } } } } } return nil } // GetClassificationFn returns given module's classification func (p *Provider) GetClassificationFn(className, name string, params modulecapabilities.ClassifyParams, ) (modulecapabilities.ClassifyItemFn, error) { class, err := p.getClass(className) if err != nil { return nil, err } for _, module := range p.GetAll() { if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { if c, ok := module.(modulecapabilities.ClassificationProvider); ok { for _, classifier := range c.Classifiers() { if classifier != nil && classifier.Name() == name { return classifier.ClassifyFn(params) } } } } } return nil, errors.Errorf("classifier %s not found", name) } // GetMeta returns meta information about modules func (p *Provider) GetMeta() (map[string]interface{}, error) { metaInfos := map[string]interface{}{} for _, module := range p.GetAll() { if c, ok := module.(modulecapabilities.MetaProvider); ok { meta, err := c.MetaInfo() if err != nil { return nil, err } metaInfos[module.Name()] = meta } } return metaInfos, nil } func (p *Provider) getClass(className string) (*models.Class, error) { sch := p.schemaGetter.GetSchemaSkipAuth() class := sch.FindClassByName(schema.ClassName(className)) if class == nil { return nil, errors.Errorf("class %q not found in schema", className) } return class, nil } func (p *Provider) HasMultipleVectorizers() bool { return p.hasMultipleVectorizers } func (p *Provider) BackupBackend(backend string) (modulecapabilities.BackupBackend, error) { if module := p.GetByName(backend); module != nil { if module.Type() == modulecapabilities.Backup { if backend, ok := module.(modulecapabilities.BackupBackend); ok { return backend, nil } } } return nil, errors.Errorf("backup: %s not found", backend) }