Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package vectorizer | |
import ( | |
"encoding/json" | |
"fmt" | |
"strconv" | |
"github.com/pkg/errors" | |
"github.com/weaviate/weaviate/entities/moduletools" | |
) | |
type classSettings struct { | |
cfg moduletools.ClassConfig | |
} | |
func NewClassSettings(cfg moduletools.ClassConfig) *classSettings { | |
return &classSettings{cfg: cfg} | |
} | |
func (ic *classSettings) ImageField(property string) bool { | |
return ic.field("imageFields", property) | |
} | |
func (ic *classSettings) ImageFieldsWeights() ([]float32, error) { | |
return ic.getFieldsWeights("image") | |
} | |
func (ic *classSettings) TextField(property string) bool { | |
return ic.field("textFields", property) | |
} | |
func (ic *classSettings) TextFieldsWeights() ([]float32, error) { | |
return ic.getFieldsWeights("text") | |
} | |
func (ic *classSettings) field(name, property string) bool { | |
if ic.cfg == nil { | |
// we would receive a nil-config on cross-class requests, such as Explore{} | |
return false | |
} | |
fields, ok := ic.cfg.Class()[name] | |
if !ok { | |
return false | |
} | |
fieldsArray, ok := fields.([]interface{}) | |
if !ok { | |
return false | |
} | |
fieldNames := make([]string, len(fieldsArray)) | |
for i, value := range fieldsArray { | |
fieldNames[i] = value.(string) | |
} | |
for i := range fieldNames { | |
if fieldNames[i] == property { | |
return true | |
} | |
} | |
return false | |
} | |
func (ic *classSettings) Validate() error { | |
if ic.cfg == nil { | |
// we would receive a nil-config on cross-class requests, such as Explore{} | |
return errors.New("empty config") | |
} | |
imageFields, imageFieldsOk := ic.cfg.Class()["imageFields"] | |
textFields, textFieldsOk := ic.cfg.Class()["textFields"] | |
if !imageFieldsOk && !textFieldsOk { | |
return errors.New("textFields or imageFields setting needs to be present") | |
} | |
if imageFieldsOk { | |
imageFieldsCount, err := ic.validateFields("image", imageFields) | |
if err != nil { | |
return err | |
} | |
err = ic.validateWeights("image", imageFieldsCount) | |
if err != nil { | |
return err | |
} | |
} | |
if textFieldsOk { | |
textFieldsCount, err := ic.validateFields("text", textFields) | |
if err != nil { | |
return err | |
} | |
err = ic.validateWeights("text", textFieldsCount) | |
if err != nil { | |
return err | |
} | |
} | |
return nil | |
} | |
func (ic *classSettings) validateFields(name string, fields interface{}) (int, error) { | |
fieldsArray, ok := fields.([]interface{}) | |
if !ok { | |
return 0, errors.Errorf("%sFields must be an array", name) | |
} | |
if len(fieldsArray) == 0 { | |
return 0, errors.Errorf("must contain at least one %s field name in %sFields", name, name) | |
} | |
for _, value := range fieldsArray { | |
v, ok := value.(string) | |
if !ok { | |
return 0, errors.Errorf("%sField must be a string", name) | |
} | |
if len(v) == 0 { | |
return 0, errors.Errorf("%sField values cannot be empty", name) | |
} | |
} | |
return len(fieldsArray), nil | |
} | |
func (ic *classSettings) validateWeights(name string, count int) error { | |
weights, ok := ic.getWeights(name) | |
if ok { | |
if len(weights) != count { | |
return errors.Errorf("weights.%sFields does not equal number of %sFields", name, name) | |
} | |
_, err := ic.getWeightsArray(weights) | |
if err != nil { | |
return err | |
} | |
} | |
return nil | |
} | |
func (ic *classSettings) getWeights(name string) ([]interface{}, bool) { | |
weights, ok := ic.cfg.Class()["weights"] | |
if ok { | |
weightsObject, ok := weights.(map[string]interface{}) | |
if ok { | |
fieldWeights, ok := weightsObject[fmt.Sprintf("%sFields", name)] | |
if ok { | |
fieldWeightsArray, ok := fieldWeights.([]interface{}) | |
if ok { | |
return fieldWeightsArray, ok | |
} | |
} | |
} | |
} | |
return nil, false | |
} | |
func (ic *classSettings) getWeightsArray(weights []interface{}) ([]float32, error) { | |
weightsArray := make([]float32, len(weights)) | |
for i := range weights { | |
weight, err := ic.getNumber(weights[i]) | |
if err != nil { | |
return nil, err | |
} | |
weightsArray[i] = weight | |
} | |
return weightsArray, nil | |
} | |
func (ic *classSettings) getFieldsWeights(name string) ([]float32, error) { | |
weights, ok := ic.getWeights(name) | |
if ok { | |
return ic.getWeightsArray(weights) | |
} | |
return nil, nil | |
} | |
func (ic *classSettings) getNumber(in interface{}) (float32, error) { | |
switch i := in.(type) { | |
case float64: | |
return float32(i), nil | |
case float32: | |
return i, nil | |
case int: | |
return float32(i), nil | |
case string: | |
num, err := strconv.ParseFloat(i, 64) | |
if err != nil { | |
return 0, err | |
} | |
return float32(num), err | |
case json.Number: | |
num, err := i.Float64() | |
if err != nil { | |
return 0, err | |
} | |
return float32(num), err | |
default: | |
return 0.0, errors.Errorf("Unrecognized weight entry type: %T", i) | |
} | |
} | |