KevinStephenson
Adding in weaviate code
b110593
raw
history blame
5.32 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package classification
import (
"fmt"
"github.com/weaviate/weaviate/entities/errorcompounder"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/schema"
schemaUC "github.com/weaviate/weaviate/usecases/schema"
)
const (
TypeKNN = "knn"
TypeContextual = "text2vec-contextionary-contextual"
TypeZeroShot = "zeroshot"
)
type Validator struct {
schema schema.Schema
errors *errorcompounder.SafeErrorCompounder
subject models.Classification
}
func NewValidator(sg schemaUC.SchemaGetter, subject models.Classification) *Validator {
schema := sg.GetSchemaSkipAuth()
return &Validator{
schema: schema,
errors: &errorcompounder.SafeErrorCompounder{},
subject: subject,
}
}
func (v *Validator) Do() error {
v.validate()
err := v.errors.ToError()
if err != nil {
return fmt.Errorf("invalid classification: %v", err)
}
return nil
}
func (v *Validator) validate() {
if v.subject.Class == "" {
v.errors.Add(fmt.Errorf("class must be set"))
return
}
class := v.schema.FindClassByName(schema.ClassName(v.subject.Class))
if class == nil {
v.errors.Addf("class '%s' not found in schema", v.subject.Class)
return
}
v.contextualTypeFeasibility()
v.knnTypeFeasibility()
v.basedOnProperties(class)
v.classifyProperties(class)
}
func (v *Validator) contextualTypeFeasibility() {
if !v.typeText2vecContextionaryContextual() {
return
}
if v.subject.Filters != nil && v.subject.Filters.TrainingSetWhere != nil {
v.errors.Addf("type is 'text2vec-contextionary-contextual', but 'trainingSetWhere' filter is set, for 'text2vec-contextionary-contextual' there is no training data, instead limit possible target data directly through setting 'targetWhere'")
}
}
func (v *Validator) knnTypeFeasibility() {
if !v.typeKNN() {
return
}
if v.subject.Filters != nil && v.subject.Filters.TargetWhere != nil {
v.errors.Addf("type is 'knn', but 'targetWhere' filter is set, for 'knn' you cannot limit target data directly, instead limit training data through setting 'trainingSetWhere'")
}
}
func (v *Validator) basedOnProperties(class *models.Class) {
if v.subject.BasedOnProperties == nil || len(v.subject.BasedOnProperties) == 0 {
v.errors.Addf("basedOnProperties must have at least one property")
return
}
if len(v.subject.BasedOnProperties) > 1 {
v.errors.Addf("only a single property in basedOnProperties supported at the moment, got %v",
v.subject.BasedOnProperties)
return
}
for _, prop := range v.subject.BasedOnProperties {
v.basedOnProperty(class, prop)
}
}
func (v *Validator) basedOnProperty(class *models.Class, propName string) {
prop, ok := v.propertyByName(class, propName)
if !ok {
v.errors.Addf("basedOnProperties: property '%s' does not exist", propName)
return
}
dt, err := v.schema.FindPropertyDataType(prop.DataType)
if err != nil {
v.errors.Addf("basedOnProperties: %v", err)
return
}
if !dt.IsPrimitive() {
v.errors.Addf("basedOnProperties: property '%s' must be of type 'text'", propName)
return
}
if dt.AsPrimitive() != schema.DataTypeText {
v.errors.Addf("basedOnProperties: property '%s' must be of type 'text'", propName)
return
}
}
func (v *Validator) classifyProperties(class *models.Class) {
if v.subject.ClassifyProperties == nil || len(v.subject.ClassifyProperties) == 0 {
v.errors.Addf("classifyProperties must have at least one property")
return
}
for _, prop := range v.subject.ClassifyProperties {
v.classifyProperty(class, prop)
}
}
func (v *Validator) classifyProperty(class *models.Class, propName string) {
prop, ok := v.propertyByName(class, propName)
if !ok {
v.errors.Addf("classifyProperties: property '%s' does not exist", propName)
return
}
dt, err := v.schema.FindPropertyDataType(prop.DataType)
if err != nil {
v.errors.Addf("classifyProperties: %v", err)
return
}
if !dt.IsReference() {
v.errors.Addf("classifyProperties: property '%s' must be of reference type (cref)", propName)
return
}
if v.typeText2vecContextionaryContextual() {
if len(dt.Classes()) > 1 {
v.errors.Addf("classifyProperties: property '%s'"+
" has more than one target class, classification of type 'text2vec-contextionary-contextual' requires exactly one target class", propName)
return
}
}
}
func (v *Validator) propertyByName(class *models.Class, propName string) (*models.Property, bool) {
for _, prop := range class.Properties {
if prop.Name == propName {
return prop, true
}
}
return nil, false
}
func (v *Validator) typeText2vecContextionaryContextual() bool {
if v.subject.Type == "" {
return false
}
return v.subject.Type == TypeContextual
}
func (v *Validator) typeKNN() bool {
if v.subject.Type == "" {
return true
}
return v.subject.Type == TypeKNN
}