KevinStephenson
Adding in weaviate code
b110593
raw
history blame
12.5 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package classification
import (
"context"
"encoding/json"
"fmt"
"strings"
"testing"
"time"
"github.com/go-openapi/strfmt"
"github.com/pkg/errors"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/schema/crossref"
testhelper "github.com/weaviate/weaviate/test/helper"
usecasesclassfication "github.com/weaviate/weaviate/usecases/classification"
)
func TestContextualClassifier_ParseSettings(t *testing.T) {
t.Run("should parse with default values with empty settings are passed", func(t *testing.T) {
// given
classifier := New(&fakeVectorizer{})
params := &models.Classification{
Class: "Article",
BasedOnProperties: []string{"description"},
ClassifyProperties: []string{"exactCategory", "mainCategory"},
Type: "text2vec-contextionary-contextual",
}
// when
err := classifier.ParseClassifierSettings(params)
// then
assert.Nil(t, err)
settings := params.Settings
assert.NotNil(t, settings)
paramsContextual, ok := settings.(*ParamsContextual)
assert.NotNil(t, paramsContextual)
assert.True(t, ok)
assert.Equal(t, int32(3), *paramsContextual.MinimumUsableWords)
assert.Equal(t, int32(50), *paramsContextual.InformationGainCutoffPercentile)
assert.Equal(t, int32(3), *paramsContextual.InformationGainMaximumBoost)
assert.Equal(t, int32(80), *paramsContextual.TfidfCutoffPercentile)
})
t.Run("should parse classifier settings", func(t *testing.T) {
// given
classifier := New(&fakeVectorizer{})
params := &models.Classification{
Class: "Article",
BasedOnProperties: []string{"description"},
ClassifyProperties: []string{"exactCategory", "mainCategory"},
Type: "text2vec-contextionary-contextual",
Settings: map[string]interface{}{
"minimumUsableWords": json.Number("1"),
"informationGainCutoffPercentile": json.Number("2"),
"informationGainMaximumBoost": json.Number("3"),
"tfidfCutoffPercentile": json.Number("4"),
},
}
// when
err := classifier.ParseClassifierSettings(params)
// then
assert.Nil(t, err)
assert.NotNil(t, params.Settings)
settings, ok := params.Settings.(*ParamsContextual)
assert.NotNil(t, settings)
assert.True(t, ok)
assert.Equal(t, int32(1), *settings.MinimumUsableWords)
assert.Equal(t, int32(2), *settings.InformationGainCutoffPercentile)
assert.Equal(t, int32(3), *settings.InformationGainMaximumBoost)
assert.Equal(t, int32(4), *settings.TfidfCutoffPercentile)
})
}
func TestContextualClassifier_Classify(t *testing.T) {
var id strfmt.UUID
// so we can reuse it for follow up requests, such as checking the status
t.Run("with valid data", func(t *testing.T) {
sg := &fakeSchemaGetter{testSchema()}
repo := newFakeClassificationRepo()
authorizer := &fakeAuthorizer{}
vectorRepo := newFakeVectorRepoContextual(testDataToBeClassified(), testDataPossibleTargets())
logger, _ := test.NewNullLogger()
vectorizer := &fakeVectorizer{words: testDataVectors()}
modulesProvider := NewFakeModulesProvider(vectorizer)
classifier := usecasesclassfication.New(sg, repo, vectorRepo, authorizer, logger, modulesProvider)
contextual := "text2vec-contextionary-contextual"
params := models.Classification{
Class: "Article",
BasedOnProperties: []string{"description"},
ClassifyProperties: []string{"exactCategory", "mainCategory"},
Type: contextual,
}
t.Run("scheduling a classification", func(t *testing.T) {
class, err := classifier.Schedule(context.Background(), nil, params)
require.Nil(t, err, "should not error")
require.NotNil(t, class)
assert.Len(t, class.ID, 36, "an id was assigned")
id = class.ID
})
t.Run("retrieving the same classification by id", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, id, class.ID)
})
// TODO: improve by polling instead
time.Sleep(500 * time.Millisecond)
t.Run("status is now completed", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, models.ClassificationStatusCompleted, class.Status)
})
t.Run("the classifier updated the actions with the classified references", func(t *testing.T) {
vectorRepo.Lock()
require.Len(t, vectorRepo.db, 6)
vectorRepo.Unlock()
t.Run("food", func(t *testing.T) {
idArticleFoodOne := "06a1e824-889c-4649-97f9-1ed3fa401d8e"
idArticleFoodTwo := "6402e649-b1e0-40ea-b192-a64eab0d5e56"
checkRef(t, vectorRepo, idArticleFoodOne, "ExactCategory", "exactCategory", idCategoryFoodAndDrink)
checkRef(t, vectorRepo, idArticleFoodTwo, "MainCategory", "mainCategory", idMainCategoryFoodAndDrink)
})
t.Run("politics", func(t *testing.T) {
idArticlePoliticsOne := "75ba35af-6a08-40ae-b442-3bec69b355f9"
idArticlePoliticsTwo := "f850439a-d3cd-4f17-8fbf-5a64405645cd"
checkRef(t, vectorRepo, idArticlePoliticsOne, "ExactCategory", "exactCategory", idCategoryPolitics)
checkRef(t, vectorRepo, idArticlePoliticsTwo, "MainCategory", "mainCategory", idMainCategoryPoliticsAndSociety)
})
t.Run("society", func(t *testing.T) {
idArticleSocietyOne := "a2bbcbdc-76e1-477d-9e72-a6d2cfb50109"
idArticleSocietyTwo := "069410c3-4b9e-4f68-8034-32a066cb7997"
checkRef(t, vectorRepo, idArticleSocietyOne, "ExactCategory", "exactCategory", idCategorySociety)
checkRef(t, vectorRepo, idArticleSocietyTwo, "MainCategory", "mainCategory", idMainCategoryPoliticsAndSociety)
})
})
})
t.Run("when errors occur during classification", func(t *testing.T) {
sg := &fakeSchemaGetter{testSchema()}
repo := newFakeClassificationRepo()
authorizer := &fakeAuthorizer{}
vectorRepo := newFakeVectorRepoKNN(testDataToBeClassified(), testDataAlreadyClassified())
vectorRepo.errorOnAggregate = errors.New("something went wrong")
logger, _ := test.NewNullLogger()
classifier := usecasesclassfication.New(sg, repo, vectorRepo, authorizer, logger, nil)
params := models.Classification{
Class: "Article",
BasedOnProperties: []string{"description"},
ClassifyProperties: []string{"exactCategory", "mainCategory"},
Settings: map[string]interface{}{
"k": json.Number("1"),
},
}
t.Run("scheduling a classification", func(t *testing.T) {
class, err := classifier.Schedule(context.Background(), nil, params)
require.Nil(t, err, "should not error")
require.NotNil(t, class)
assert.Len(t, class.ID, 36, "an id was assigned")
id = class.ID
})
waitForStatusToNoLongerBeRunning(t, classifier, id)
t.Run("status is now failed", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, models.ClassificationStatusFailed, class.Status)
expectedErrStrings := []string{
"classification failed: ",
"classify Article/75ba35af-6a08-40ae-b442-3bec69b355f9: something went wrong",
"classify Article/f850439a-d3cd-4f17-8fbf-5a64405645cd: something went wrong",
"classify Article/a2bbcbdc-76e1-477d-9e72-a6d2cfb50109: something went wrong",
"classify Article/069410c3-4b9e-4f68-8034-32a066cb7997: something went wrong",
"classify Article/06a1e824-889c-4649-97f9-1ed3fa401d8e: something went wrong",
"classify Article/6402e649-b1e0-40ea-b192-a64eab0d5e56: something went wrong",
}
for _, msg := range expectedErrStrings {
assert.Contains(t, class.Error, msg)
}
})
})
t.Run("when there is nothing to be classified", func(t *testing.T) {
sg := &fakeSchemaGetter{testSchema()}
repo := newFakeClassificationRepo()
authorizer := &fakeAuthorizer{}
vectorRepo := newFakeVectorRepoKNN(nil, testDataAlreadyClassified())
logger, _ := test.NewNullLogger()
classifier := usecasesclassfication.New(sg, repo, vectorRepo, authorizer, logger, nil)
params := models.Classification{
Class: "Article",
BasedOnProperties: []string{"description"},
ClassifyProperties: []string{"exactCategory", "mainCategory"},
Settings: map[string]interface{}{
"k": json.Number("1"),
},
}
t.Run("scheduling a classification", func(t *testing.T) {
class, err := classifier.Schedule(context.Background(), nil, params)
require.Nil(t, err, "should not error")
require.NotNil(t, class)
assert.Len(t, class.ID, 36, "an id was assigned")
id = class.ID
})
waitForStatusToNoLongerBeRunning(t, classifier, id)
t.Run("status is now failed", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, models.ClassificationStatusFailed, class.Status)
expectedErr := "classification failed: " +
"no classes to be classified - did you run a previous classification already?"
assert.Equal(t, expectedErr, class.Error)
})
})
}
func waitForStatusToNoLongerBeRunning(t *testing.T, classifier *usecasesclassfication.Classifier, id strfmt.UUID) {
testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, true, func() interface{} {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
return class.Status != models.ClassificationStatusRunning
}, 100*time.Millisecond, 20*time.Second, "wait until status in no longer running")
}
type genericFakeRepo interface {
get(strfmt.UUID) (*models.Object, bool)
}
func checkRef(t *testing.T, repo genericFakeRepo, source, targetClass, propName, target string) {
object, ok := repo.get(strfmt.UUID(source))
require.True(t, ok, "object must be present")
schema, ok := object.Properties.(map[string]interface{})
require.True(t, ok, "schema must be map")
prop, ok := schema[propName]
require.True(t, ok, "ref prop must be present")
refs, ok := prop.(models.MultipleRef)
require.True(t, ok, "ref prop must be models.MultipleRef")
require.Len(t, refs, 1, "refs must have len 1")
assert.Equal(t, crossref.NewLocalhost(targetClass, strfmt.UUID(target)).String(), refs[0].Beacon.String(), "beacon must match")
}
type fakeVectorizer struct {
words map[string][]float32
}
func (f *fakeVectorizer) MultiVectorForWord(ctx context.Context, words []string) ([][]float32, error) {
out := make([][]float32, len(words))
for i, word := range words {
vector, ok := f.words[strings.ToLower(word)]
if !ok {
continue
}
out[i] = vector
}
return out, nil
}
func (f *fakeVectorizer) VectorOnlyForCorpi(ctx context.Context, corpi []string,
overrides map[string]string,
) ([]float32, error) {
words := strings.Split(corpi[0], " ")
if len(words) == 0 {
return nil, fmt.Errorf("vector for corpi called without words")
}
vectors, _ := f.MultiVectorForWord(ctx, words)
return f.centroid(vectors, words)
}
func (f *fakeVectorizer) centroid(in [][]float32, words []string) ([]float32, error) {
withoutNilVectors := make([][]float32, len(in))
if len(in) == 0 {
return nil, fmt.Errorf("got nil vector list for words: %v", words)
}
i := 0
for _, vec := range in {
if vec == nil {
continue
}
withoutNilVectors[i] = vec
i++
}
withoutNilVectors = withoutNilVectors[:i]
if i == 0 {
return nil, fmt.Errorf("no usable words: %v", words)
}
// take the first vector assuming all have the same length
out := make([]float32, len(withoutNilVectors[0]))
for _, vec := range withoutNilVectors {
for i, dim := range vec {
out[i] = out[i] + dim
}
}
for i, sum := range out {
out[i] = sum / float32(len(withoutNilVectors))
}
return out, nil
}