File size: 3,443 Bytes
b110593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
//                           _       _
// __      _____  __ ___   ___  __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
//  CONTACT: [email protected]
//

package classification

import (
	"context"
	"encoding/json"

	"github.com/pkg/errors"
	"github.com/weaviate/weaviate/entities/models"
	"github.com/weaviate/weaviate/entities/modulecapabilities"
)

type vectorizer interface {
	// MultiVectorForWord must keep order, if an item cannot be vectorized, the
	// element should be explicit nil, not skipped
	MultiVectorForWord(ctx context.Context, words []string) ([][]float32, error)
	VectorOnlyForCorpi(ctx context.Context, corpi []string, overrides map[string]string) ([]float32, error)
}

type Classifier struct {
	vectorizer vectorizer
}

func New(vectorizer vectorizer) modulecapabilities.Classifier {
	return &Classifier{vectorizer: vectorizer}
}

func (c *Classifier) Name() string {
	return "text2vec-contextionary-contextual"
}

func (c *Classifier) ClassifyFn(params modulecapabilities.ClassifyParams) (modulecapabilities.ClassifyItemFn, error) {
	if c.vectorizer == nil {
		return nil, errors.Errorf("cannot use text2vec-contextionary-contextual " +
			"without the respective module")
	}

	// 1. do preparation here once
	preparedContext, err := c.prepareContextualClassification(params.Schema, params.VectorRepo,
		params.Params, params.Filters, params.UnclassifiedItems)
	if err != nil {
		return nil, errors.Wrap(err, "prepare context for text2vec-contextionary-contextual classification")
	}

	// 2. use higher order function to inject preparation data so it is then present for each single run
	return c.makeClassifyItemContextual(params.Schema, preparedContext), nil
}

func (c *Classifier) ParseClassifierSettings(params *models.Classification) error {
	raw := params.Settings
	settings := &ParamsContextual{}
	if raw == nil {
		settings.SetDefaults()
		params.Settings = settings
		return nil
	}

	asMap, ok := raw.(map[string]interface{})
	if !ok {
		return errors.Errorf("settings must be an object got %T", raw)
	}

	v, err := c.extractNumberFromMap(asMap, "minimumUsableWords")
	if err != nil {
		return err
	}
	settings.MinimumUsableWords = v

	v, err = c.extractNumberFromMap(asMap, "informationGainCutoffPercentile")
	if err != nil {
		return err
	}
	settings.InformationGainCutoffPercentile = v

	v, err = c.extractNumberFromMap(asMap, "informationGainMaximumBoost")
	if err != nil {
		return err
	}
	settings.InformationGainMaximumBoost = v

	v, err = c.extractNumberFromMap(asMap, "tfidfCutoffPercentile")
	if err != nil {
		return err
	}
	settings.TfidfCutoffPercentile = v

	settings.SetDefaults()
	params.Settings = settings

	return nil
}

func (c *Classifier) extractNumberFromMap(in map[string]interface{}, field string) (*int32, error) {
	unparsed, present := in[field]
	if present {
		parsed, ok := unparsed.(json.Number)
		if !ok {
			return nil, errors.Errorf("settings.%s must be number, got %T",
				field, unparsed)
		}

		asInt64, err := parsed.Int64()
		if err != nil {
			return nil, errors.Wrapf(err, "settings.%s", field)
		}

		asInt32 := int32(asInt64)
		return &asInt32, nil
	}

	return nil, nil
}