KevinStephenson
Adding in weaviate code
b110593
raw
history blame
3.44 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package classification
import (
"context"
"encoding/json"
"github.com/pkg/errors"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/modulecapabilities"
)
type vectorizer interface {
// MultiVectorForWord must keep order, if an item cannot be vectorized, the
// element should be explicit nil, not skipped
MultiVectorForWord(ctx context.Context, words []string) ([][]float32, error)
VectorOnlyForCorpi(ctx context.Context, corpi []string, overrides map[string]string) ([]float32, error)
}
type Classifier struct {
vectorizer vectorizer
}
func New(vectorizer vectorizer) modulecapabilities.Classifier {
return &Classifier{vectorizer: vectorizer}
}
func (c *Classifier) Name() string {
return "text2vec-contextionary-contextual"
}
func (c *Classifier) ClassifyFn(params modulecapabilities.ClassifyParams) (modulecapabilities.ClassifyItemFn, error) {
if c.vectorizer == nil {
return nil, errors.Errorf("cannot use text2vec-contextionary-contextual " +
"without the respective module")
}
// 1. do preparation here once
preparedContext, err := c.prepareContextualClassification(params.Schema, params.VectorRepo,
params.Params, params.Filters, params.UnclassifiedItems)
if err != nil {
return nil, errors.Wrap(err, "prepare context for text2vec-contextionary-contextual classification")
}
// 2. use higher order function to inject preparation data so it is then present for each single run
return c.makeClassifyItemContextual(params.Schema, preparedContext), nil
}
func (c *Classifier) ParseClassifierSettings(params *models.Classification) error {
raw := params.Settings
settings := &ParamsContextual{}
if raw == nil {
settings.SetDefaults()
params.Settings = settings
return nil
}
asMap, ok := raw.(map[string]interface{})
if !ok {
return errors.Errorf("settings must be an object got %T", raw)
}
v, err := c.extractNumberFromMap(asMap, "minimumUsableWords")
if err != nil {
return err
}
settings.MinimumUsableWords = v
v, err = c.extractNumberFromMap(asMap, "informationGainCutoffPercentile")
if err != nil {
return err
}
settings.InformationGainCutoffPercentile = v
v, err = c.extractNumberFromMap(asMap, "informationGainMaximumBoost")
if err != nil {
return err
}
settings.InformationGainMaximumBoost = v
v, err = c.extractNumberFromMap(asMap, "tfidfCutoffPercentile")
if err != nil {
return err
}
settings.TfidfCutoffPercentile = v
settings.SetDefaults()
params.Settings = settings
return nil
}
func (c *Classifier) extractNumberFromMap(in map[string]interface{}, field string) (*int32, error) {
unparsed, present := in[field]
if present {
parsed, ok := unparsed.(json.Number)
if !ok {
return nil, errors.Errorf("settings.%s must be number, got %T",
field, unparsed)
}
asInt64, err := parsed.Int64()
if err != nil {
return nil, errors.Wrapf(err, "settings.%s", field)
}
asInt32 := int32(asInt64)
return &asInt32, nil
}
return nil, nil
}