// _ _ // __ _____ __ ___ ___ __ _| |_ ___ // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ // \ V V / __/ (_| |\ V /| | (_| | || __/ // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| // // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. // // CONTACT: hello@weaviate.io // package classification import ( "context" "encoding/json" "fmt" "time" "github.com/go-openapi/strfmt" "github.com/google/uuid" "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/weaviate/weaviate/adapters/handlers/rest/filterext" "github.com/weaviate/weaviate/entities/additional" "github.com/weaviate/weaviate/entities/dto" libfilters "github.com/weaviate/weaviate/entities/filters" "github.com/weaviate/weaviate/entities/models" "github.com/weaviate/weaviate/entities/modulecapabilities" "github.com/weaviate/weaviate/entities/search" "github.com/weaviate/weaviate/usecases/objects" schemaUC "github.com/weaviate/weaviate/usecases/schema" libvectorizer "github.com/weaviate/weaviate/usecases/vectorizer" ) type classificationFilters struct { source *libfilters.LocalFilter target *libfilters.LocalFilter trainingSet *libfilters.LocalFilter } func (f classificationFilters) Source() *libfilters.LocalFilter { return f.source } func (f classificationFilters) Target() *libfilters.LocalFilter { return f.target } func (f classificationFilters) TrainingSet() *libfilters.LocalFilter { return f.trainingSet } type distancer func(a, b []float32) (float32, error) type Classifier struct { schemaGetter schemaUC.SchemaGetter repo Repo vectorRepo vectorRepo vectorClassSearchRepo modulecapabilities.VectorClassSearchRepo authorizer authorizer distancer distancer modulesProvider ModulesProvider logger logrus.FieldLogger } type authorizer interface { Authorize(principal *models.Principal, verb, resource string) error } type ModulesProvider interface { ParseClassifierSettings(name string, params *models.Classification) error GetClassificationFn(className, name string, params modulecapabilities.ClassifyParams) (modulecapabilities.ClassifyItemFn, error) } func New(sg schemaUC.SchemaGetter, cr Repo, vr vectorRepo, authorizer authorizer, logger logrus.FieldLogger, modulesProvider ModulesProvider, ) *Classifier { return &Classifier{ logger: logger, schemaGetter: sg, repo: cr, vectorRepo: vr, authorizer: authorizer, distancer: libvectorizer.NormalizedDistance, vectorClassSearchRepo: newVectorClassSearchRepo(vr), modulesProvider: modulesProvider, } } // Repo to manage classification state, should be consistent, not used to store // actual data object vectors, see VectorRepo type Repo interface { Put(ctx context.Context, classification models.Classification) error Get(ctx context.Context, id strfmt.UUID) (*models.Classification, error) } type VectorRepo interface { GetUnclassified(ctx context.Context, class string, properties []string, filter *libfilters.LocalFilter) ([]search.Result, error) AggregateNeighbors(ctx context.Context, vector []float32, class string, properties []string, k int, filter *libfilters.LocalFilter) ([]NeighborRef, error) VectorSearch(ctx context.Context, params dto.GetParams) ([]search.Result, error) ZeroShotSearch(ctx context.Context, vector []float32, class string, properties []string, filter *libfilters.LocalFilter) ([]search.Result, error) } type vectorRepo interface { VectorRepo BatchPutObjects(ctx context.Context, objects objects.BatchObjects, repl *additional.ReplicationProperties) (objects.BatchObjects, error) } // NeighborRef is the result of an aggregation of the ref properties of k // neighbors type NeighborRef struct { // Property indicates which property was aggregated Property string // The beacon of the most common (kNN) reference Beacon strfmt.URI OverallCount int WinningCount int LosingCount int Distances NeighborRefDistances } func (c *Classifier) Schedule(ctx context.Context, principal *models.Principal, params models.Classification) (*models.Classification, error) { err := c.authorizer.Authorize(principal, "create", "classifications/*") if err != nil { return nil, err } err = c.parseAndSetDefaults(¶ms) if err != nil { return nil, err } err = NewValidator(c.schemaGetter, params).Do() if err != nil { return nil, err } if err := c.assignNewID(¶ms); err != nil { return nil, fmt.Errorf("classification: assign id: %v", err) } params.Status = models.ClassificationStatusRunning params.Meta = &models.ClassificationMeta{ Started: strfmt.DateTime(time.Now()), } if err := c.repo.Put(ctx, params); err != nil { return nil, fmt.Errorf("classification: put: %v", err) } // asynchronously trigger the classification filters, err := c.extractFilters(params) if err != nil { return nil, err } go c.run(params, filters) return ¶ms, nil } func (c *Classifier) extractFilters(params models.Classification) (Filters, error) { if params.Filters == nil { return classificationFilters{}, nil } source, err := filterext.Parse(params.Filters.SourceWhere, params.Class) if err != nil { return classificationFilters{}, fmt.Errorf("field 'sourceWhere': %v", err) } trainingSet, err := filterext.Parse(params.Filters.TrainingSetWhere, params.Class) if err != nil { return classificationFilters{}, fmt.Errorf("field 'trainingSetWhere': %v", err) } target, err := filterext.Parse(params.Filters.TargetWhere, params.Class) if err != nil { return classificationFilters{}, fmt.Errorf("field 'targetWhere': %v", err) } filters := classificationFilters{ source: source, trainingSet: trainingSet, target: target, } if err = c.validateFilters(¶ms, &filters); err != nil { return nil, err } return filters, nil } func (c *Classifier) validateFilters(params *models.Classification, filters *classificationFilters) (err error) { if params.Type == TypeKNN { if err = c.validateFilter(filters.Source()); err != nil { return fmt.Errorf("invalid sourceWhere: %s", err) } if err = c.validateFilter(filters.TrainingSet()); err != nil { return fmt.Errorf("invalid trainingSetWhere: %s", err) } } if params.Type == TypeContextual || params.Type == TypeZeroShot { if err = c.validateFilter(filters.Source()); err != nil { return fmt.Errorf("invalid sourceWhere: %s", err) } if err = c.validateFilter(filters.Target()); err != nil { return fmt.Errorf("invalid targetWhere: %s", err) } } return } func (c *Classifier) validateFilter(filter *libfilters.LocalFilter) error { if filter == nil { return nil } return libfilters.ValidateFilters(c.schemaGetter.GetSchemaSkipAuth(), filter) } func (c *Classifier) assignNewID(params *models.Classification) error { id, err := uuid.NewRandom() if err != nil { return err } params.ID = strfmt.UUID(id.String()) return nil } func (c *Classifier) Get(ctx context.Context, principal *models.Principal, id strfmt.UUID) (*models.Classification, error) { err := c.authorizer.Authorize(principal, "get", "classifications/*") if err != nil { return nil, err } return c.repo.Get(ctx, id) } func (c *Classifier) parseAndSetDefaults(params *models.Classification) error { if params.Type == "" { defaultType := "knn" params.Type = defaultType } if params.Type == "knn" { if err := c.parseKNNSettings(params); err != nil { return errors.Wrapf(err, "parse knn specific settings") } return nil } if c.modulesProvider != nil { if err := c.modulesProvider.ParseClassifierSettings(params.Type, params); err != nil { return errors.Wrapf(err, "parse %s specific settings", params.Type) } return nil } return nil } func (c *Classifier) parseKNNSettings(params *models.Classification) error { raw := params.Settings settings := &ParamsKNN{} if raw == nil { settings.SetDefaults() params.Settings = settings return nil } asMap, ok := raw.(map[string]interface{}) if !ok { return errors.Errorf("settings must be an object got %T", raw) } v, err := extractNumberFromMap(asMap, "k") if err != nil { return err } settings.K = v settings.SetDefaults() params.Settings = settings return nil } type ParamsKNN struct { K *int32 `json:"k"` } func (params *ParamsKNN) SetDefaults() { if params.K == nil { defaultK := int32(3) params.K = &defaultK } } func extractNumberFromMap(in map[string]interface{}, field string) (*int32, error) { unparsed, present := in[field] if present { parsed, ok := unparsed.(json.Number) if !ok { return nil, errors.Errorf("settings.%s must be number, got %T", field, unparsed) } asInt64, err := parsed.Int64() if err != nil { return nil, errors.Wrapf(err, "settings.%s", field) } asInt32 := int32(asInt64) return &asInt32, nil } return nil, nil }