KevinStephenson
Adding in weaviate code
b110593
raw
history blame
22.2 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package classification
import (
"context"
"encoding/json"
"errors"
"fmt"
"testing"
"time"
"github.com/go-openapi/strfmt"
"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/weaviate/weaviate/entities/models"
testhelper "github.com/weaviate/weaviate/test/helper"
)
func newNullLogger() *logrus.Logger {
log, _ := test.NewNullLogger()
return log
}
func Test_Classifier_KNN(t *testing.T) {
t.Run("with invalid data", func(t *testing.T) {
sg := &fakeSchemaGetter{testSchema()}
_, err := New(sg, nil, nil, &fakeAuthorizer{}, newNullLogger(), nil).
Schedule(context.Background(), nil, models.Classification{})
assert.NotNil(t, err, "should error with invalid user input")
})
var id strfmt.UUID
// so we can reuse it for follow up requests, such as checking the status
t.Run("with valid data", func(t *testing.T) {
sg := &fakeSchemaGetter{testSchema()}
repo := newFakeClassificationRepo()
authorizer := &fakeAuthorizer{}
vectorRepo := newFakeVectorRepoKNN(testDataToBeClassified(), testDataAlreadyClassified())
classifier := New(sg, repo, vectorRepo, authorizer, newNullLogger(), nil)
params := models.Classification{
Class: "Article",
BasedOnProperties: []string{"description"},
ClassifyProperties: []string{"exactCategory", "mainCategory"},
Settings: map[string]interface{}{
"k": json.Number("1"),
},
}
t.Run("scheduling a classification", func(t *testing.T) {
class, err := classifier.Schedule(context.Background(), nil, params)
require.Nil(t, err, "should not error")
require.NotNil(t, class)
assert.Len(t, class.ID, 36, "an id was assigned")
id = class.ID
})
t.Run("retrieving the same classification by id", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, id, class.ID)
assert.Equal(t, models.ClassificationStatusRunning, class.Status)
})
// TODO: improve by polling instead
time.Sleep(500 * time.Millisecond)
t.Run("status is now completed", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, models.ClassificationStatusCompleted, class.Status)
})
t.Run("the classifier updated the actions with the classified references", func(t *testing.T) {
vectorRepo.Lock()
require.Len(t, vectorRepo.db, 6)
vectorRepo.Unlock()
t.Run("food", func(t *testing.T) {
idArticleFoodOne := "06a1e824-889c-4649-97f9-1ed3fa401d8e"
idArticleFoodTwo := "6402e649-b1e0-40ea-b192-a64eab0d5e56"
checkRef(t, vectorRepo, idArticleFoodOne, "exactCategory", idCategoryFoodAndDrink)
checkRef(t, vectorRepo, idArticleFoodTwo, "mainCategory", idMainCategoryFoodAndDrink)
})
t.Run("politics", func(t *testing.T) {
idArticlePoliticsOne := "75ba35af-6a08-40ae-b442-3bec69b355f9"
idArticlePoliticsTwo := "f850439a-d3cd-4f17-8fbf-5a64405645cd"
checkRef(t, vectorRepo, idArticlePoliticsOne, "exactCategory", idCategoryPolitics)
checkRef(t, vectorRepo, idArticlePoliticsTwo, "mainCategory", idMainCategoryPoliticsAndSociety)
})
t.Run("society", func(t *testing.T) {
idArticleSocietyOne := "a2bbcbdc-76e1-477d-9e72-a6d2cfb50109"
idArticleSocietyTwo := "069410c3-4b9e-4f68-8034-32a066cb7997"
checkRef(t, vectorRepo, idArticleSocietyOne, "exactCategory", idCategorySociety)
checkRef(t, vectorRepo, idArticleSocietyTwo, "mainCategory", idMainCategoryPoliticsAndSociety)
})
})
})
t.Run("when errors occur during classification", func(t *testing.T) {
sg := &fakeSchemaGetter{testSchema()}
repo := newFakeClassificationRepo()
authorizer := &fakeAuthorizer{}
vectorRepo := newFakeVectorRepoKNN(testDataToBeClassified(), testDataAlreadyClassified())
vectorRepo.errorOnAggregate = errors.New("something went wrong")
classifier := New(sg, repo, vectorRepo, authorizer, newNullLogger(), nil)
params := models.Classification{
Class: "Article",
BasedOnProperties: []string{"description"},
ClassifyProperties: []string{"exactCategory", "mainCategory"},
Settings: map[string]interface{}{
"k": json.Number("1"),
},
}
t.Run("scheduling a classification", func(t *testing.T) {
class, err := classifier.Schedule(context.Background(), nil, params)
require.Nil(t, err, "should not error")
require.NotNil(t, class)
assert.Len(t, class.ID, 36, "an id was assigned")
id = class.ID
})
waitForStatusToNoLongerBeRunning(t, classifier, id)
t.Run("status is now failed", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, models.ClassificationStatusFailed, class.Status)
expectedErrStrings := []string{
"classification failed: ",
"classify Article/75ba35af-6a08-40ae-b442-3bec69b355f9: something went wrong",
"classify Article/f850439a-d3cd-4f17-8fbf-5a64405645cd: something went wrong",
"classify Article/a2bbcbdc-76e1-477d-9e72-a6d2cfb50109: something went wrong",
"classify Article/069410c3-4b9e-4f68-8034-32a066cb7997: something went wrong",
"classify Article/06a1e824-889c-4649-97f9-1ed3fa401d8e: something went wrong",
"classify Article/6402e649-b1e0-40ea-b192-a64eab0d5e56: something went wrong",
}
for _, msg := range expectedErrStrings {
assert.Contains(t, class.Error, msg)
}
})
})
t.Run("when there is nothing to be classified", func(t *testing.T) {
sg := &fakeSchemaGetter{testSchema()}
repo := newFakeClassificationRepo()
authorizer := &fakeAuthorizer{}
vectorRepo := newFakeVectorRepoKNN(nil, testDataAlreadyClassified())
classifier := New(sg, repo, vectorRepo, authorizer, newNullLogger(), nil)
params := models.Classification{
Class: "Article",
BasedOnProperties: []string{"description"},
ClassifyProperties: []string{"exactCategory", "mainCategory"},
Settings: map[string]interface{}{
"k": json.Number("1"),
},
}
t.Run("scheduling a classification", func(t *testing.T) {
class, err := classifier.Schedule(context.Background(), nil, params)
require.Nil(t, err, "should not error")
require.NotNil(t, class)
assert.Len(t, class.ID, 36, "an id was assigned")
id = class.ID
})
waitForStatusToNoLongerBeRunning(t, classifier, id)
t.Run("status is now failed", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, models.ClassificationStatusFailed, class.Status)
expectedErr := "classification failed: " +
"no classes to be classified - did you run a previous classification already?"
assert.Equal(t, expectedErr, class.Error)
})
})
}
func Test_Classifier_Custom_Classifier(t *testing.T) {
var id strfmt.UUID
// so we can reuse it for follow up requests, such as checking the status
t.Run("with unreconginzed custom module classifier name", func(t *testing.T) {
sg := &fakeSchemaGetter{testSchema()}
repo := newFakeClassificationRepo()
authorizer := &fakeAuthorizer{}
vectorRepo := newFakeVectorRepoContextual(testDataToBeClassified(), testDataPossibleTargets())
logger, _ := test.NewNullLogger()
// vectorizer := &fakeVectorizer{words: testDataVectors()}
modulesProvider := NewFakeModulesProvider()
classifier := New(sg, repo, vectorRepo, authorizer, logger, modulesProvider)
notRecoginzedContextual := "text2vec-contextionary-custom-not-recognized"
params := models.Classification{
Class: "Article",
BasedOnProperties: []string{"description"},
ClassifyProperties: []string{"exactCategory", "mainCategory"},
Type: notRecoginzedContextual,
}
t.Run("scheduling an unrecognized classification", func(t *testing.T) {
class, err := classifier.Schedule(context.Background(), nil, params)
require.Nil(t, err, "should not error")
require.NotNil(t, class)
assert.Len(t, class.ID, 36, "an id was assigned")
id = class.ID
})
t.Run("retrieving the same classification by id", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, id, class.ID)
})
// TODO: improve by polling instead
time.Sleep(500 * time.Millisecond)
t.Run("status is failed", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, models.ClassificationStatusFailed, class.Status)
assert.Equal(t, notRecoginzedContextual, class.Type)
assert.Contains(t, class.Error, "classifier "+notRecoginzedContextual+" not found")
})
})
t.Run("with valid data", func(t *testing.T) {
sg := &fakeSchemaGetter{testSchema()}
repo := newFakeClassificationRepo()
authorizer := &fakeAuthorizer{}
vectorRepo := newFakeVectorRepoContextual(testDataToBeClassified(), testDataPossibleTargets())
logger, _ := test.NewNullLogger()
modulesProvider := NewFakeModulesProvider()
classifier := New(sg, repo, vectorRepo, authorizer, logger, modulesProvider)
contextual := "text2vec-contextionary-custom-contextual"
params := models.Classification{
Class: "Article",
BasedOnProperties: []string{"description"},
ClassifyProperties: []string{"exactCategory", "mainCategory"},
Type: contextual,
}
t.Run("scheduling a classification", func(t *testing.T) {
class, err := classifier.Schedule(context.Background(), nil, params)
require.Nil(t, err, "should not error")
require.NotNil(t, class)
assert.Len(t, class.ID, 36, "an id was assigned")
id = class.ID
})
t.Run("retrieving the same classification by id", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, id, class.ID)
})
// TODO: improve by polling instead
time.Sleep(500 * time.Millisecond)
t.Run("status is now completed", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, models.ClassificationStatusCompleted, class.Status)
})
t.Run("the classifier updated the actions with the classified references", func(t *testing.T) {
vectorRepo.Lock()
require.Len(t, vectorRepo.db, 6)
vectorRepo.Unlock()
t.Run("food", func(t *testing.T) {
idArticleFoodOne := "06a1e824-889c-4649-97f9-1ed3fa401d8e"
idArticleFoodTwo := "6402e649-b1e0-40ea-b192-a64eab0d5e56"
checkRef(t, vectorRepo, idArticleFoodOne, "exactCategory", idCategoryFoodAndDrink)
checkRef(t, vectorRepo, idArticleFoodTwo, "mainCategory", idMainCategoryFoodAndDrink)
})
t.Run("politics", func(t *testing.T) {
idArticlePoliticsOne := "75ba35af-6a08-40ae-b442-3bec69b355f9"
idArticlePoliticsTwo := "f850439a-d3cd-4f17-8fbf-5a64405645cd"
checkRef(t, vectorRepo, idArticlePoliticsOne, "exactCategory", idCategoryPolitics)
checkRef(t, vectorRepo, idArticlePoliticsTwo, "mainCategory", idMainCategoryPoliticsAndSociety)
})
t.Run("society", func(t *testing.T) {
idArticleSocietyOne := "a2bbcbdc-76e1-477d-9e72-a6d2cfb50109"
idArticleSocietyTwo := "069410c3-4b9e-4f68-8034-32a066cb7997"
checkRef(t, vectorRepo, idArticleSocietyOne, "exactCategory", idCategorySociety)
checkRef(t, vectorRepo, idArticleSocietyTwo, "mainCategory", idMainCategoryPoliticsAndSociety)
})
})
})
t.Run("when errors occur during classification", func(t *testing.T) {
sg := &fakeSchemaGetter{testSchema()}
repo := newFakeClassificationRepo()
authorizer := &fakeAuthorizer{}
vectorRepo := newFakeVectorRepoKNN(testDataToBeClassified(), testDataAlreadyClassified())
vectorRepo.errorOnAggregate = errors.New("something went wrong")
logger, _ := test.NewNullLogger()
classifier := New(sg, repo, vectorRepo, authorizer, logger, nil)
params := models.Classification{
Class: "Article",
BasedOnProperties: []string{"description"},
ClassifyProperties: []string{"exactCategory", "mainCategory"},
Settings: map[string]interface{}{
"k": json.Number("1"),
},
}
t.Run("scheduling a classification", func(t *testing.T) {
class, err := classifier.Schedule(context.Background(), nil, params)
require.Nil(t, err, "should not error")
require.NotNil(t, class)
assert.Len(t, class.ID, 36, "an id was assigned")
id = class.ID
})
waitForStatusToNoLongerBeRunning(t, classifier, id)
t.Run("status is now failed", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, models.ClassificationStatusFailed, class.Status)
expectedErrStrings := []string{
"classification failed: ",
"classify Article/75ba35af-6a08-40ae-b442-3bec69b355f9: something went wrong",
"classify Article/f850439a-d3cd-4f17-8fbf-5a64405645cd: something went wrong",
"classify Article/a2bbcbdc-76e1-477d-9e72-a6d2cfb50109: something went wrong",
"classify Article/069410c3-4b9e-4f68-8034-32a066cb7997: something went wrong",
"classify Article/06a1e824-889c-4649-97f9-1ed3fa401d8e: something went wrong",
"classify Article/6402e649-b1e0-40ea-b192-a64eab0d5e56: something went wrong",
}
for _, msg := range expectedErrStrings {
assert.Contains(t, class.Error, msg)
}
})
})
t.Run("when there is nothing to be classified", func(t *testing.T) {
sg := &fakeSchemaGetter{testSchema()}
repo := newFakeClassificationRepo()
authorizer := &fakeAuthorizer{}
vectorRepo := newFakeVectorRepoKNN(nil, testDataAlreadyClassified())
logger, _ := test.NewNullLogger()
classifier := New(sg, repo, vectorRepo, authorizer, logger, nil)
params := models.Classification{
Class: "Article",
BasedOnProperties: []string{"description"},
ClassifyProperties: []string{"exactCategory", "mainCategory"},
Settings: map[string]interface{}{
"k": json.Number("1"),
},
}
t.Run("scheduling a classification", func(t *testing.T) {
class, err := classifier.Schedule(context.Background(), nil, params)
require.Nil(t, err, "should not error")
require.NotNil(t, class)
assert.Len(t, class.ID, 36, "an id was assigned")
id = class.ID
})
waitForStatusToNoLongerBeRunning(t, classifier, id)
t.Run("status is now failed", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, models.ClassificationStatusFailed, class.Status)
expectedErr := "classification failed: " +
"no classes to be classified - did you run a previous classification already?"
assert.Equal(t, expectedErr, class.Error)
})
})
}
func Test_Classifier_WhereFilterValidation(t *testing.T) {
t.Run("when invalid whereFilters are received", func(t *testing.T) {
sg := &fakeSchemaGetter{testSchema()}
repo := newFakeClassificationRepo()
authorizer := &fakeAuthorizer{}
vectorRepo := newFakeVectorRepoKNN(testDataToBeClassified(), testDataAlreadyClassified())
classifier := New(sg, repo, vectorRepo, authorizer, newNullLogger(), nil)
t.Run("with only one of the where filters being set", func(t *testing.T) {
whereFilter := &models.WhereFilter{
Path: []string{"id"},
Operator: "Like",
ValueText: ptString("*"),
}
testData := []struct {
name string
classificationType string
classificationFilters *models.ClassificationFilters
}{
{
name: "Contextual only source where filter set",
classificationType: TypeContextual,
classificationFilters: &models.ClassificationFilters{
SourceWhere: whereFilter,
},
},
{
name: "Contextual only target where filter set",
classificationType: TypeContextual,
classificationFilters: &models.ClassificationFilters{
TargetWhere: whereFilter,
},
},
{
name: "ZeroShot only source where filter set",
classificationType: TypeZeroShot,
classificationFilters: &models.ClassificationFilters{
SourceWhere: whereFilter,
},
},
{
name: "ZeroShot only target where filter set",
classificationType: TypeZeroShot,
classificationFilters: &models.ClassificationFilters{
TargetWhere: whereFilter,
},
},
{
name: "KNN only source where filter set",
classificationType: TypeKNN,
classificationFilters: &models.ClassificationFilters{
SourceWhere: whereFilter,
},
},
{
name: "KNN only training set where filter set",
classificationType: TypeKNN,
classificationFilters: &models.ClassificationFilters{
TrainingSetWhere: whereFilter,
},
},
}
for _, td := range testData {
t.Run(td.name, func(t *testing.T) {
params := models.Classification{
Class: "Article",
BasedOnProperties: []string{"description"},
ClassifyProperties: []string{"exactCategory", "mainCategory"},
Settings: map[string]interface{}{
"k": json.Number("1"),
},
Type: td.classificationType,
Filters: td.classificationFilters,
}
class, err := classifier.Schedule(context.Background(), nil, params)
assert.Nil(t, err)
assert.NotNil(t, class)
assert.Len(t, class.ID, 36, "an id was assigned")
waitForStatusToNoLongerBeRunning(t, classifier, class.ID)
})
}
})
})
t.Run("[deprecated string] when valueString whereFilters are received", func(t *testing.T) {
sg := &fakeSchemaGetter{testSchema()}
repo := newFakeClassificationRepo()
authorizer := &fakeAuthorizer{}
vectorRepo := newFakeVectorRepoKNN(testDataToBeClassified(), testDataAlreadyClassified())
classifier := New(sg, repo, vectorRepo, authorizer, newNullLogger(), nil)
validFilter := &models.WhereFilter{
Path: []string{"description"},
Operator: "Equal",
ValueText: ptString("valueText is valid"),
}
deprecatedFilter := &models.WhereFilter{
Path: []string{"description"},
Operator: "Equal",
ValueString: ptString("valueString is accepted"),
}
t.Run("with deprecated sourceFilter", func(t *testing.T) {
params := models.Classification{
Class: "Article",
BasedOnProperties: []string{"description"},
ClassifyProperties: []string{"exactCategory", "mainCategory"},
Settings: map[string]interface{}{
"k": json.Number("1"),
},
Filters: &models.ClassificationFilters{
SourceWhere: deprecatedFilter,
},
Type: TypeContextual,
}
_, err := classifier.Schedule(context.Background(), nil, params)
assert.Nil(t, err)
})
t.Run("with deprecated targetFilter", func(t *testing.T) {
params := models.Classification{
Class: "Article",
BasedOnProperties: []string{"description"},
ClassifyProperties: []string{"exactCategory", "mainCategory"},
Settings: map[string]interface{}{
"k": json.Number("1"),
},
Filters: &models.ClassificationFilters{
SourceWhere: validFilter,
TargetWhere: deprecatedFilter,
},
Type: TypeContextual,
}
_, err := classifier.Schedule(context.Background(), nil, params)
assert.Nil(t, err)
})
t.Run("with deprecated trainingFilter", func(t *testing.T) {
params := models.Classification{
Class: "Article",
BasedOnProperties: []string{"description"},
ClassifyProperties: []string{"exactCategory", "mainCategory"},
Settings: map[string]interface{}{
"k": json.Number("1"),
},
Filters: &models.ClassificationFilters{
SourceWhere: validFilter,
TrainingSetWhere: deprecatedFilter,
},
Type: TypeKNN,
}
_, err := classifier.Schedule(context.Background(), nil, params)
assert.Nil(t, err)
})
})
}
type genericFakeRepo interface {
get(strfmt.UUID) (*models.Object, bool)
}
func checkRef(t *testing.T, repo genericFakeRepo, source, propName, target string) {
object, ok := repo.get(strfmt.UUID(source))
require.True(t, ok, "object must be present")
schema, ok := object.Properties.(map[string]interface{})
require.True(t, ok, "schema must be map")
prop, ok := schema[propName]
require.True(t, ok, "ref prop must be present")
refs, ok := prop.(models.MultipleRef)
require.True(t, ok, "ref prop must be models.MultipleRef")
require.Len(t, refs, 1, "refs must have len 1")
assert.Equal(t, fmt.Sprintf("weaviate://localhost/%s", target), refs[0].Beacon.String(), "beacon must match")
}
func waitForStatusToNoLongerBeRunning(t *testing.T, classifier *Classifier, id strfmt.UUID) {
testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, true, func() interface{} {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
return class.Status != models.ClassificationStatusRunning
}, 100*time.Millisecond, 20*time.Second, "wait until status in no longer running")
}