Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package modules | |
import ( | |
"context" | |
"fmt" | |
"github.com/pkg/errors" | |
"github.com/sirupsen/logrus" | |
"github.com/tailor-inc/graphql" | |
"github.com/tailor-inc/graphql/language/ast" | |
"github.com/weaviate/weaviate/entities/models" | |
"github.com/weaviate/weaviate/entities/modulecapabilities" | |
"github.com/weaviate/weaviate/entities/moduletools" | |
"github.com/weaviate/weaviate/entities/schema" | |
"github.com/weaviate/weaviate/entities/search" | |
) | |
var ( | |
internalSearchers = []string{ | |
"nearObject", "nearVector", "where", "group", "limit", "offset", | |
"after", "groupBy", "bm25", "hybrid", | |
} | |
internalAdditionalProperties = []string{"classification", "certainty", "id", "distance", "group"} | |
) | |
type Provider struct { | |
registered map[string]modulecapabilities.Module | |
altNames map[string]string | |
schemaGetter schemaGetter | |
hasMultipleVectorizers bool | |
} | |
type schemaGetter interface { | |
GetSchemaSkipAuth() schema.Schema | |
} | |
func NewProvider() *Provider { | |
return &Provider{ | |
registered: map[string]modulecapabilities.Module{}, | |
altNames: map[string]string{}, | |
} | |
} | |
func (p *Provider) Register(mod modulecapabilities.Module) { | |
p.registered[mod.Name()] = mod | |
if modHasAltNames, ok := mod.(modulecapabilities.ModuleHasAltNames); ok { | |
for _, altName := range modHasAltNames.AltNames() { | |
p.altNames[altName] = mod.Name() | |
} | |
} | |
} | |
func (p *Provider) GetByName(name string) modulecapabilities.Module { | |
if mod, ok := p.registered[name]; ok { | |
return mod | |
} | |
if origName, ok := p.altNames[name]; ok { | |
return p.registered[origName] | |
} | |
return nil | |
} | |
func (p *Provider) GetAll() []modulecapabilities.Module { | |
out := make([]modulecapabilities.Module, len(p.registered)) | |
i := 0 | |
for _, mod := range p.registered { | |
out[i] = mod | |
i++ | |
} | |
return out | |
} | |
func (p *Provider) GetAllExclude(module string) []modulecapabilities.Module { | |
filtered := []modulecapabilities.Module{} | |
for _, mod := range p.GetAll() { | |
if mod.Name() != module { | |
filtered = append(filtered, mod) | |
} | |
} | |
return filtered | |
} | |
func (p *Provider) SetSchemaGetter(sg schemaGetter) { | |
p.schemaGetter = sg | |
} | |
func (p *Provider) Init(ctx context.Context, | |
params moduletools.ModuleInitParams, logger logrus.FieldLogger, | |
) error { | |
for i, mod := range p.GetAll() { | |
if err := mod.Init(ctx, params); err != nil { | |
return errors.Wrapf(err, "init module %d (%q)", i, mod.Name()) | |
} else { | |
logger.WithField("action", "startup"). | |
WithField("module", mod.Name()). | |
Debug("initialized module") | |
} | |
} | |
for i, mod := range p.GetAll() { | |
if modExtension, ok := mod.(modulecapabilities.ModuleExtension); ok { | |
if err := modExtension.InitExtension(p.GetAllExclude(mod.Name())); err != nil { | |
return errors.Wrapf(err, "init module extension %d (%q)", i, mod.Name()) | |
} else { | |
logger.WithField("action", "startup"). | |
WithField("module", mod.Name()). | |
Debug("initialized module extension") | |
} | |
} | |
} | |
for i, mod := range p.GetAll() { | |
if modDependency, ok := mod.(modulecapabilities.ModuleDependency); ok { | |
if err := modDependency.InitDependency(p.GetAllExclude(mod.Name())); err != nil { | |
return errors.Wrapf(err, "init module dependency %d (%q)", i, mod.Name()) | |
} else { | |
logger.WithField("action", "startup"). | |
WithField("module", mod.Name()). | |
Debug("initialized module dependency") | |
} | |
} | |
} | |
if err := p.validate(); err != nil { | |
return errors.Wrap(err, "validate modules") | |
} | |
if p.HasMultipleVectorizers() { | |
logger.Warn("Multiple vector spaces are present, GraphQL Explore and REST API list objects endpoint module include params has been disabled as a result.") | |
} | |
return nil | |
} | |
func (p *Provider) validate() error { | |
searchers := map[string][]string{} | |
additionalGraphQLProps := map[string][]string{} | |
additionalRestAPIProps := map[string][]string{} | |
for _, mod := range p.GetAll() { | |
if module, ok := mod.(modulecapabilities.GraphQLArguments); ok { | |
allArguments := []string{} | |
for paraName, argument := range module.Arguments() { | |
if argument.ExtractFunction != nil { | |
allArguments = append(allArguments, paraName) | |
} | |
} | |
searchers = p.scanProperties(searchers, allArguments, mod.Name()) | |
} | |
if module, ok := mod.(modulecapabilities.AdditionalProperties); ok { | |
allAdditionalRestAPIProps, allAdditionalGrapQLProps := p.getAdditionalProps(module.AdditionalProperties()) | |
additionalGraphQLProps = p.scanProperties(additionalGraphQLProps, | |
allAdditionalGrapQLProps, mod.Name()) | |
additionalRestAPIProps = p.scanProperties(additionalRestAPIProps, | |
allAdditionalRestAPIProps, mod.Name()) | |
} | |
} | |
var errorMessages []string | |
errorMessages = append(errorMessages, | |
p.validateModules("searcher", searchers, internalSearchers)...) | |
errorMessages = append(errorMessages, | |
p.validateModules("graphql additional property", additionalGraphQLProps, internalAdditionalProperties)...) | |
errorMessages = append(errorMessages, | |
p.validateModules("rest api additional property", additionalRestAPIProps, internalAdditionalProperties)...) | |
if len(errorMessages) > 0 { | |
return errors.Errorf("%v", errorMessages) | |
} | |
return nil | |
} | |
func (p *Provider) scanProperties(result map[string][]string, properties []string, module string) map[string][]string { | |
for i := range properties { | |
if result[properties[i]] == nil { | |
result[properties[i]] = []string{} | |
} | |
modules := result[properties[i]] | |
modules = append(modules, module) | |
result[properties[i]] = modules | |
} | |
return result | |
} | |
func (p *Provider) getAdditionalProps(additionalProps map[string]modulecapabilities.AdditionalProperty) ([]string, []string) { | |
restProps := []string{} | |
graphQLProps := []string{} | |
for _, additionalProperty := range additionalProps { | |
if additionalProperty.RestNames != nil { | |
restProps = append(restProps, additionalProperty.RestNames...) | |
} | |
if additionalProperty.GraphQLNames != nil { | |
graphQLProps = append(graphQLProps, additionalProperty.GraphQLNames...) | |
} | |
} | |
return restProps, graphQLProps | |
} | |
func (p *Provider) validateModules(name string, properties map[string][]string, internalProperties []string) []string { | |
errorMessages := []string{} | |
for propertyName, modules := range properties { | |
for i := range internalProperties { | |
if internalProperties[i] == propertyName { | |
errorMessages = append(errorMessages, | |
fmt.Sprintf("%s: %s conflicts with weaviate's internal searcher in modules: %v", | |
name, propertyName, modules)) | |
} | |
} | |
if len(modules) > 1 { | |
p.hasMultipleVectorizers = true | |
} | |
for _, moduleName := range modules { | |
moduleType := p.GetByName(moduleName).Type() | |
if p.moduleProvidesMultipleVectorizers(moduleType) { | |
p.hasMultipleVectorizers = true | |
} | |
} | |
} | |
return errorMessages | |
} | |
func (p *Provider) moduleProvidesMultipleVectorizers(moduleType modulecapabilities.ModuleType) bool { | |
return moduleType == modulecapabilities.Text2MultiVec | |
} | |
func (p *Provider) isOnlyOneModuleEnabledOfAGivenType(moduleType modulecapabilities.ModuleType) bool { | |
i := 0 | |
for _, mod := range p.registered { | |
if mod.Type() == moduleType { | |
i++ | |
} | |
} | |
return i == 1 | |
} | |
func (p *Provider) isVectorizerModule(moduleType modulecapabilities.ModuleType) bool { | |
switch moduleType { | |
case modulecapabilities.Text2Vec, | |
modulecapabilities.Img2Vec, | |
modulecapabilities.Multi2Vec, | |
modulecapabilities.Text2MultiVec, | |
modulecapabilities.Ref2Vec: | |
return true | |
default: | |
return false | |
} | |
} | |
func (p *Provider) shouldIncludeClassArgument(class *models.Class, module string, | |
moduleType modulecapabilities.ModuleType, | |
) bool { | |
if p.isVectorizerModule(moduleType) { | |
return class.Vectorizer == module | |
} | |
if moduleConfig, ok := class.ModuleConfig.(map[string]interface{}); ok { | |
existsConfigForModule := moduleConfig[module] != nil | |
if existsConfigForModule { | |
return true | |
} | |
} | |
// Allow Text2Text (Generative, QnA, Summarize, NER) modules to be registered to a given class | |
// only if there's no configuration present and there's only one module of a given type enabled | |
return p.isOnlyOneModuleEnabledOfAGivenType(moduleType) | |
} | |
func (p *Provider) shouldCrossClassIncludeClassArgument(class *models.Class, module string, | |
moduleType modulecapabilities.ModuleType, | |
) bool { | |
if class == nil { | |
return !p.HasMultipleVectorizers() | |
} | |
return p.shouldIncludeClassArgument(class, module, moduleType) | |
} | |
func (p *Provider) shouldIncludeArgument(schema *models.Schema, module string, | |
moduleType modulecapabilities.ModuleType, | |
) bool { | |
for _, c := range schema.Classes { | |
if p.shouldIncludeClassArgument(c, module, moduleType) { | |
return true | |
} | |
} | |
return false | |
} | |
// GetArguments provides GraphQL Get arguments | |
func (p *Provider) GetArguments(class *models.Class) map[string]*graphql.ArgumentConfig { | |
arguments := map[string]*graphql.ArgumentConfig{} | |
for _, module := range p.GetAll() { | |
if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { | |
if arg, ok := module.(modulecapabilities.GraphQLArguments); ok { | |
for name, argument := range arg.Arguments() { | |
if argument.GetArgumentsFunction != nil { | |
arguments[name] = argument.GetArgumentsFunction(class.Class) | |
} | |
} | |
} | |
} | |
} | |
return arguments | |
} | |
// AggregateArguments provides GraphQL Aggregate arguments | |
func (p *Provider) AggregateArguments(class *models.Class) map[string]*graphql.ArgumentConfig { | |
arguments := map[string]*graphql.ArgumentConfig{} | |
for _, module := range p.GetAll() { | |
if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { | |
if arg, ok := module.(modulecapabilities.GraphQLArguments); ok { | |
for name, argument := range arg.Arguments() { | |
if argument.AggregateArgumentsFunction != nil { | |
arguments[name] = argument.AggregateArgumentsFunction(class.Class) | |
} | |
} | |
} | |
} | |
} | |
return arguments | |
} | |
// ExploreArguments provides GraphQL Explore arguments | |
func (p *Provider) ExploreArguments(schema *models.Schema) map[string]*graphql.ArgumentConfig { | |
arguments := map[string]*graphql.ArgumentConfig{} | |
for _, module := range p.GetAll() { | |
if p.shouldIncludeArgument(schema, module.Name(), module.Type()) { | |
if arg, ok := module.(modulecapabilities.GraphQLArguments); ok { | |
for name, argument := range arg.Arguments() { | |
if argument.ExploreArgumentsFunction != nil { | |
arguments[name] = argument.ExploreArgumentsFunction() | |
} | |
} | |
} | |
} | |
} | |
return arguments | |
} | |
// CrossClassExtractSearchParams extracts GraphQL arguments from modules without | |
// being specific to any one class and it's configuration. This is used in | |
// Explore() { } for example | |
func (p *Provider) CrossClassExtractSearchParams(arguments map[string]interface{}) map[string]interface{} { | |
return p.extractSearchParams(arguments, nil) | |
} | |
// ExtractSearchParams extracts GraphQL arguments | |
func (p *Provider) ExtractSearchParams(arguments map[string]interface{}, className string) map[string]interface{} { | |
exractedParams := map[string]interface{}{} | |
class, err := p.getClass(className) | |
if err != nil { | |
return exractedParams | |
} | |
return p.extractSearchParams(arguments, class) | |
} | |
func (p *Provider) extractSearchParams(arguments map[string]interface{}, class *models.Class) map[string]interface{} { | |
exractedParams := map[string]interface{}{} | |
for _, module := range p.GetAll() { | |
if p.shouldCrossClassIncludeClassArgument(class, module.Name(), module.Type()) { | |
if args, ok := module.(modulecapabilities.GraphQLArguments); ok { | |
for paramName, argument := range args.Arguments() { | |
if param, ok := arguments[paramName]; ok && argument.ExtractFunction != nil { | |
extracted := argument.ExtractFunction(param.(map[string]interface{})) | |
exractedParams[paramName] = extracted | |
} | |
} | |
} | |
} | |
} | |
return exractedParams | |
} | |
// CrossClassValidateSearchParam validates module parameters without | |
// being specific to any one class and it's configuration. This is used in | |
// Explore() { } for example | |
func (p *Provider) CrossClassValidateSearchParam(name string, value interface{}) error { | |
return p.validateSearchParam(name, value, nil) | |
} | |
// ValidateSearchParam validates module parameters | |
func (p *Provider) ValidateSearchParam(name string, value interface{}, className string) error { | |
class, err := p.getClass(className) | |
if err != nil { | |
return err | |
} | |
return p.validateSearchParam(name, value, class) | |
} | |
func (p *Provider) validateSearchParam(name string, value interface{}, class *models.Class) error { | |
for _, module := range p.GetAll() { | |
if p.shouldCrossClassIncludeClassArgument(class, module.Name(), module.Type()) { | |
if args, ok := module.(modulecapabilities.GraphQLArguments); ok { | |
for paramName, argument := range args.Arguments() { | |
if paramName == name && argument.ValidateFunction != nil { | |
return argument.ValidateFunction(value) | |
} | |
} | |
} | |
} | |
} | |
panic("ValidateParam was called without any known params present") | |
} | |
// GetAdditionalFields provides GraphQL Get additional fields | |
func (p *Provider) GetAdditionalFields(class *models.Class) map[string]*graphql.Field { | |
additionalProperties := map[string]*graphql.Field{} | |
for _, module := range p.GetAll() { | |
if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { | |
if arg, ok := module.(modulecapabilities.AdditionalProperties); ok { | |
for name, additionalProperty := range arg.AdditionalProperties() { | |
if additionalProperty.GraphQLFieldFunction != nil { | |
additionalProperties[name] = additionalProperty.GraphQLFieldFunction(class.Class) | |
} | |
} | |
} | |
} | |
} | |
return additionalProperties | |
} | |
// ExtractAdditionalField extracts additional properties from given graphql arguments | |
func (p *Provider) ExtractAdditionalField(className, name string, params []*ast.Argument) interface{} { | |
class, err := p.getClass(className) | |
if err != nil { | |
return err | |
} | |
for _, module := range p.GetAll() { | |
if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { | |
if arg, ok := module.(modulecapabilities.AdditionalProperties); ok { | |
if additionalProperties := arg.AdditionalProperties(); len(additionalProperties) > 0 { | |
if additionalProperty, ok := additionalProperties[name]; ok { | |
if additionalProperty.GraphQLExtractFunction != nil { | |
return additionalProperty.GraphQLExtractFunction(params) | |
} | |
} | |
} | |
} | |
} | |
} | |
return nil | |
} | |
// GetObjectAdditionalExtend extends rest api get queries with additional properties | |
func (p *Provider) GetObjectAdditionalExtend(ctx context.Context, | |
in *search.Result, moduleParams map[string]interface{}, | |
) (*search.Result, error) { | |
resArray, err := p.additionalExtend(ctx, search.Results{*in}, moduleParams, nil, "ObjectGet", nil) | |
if err != nil { | |
return nil, err | |
} | |
return &resArray[0], nil | |
} | |
// ListObjectsAdditionalExtend extends rest api list queries with additional properties | |
func (p *Provider) ListObjectsAdditionalExtend(ctx context.Context, | |
in search.Results, moduleParams map[string]interface{}, | |
) (search.Results, error) { | |
return p.additionalExtend(ctx, in, moduleParams, nil, "ObjectList", nil) | |
} | |
// GetExploreAdditionalExtend extends graphql api get queries with additional properties | |
func (p *Provider) GetExploreAdditionalExtend(ctx context.Context, in []search.Result, | |
moduleParams map[string]interface{}, searchVector []float32, | |
argumentModuleParams map[string]interface{}, | |
) ([]search.Result, error) { | |
return p.additionalExtend(ctx, in, moduleParams, searchVector, "ExploreGet", argumentModuleParams) | |
} | |
// ListExploreAdditionalExtend extends graphql api list queries with additional properties | |
func (p *Provider) ListExploreAdditionalExtend(ctx context.Context, in []search.Result, | |
moduleParams map[string]interface{}, | |
argumentModuleParams map[string]interface{}, | |
) ([]search.Result, error) { | |
return p.additionalExtend(ctx, in, moduleParams, nil, "ExploreList", argumentModuleParams) | |
} | |
func (p *Provider) additionalExtend(ctx context.Context, in []search.Result, | |
moduleParams map[string]interface{}, searchVector []float32, | |
capability string, argumentModuleParams map[string]interface{}, | |
) ([]search.Result, error) { | |
toBeExtended := in | |
if len(toBeExtended) > 0 { | |
class, err := p.getClassFromSearchResult(toBeExtended) | |
if err != nil { | |
return nil, err | |
} | |
allAdditionalProperties := map[string]modulecapabilities.AdditionalProperty{} | |
for _, module := range p.GetAll() { | |
if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { | |
if arg, ok := module.(modulecapabilities.AdditionalProperties); ok { | |
if arg != nil && arg.AdditionalProperties() != nil { | |
for name, additionalProperty := range arg.AdditionalProperties() { | |
allAdditionalProperties[name] = additionalProperty | |
} | |
} | |
} | |
} | |
} | |
if len(allAdditionalProperties) > 0 { | |
if err := p.checkCapabilities(allAdditionalProperties, moduleParams, capability); err != nil { | |
return nil, err | |
} | |
cfg := NewClassBasedModuleConfig(class, "", "") | |
for name, value := range moduleParams { | |
additionalPropertyFn := p.getAdditionalPropertyFn(allAdditionalProperties[name], capability) | |
if additionalPropertyFn != nil && value != nil { | |
searchValue := value | |
if searchVectorValue, ok := value.(modulecapabilities.AdditionalPropertyWithSearchVector); ok { | |
searchVectorValue.SetSearchVector(searchVector) | |
searchValue = searchVectorValue | |
} | |
resArray, err := additionalPropertyFn(ctx, toBeExtended, searchValue, nil, argumentModuleParams, cfg) | |
if err != nil { | |
return nil, errors.Errorf("extend %s: %v", name, err) | |
} | |
toBeExtended = resArray | |
} else { | |
return nil, errors.Errorf("unknown capability: %s", name) | |
} | |
} | |
} | |
} | |
return toBeExtended, nil | |
} | |
func (p *Provider) getClassFromSearchResult(in []search.Result) (*models.Class, error) { | |
if len(in) > 0 { | |
return p.getClass(in[0].ClassName) | |
} | |
return nil, errors.Errorf("unknown class") | |
} | |
func (p *Provider) checkCapabilities(additionalProperties map[string]modulecapabilities.AdditionalProperty, | |
moduleParams map[string]interface{}, capability string, | |
) error { | |
for name := range moduleParams { | |
additionalPropertyFn := p.getAdditionalPropertyFn(additionalProperties[name], capability) | |
if additionalPropertyFn == nil { | |
return errors.Errorf("unknown capability: %s", name) | |
} | |
} | |
return nil | |
} | |
func (p *Provider) getAdditionalPropertyFn( | |
additionalProperty modulecapabilities.AdditionalProperty, | |
capability string, | |
) modulecapabilities.AdditionalPropertyFn { | |
switch capability { | |
case "ObjectGet": | |
return additionalProperty.SearchFunctions.ObjectGet | |
case "ObjectList": | |
return additionalProperty.SearchFunctions.ObjectList | |
case "ExploreGet": | |
return additionalProperty.SearchFunctions.ExploreGet | |
case "ExploreList": | |
return additionalProperty.SearchFunctions.ExploreList | |
default: | |
return nil | |
} | |
} | |
// GraphQLAdditionalFieldNames get's all additional field names used in graphql | |
func (p *Provider) GraphQLAdditionalFieldNames() []string { | |
additionalPropertiesNames := []string{} | |
for _, module := range p.GetAll() { | |
if arg, ok := module.(modulecapabilities.AdditionalProperties); ok { | |
for _, additionalProperty := range arg.AdditionalProperties() { | |
if additionalProperty.GraphQLNames != nil { | |
additionalPropertiesNames = append(additionalPropertiesNames, additionalProperty.GraphQLNames...) | |
} | |
} | |
} | |
} | |
return additionalPropertiesNames | |
} | |
// RestApiAdditionalProperties get's all rest specific additional properties with their | |
// default values | |
func (p *Provider) RestApiAdditionalProperties(includeProp string, class *models.Class) map[string]interface{} { | |
moduleParams := map[string]interface{}{} | |
for _, module := range p.GetAll() { | |
if p.shouldCrossClassIncludeClassArgument(class, module.Name(), module.Type()) { | |
if arg, ok := module.(modulecapabilities.AdditionalProperties); ok { | |
for name, additionalProperty := range arg.AdditionalProperties() { | |
for _, includePropName := range additionalProperty.RestNames { | |
if includePropName == includeProp && moduleParams[name] == nil { | |
moduleParams[name] = additionalProperty.DefaultValue | |
} | |
} | |
} | |
} | |
} | |
} | |
return moduleParams | |
} | |
// VectorFromSearchParam gets a vector for a given argument. This is used in | |
// Get { Class() } for example | |
func (p *Provider) VectorFromSearchParam(ctx context.Context, | |
className string, param string, params interface{}, | |
findVectorFn modulecapabilities.FindVectorFn, tenant string, | |
) ([]float32, error) { | |
class, err := p.getClass(className) | |
if err != nil { | |
return nil, err | |
} | |
for _, mod := range p.GetAll() { | |
if p.shouldIncludeClassArgument(class, mod.Name(), mod.Type()) { | |
var moduleName string | |
var vectorSearches modulecapabilities.ArgumentVectorForParams | |
if searcher, ok := mod.(modulecapabilities.Searcher); ok { | |
moduleName = mod.Name() | |
vectorSearches = searcher.VectorSearches() | |
} else if searchers, ok := mod.(modulecapabilities.DependencySearcher); ok { | |
if dependencySearchers := searchers.VectorSearches(); dependencySearchers != nil { | |
moduleName = class.Vectorizer | |
vectorSearches = dependencySearchers[class.Vectorizer] | |
} | |
} | |
if vectorSearches != nil { | |
if searchVectorFn := vectorSearches[param]; searchVectorFn != nil { | |
cfg := NewClassBasedModuleConfig(class, moduleName, tenant) | |
vector, err := searchVectorFn(ctx, params, class.Class, findVectorFn, cfg) | |
if err != nil { | |
return nil, errors.Errorf("vectorize params: %v", err) | |
} | |
return vector, nil | |
} | |
} | |
} | |
} | |
panic("VectorFromParams was called without any known params present") | |
} | |
// CrossClassVectorFromSearchParam gets a vector for a given argument without | |
// being specific to any one class and it's configuration. This is used in | |
// Explore() { } for example | |
func (p *Provider) CrossClassVectorFromSearchParam(ctx context.Context, | |
param string, params interface{}, | |
findVectorFn modulecapabilities.FindVectorFn, | |
) ([]float32, error) { | |
for _, mod := range p.GetAll() { | |
if searcher, ok := mod.(modulecapabilities.Searcher); ok { | |
if vectorSearches := searcher.VectorSearches(); vectorSearches != nil { | |
if searchVectorFn := vectorSearches[param]; searchVectorFn != nil { | |
cfg := NewCrossClassModuleConfig() | |
vector, err := searchVectorFn(ctx, params, "", findVectorFn, cfg) | |
if err != nil { | |
return nil, errors.Errorf("vectorize params: %v", err) | |
} | |
return vector, nil | |
} | |
} | |
} | |
} | |
panic("VectorFromParams was called without any known params present") | |
} | |
func (p *Provider) VectorFromInput(ctx context.Context, | |
className string, input string, | |
) ([]float32, error) { | |
class, err := p.getClass(className) | |
if err != nil { | |
return nil, err | |
} | |
for _, mod := range p.GetAll() { | |
if p.shouldIncludeClassArgument(class, mod.Name(), mod.Type()) { | |
if vectorizer, ok := mod.(modulecapabilities.InputVectorizer); ok { | |
// does not access any objects, therefore tenant is irrelevant | |
cfg := NewClassBasedModuleConfig(class, mod.Name(), "") | |
return vectorizer.VectorizeInput(ctx, input, cfg) | |
} | |
} | |
} | |
return nil, fmt.Errorf("VectorFromInput was called without vectorizer") | |
} | |
// ParseClassifierSettings parses and adds classifier specific settings | |
func (p *Provider) ParseClassifierSettings(name string, | |
params *models.Classification, | |
) error { | |
class, err := p.getClass(params.Class) | |
if err != nil { | |
return err | |
} | |
for _, module := range p.GetAll() { | |
if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { | |
if c, ok := module.(modulecapabilities.ClassificationProvider); ok { | |
for _, classifier := range c.Classifiers() { | |
if classifier != nil && classifier.Name() == name { | |
return classifier.ParseClassifierSettings(params) | |
} | |
} | |
} | |
} | |
} | |
return nil | |
} | |
// GetClassificationFn returns given module's classification | |
func (p *Provider) GetClassificationFn(className, name string, | |
params modulecapabilities.ClassifyParams, | |
) (modulecapabilities.ClassifyItemFn, error) { | |
class, err := p.getClass(className) | |
if err != nil { | |
return nil, err | |
} | |
for _, module := range p.GetAll() { | |
if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { | |
if c, ok := module.(modulecapabilities.ClassificationProvider); ok { | |
for _, classifier := range c.Classifiers() { | |
if classifier != nil && classifier.Name() == name { | |
return classifier.ClassifyFn(params) | |
} | |
} | |
} | |
} | |
} | |
return nil, errors.Errorf("classifier %s not found", name) | |
} | |
// GetMeta returns meta information about modules | |
func (p *Provider) GetMeta() (map[string]interface{}, error) { | |
metaInfos := map[string]interface{}{} | |
for _, module := range p.GetAll() { | |
if c, ok := module.(modulecapabilities.MetaProvider); ok { | |
meta, err := c.MetaInfo() | |
if err != nil { | |
return nil, err | |
} | |
metaInfos[module.Name()] = meta | |
} | |
} | |
return metaInfos, nil | |
} | |
func (p *Provider) getClass(className string) (*models.Class, error) { | |
sch := p.schemaGetter.GetSchemaSkipAuth() | |
class := sch.FindClassByName(schema.ClassName(className)) | |
if class == nil { | |
return nil, errors.Errorf("class %q not found in schema", className) | |
} | |
return class, nil | |
} | |
func (p *Provider) HasMultipleVectorizers() bool { | |
return p.hasMultipleVectorizers | |
} | |
func (p *Provider) BackupBackend(backend string) (modulecapabilities.BackupBackend, error) { | |
if module := p.GetByName(backend); module != nil { | |
if module.Type() == modulecapabilities.Backup { | |
if backend, ok := module.(modulecapabilities.BackupBackend); ok { | |
return backend, nil | |
} | |
} | |
} | |
return nil, errors.Errorf("backup: %s not found", backend) | |
} | |