Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package classification | |
import ( | |
"context" | |
"encoding/json" | |
"github.com/pkg/errors" | |
"github.com/weaviate/weaviate/entities/models" | |
"github.com/weaviate/weaviate/entities/modulecapabilities" | |
) | |
type vectorizer interface { | |
// MultiVectorForWord must keep order, if an item cannot be vectorized, the | |
// element should be explicit nil, not skipped | |
MultiVectorForWord(ctx context.Context, words []string) ([][]float32, error) | |
VectorOnlyForCorpi(ctx context.Context, corpi []string, overrides map[string]string) ([]float32, error) | |
} | |
type Classifier struct { | |
vectorizer vectorizer | |
} | |
func New(vectorizer vectorizer) modulecapabilities.Classifier { | |
return &Classifier{vectorizer: vectorizer} | |
} | |
func (c *Classifier) Name() string { | |
return "text2vec-contextionary-contextual" | |
} | |
func (c *Classifier) ClassifyFn(params modulecapabilities.ClassifyParams) (modulecapabilities.ClassifyItemFn, error) { | |
if c.vectorizer == nil { | |
return nil, errors.Errorf("cannot use text2vec-contextionary-contextual " + | |
"without the respective module") | |
} | |
// 1. do preparation here once | |
preparedContext, err := c.prepareContextualClassification(params.Schema, params.VectorRepo, | |
params.Params, params.Filters, params.UnclassifiedItems) | |
if err != nil { | |
return nil, errors.Wrap(err, "prepare context for text2vec-contextionary-contextual classification") | |
} | |
// 2. use higher order function to inject preparation data so it is then present for each single run | |
return c.makeClassifyItemContextual(params.Schema, preparedContext), nil | |
} | |
func (c *Classifier) ParseClassifierSettings(params *models.Classification) error { | |
raw := params.Settings | |
settings := &ParamsContextual{} | |
if raw == nil { | |
settings.SetDefaults() | |
params.Settings = settings | |
return nil | |
} | |
asMap, ok := raw.(map[string]interface{}) | |
if !ok { | |
return errors.Errorf("settings must be an object got %T", raw) | |
} | |
v, err := c.extractNumberFromMap(asMap, "minimumUsableWords") | |
if err != nil { | |
return err | |
} | |
settings.MinimumUsableWords = v | |
v, err = c.extractNumberFromMap(asMap, "informationGainCutoffPercentile") | |
if err != nil { | |
return err | |
} | |
settings.InformationGainCutoffPercentile = v | |
v, err = c.extractNumberFromMap(asMap, "informationGainMaximumBoost") | |
if err != nil { | |
return err | |
} | |
settings.InformationGainMaximumBoost = v | |
v, err = c.extractNumberFromMap(asMap, "tfidfCutoffPercentile") | |
if err != nil { | |
return err | |
} | |
settings.TfidfCutoffPercentile = v | |
settings.SetDefaults() | |
params.Settings = settings | |
return nil | |
} | |
func (c *Classifier) extractNumberFromMap(in map[string]interface{}, field string) (*int32, error) { | |
unparsed, present := in[field] | |
if present { | |
parsed, ok := unparsed.(json.Number) | |
if !ok { | |
return nil, errors.Errorf("settings.%s must be number, got %T", | |
field, unparsed) | |
} | |
asInt64, err := parsed.Int64() | |
if err != nil { | |
return nil, errors.Wrapf(err, "settings.%s", field) | |
} | |
asInt32 := int32(asInt64) | |
return &asInt32, nil | |
} | |
return nil, nil | |
} | |