Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
//go:build integrationTest | |
// +build integrationTest | |
package classification_integration_test | |
import ( | |
"context" | |
"encoding/json" | |
"testing" | |
"time" | |
"github.com/go-openapi/strfmt" | |
"github.com/sirupsen/logrus/hooks/test" | |
"github.com/stretchr/testify/assert" | |
"github.com/stretchr/testify/require" | |
"github.com/weaviate/weaviate/adapters/repos/db" | |
"github.com/weaviate/weaviate/entities/dto" | |
"github.com/weaviate/weaviate/entities/filters" | |
"github.com/weaviate/weaviate/entities/models" | |
"github.com/weaviate/weaviate/entities/schema" | |
testhelper "github.com/weaviate/weaviate/test/helper" | |
"github.com/weaviate/weaviate/usecases/classification" | |
"github.com/weaviate/weaviate/usecases/objects" | |
) | |
func Test_Classifier_KNN_SaveConsistency(t *testing.T) { | |
dirName := t.TempDir() | |
logger, _ := test.NewNullLogger() | |
var id strfmt.UUID | |
shardState := singleShardState() | |
sg := &fakeSchemaGetter{ | |
schema: schema.Schema{Objects: &models.Schema{Classes: nil}}, | |
shardState: shardState, | |
} | |
vrepo, err := db.New(logger, db.Config{ | |
MemtablesFlushIdleAfter: 60, | |
RootPath: dirName, | |
QueryMaximumResults: 10000, | |
MaxImportGoroutinesFactor: 1, | |
}, &fakeRemoteClient{}, &fakeNodeResolver{}, &fakeRemoteNodeClient{}, &fakeReplicationClient{}, nil) | |
require.Nil(t, err) | |
vrepo.SetSchemaGetter(sg) | |
require.Nil(t, vrepo.WaitForStartup(context.Background())) | |
migrator := db.NewMigrator(vrepo, logger) | |
// so we can reuse it for follow up requests, such as checking the status | |
size := 400 | |
data := largeTestDataSize(size) | |
t.Run("preparations", func(t *testing.T) { | |
t.Run("creating the classes", func(t *testing.T) { | |
for _, c := range testSchema().Objects.Classes { | |
require.Nil(t, | |
migrator.AddClass(context.Background(), c, shardState)) | |
} | |
sg.schema = testSchema() | |
}) | |
t.Run("importing the training data", func(t *testing.T) { | |
classified := testDataAlreadyClassified() | |
bt := make(objects.BatchObjects, len(classified)) | |
for i, elem := range classified { | |
bt[i] = objects.BatchObject{ | |
OriginalIndex: i, | |
UUID: elem.ID, | |
Vector: elem.Vector, | |
Object: elem.Object(), | |
} | |
} | |
res, err := vrepo.BatchPutObjects(context.Background(), bt, nil) | |
require.Nil(t, err) | |
for _, elem := range res { | |
require.Nil(t, elem.Err) | |
} | |
}) | |
t.Run("importing the to be classified data", func(t *testing.T) { | |
bt := make(objects.BatchObjects, size) | |
for i, elem := range data { | |
bt[i] = objects.BatchObject{ | |
OriginalIndex: i, | |
UUID: elem.ID, | |
Vector: elem.Vector, | |
Object: elem.Object(), | |
} | |
} | |
res, err := vrepo.BatchPutObjects(context.Background(), bt, nil) | |
require.Nil(t, err) | |
for _, elem := range res { | |
require.Nil(t, elem.Err) | |
} | |
}) | |
}) | |
t.Run("classification journey", func(t *testing.T) { | |
repo := newFakeClassificationRepo() | |
authorizer := &fakeAuthorizer{} | |
classifier := classification.New(sg, repo, vrepo, authorizer, logger, nil) | |
params := models.Classification{ | |
Class: "Article", | |
BasedOnProperties: []string{"description"}, | |
ClassifyProperties: []string{"exactCategory", "mainCategory"}, | |
Settings: map[string]interface{}{ | |
"k": json.Number("1"), | |
}, | |
} | |
t.Run("scheduling a classification", func(t *testing.T) { | |
class, err := classifier.Schedule(context.Background(), nil, params) | |
require.Nil(t, err, "should not error") | |
require.NotNil(t, class) | |
assert.Len(t, class.ID, 36, "an id was assigned") | |
id = class.ID | |
}) | |
t.Run("retrieving the same classification by id", func(t *testing.T) { | |
class, err := classifier.Get(context.Background(), nil, id) | |
require.Nil(t, err) | |
require.NotNil(t, class) | |
assert.Equal(t, id, class.ID) | |
assert.Equal(t, models.ClassificationStatusRunning, class.Status) | |
}) | |
waitForStatusToNoLongerBeRunning(t, classifier, id) | |
t.Run("status is now completed", func(t *testing.T) { | |
class, err := classifier.Get(context.Background(), nil, id) | |
require.Nil(t, err) | |
require.NotNil(t, class) | |
assert.Equal(t, models.ClassificationStatusCompleted, class.Status) | |
assert.Equal(t, int64(400), class.Meta.CountSucceeded) | |
}) | |
t.Run("verify everything is classified", func(t *testing.T) { | |
filter := filters.LocalFilter{ | |
Root: &filters.Clause{ | |
Operator: filters.OperatorEqual, | |
On: &filters.Path{ | |
Class: "Article", | |
Property: "exactCategory", | |
}, | |
Value: &filters.Value{ | |
Value: 0, | |
Type: schema.DataTypeInt, | |
}, | |
}, | |
} | |
res, err := vrepo.Search(context.Background(), dto.GetParams{ | |
ClassName: "Article", | |
Filters: &filter, | |
Pagination: &filters.Pagination{ | |
Limit: 10000, | |
}, | |
}) | |
require.Nil(t, err) | |
assert.Equal(t, 0, len(res)) | |
}) | |
}) | |
} | |
func Test_Classifier_ZeroShot_SaveConsistency(t *testing.T) { | |
t.Skip() | |
dirName := t.TempDir() | |
logger, _ := test.NewNullLogger() | |
var id strfmt.UUID | |
sg := &fakeSchemaGetter{shardState: singleShardState()} | |
vrepo, err := db.New(logger, db.Config{ | |
RootPath: dirName, | |
QueryMaximumResults: 10000, | |
MaxImportGoroutinesFactor: 1, | |
}, &fakeRemoteClient{}, &fakeNodeResolver{}, &fakeRemoteNodeClient{}, &fakeReplicationClient{}, nil) | |
require.Nil(t, err) | |
vrepo.SetSchemaGetter(sg) | |
require.Nil(t, vrepo.WaitForStartup(context.Background())) | |
migrator := db.NewMigrator(vrepo, logger) | |
t.Run("preparations", func(t *testing.T) { | |
t.Run("creating the classes", func(t *testing.T) { | |
for _, c := range testSchemaForZeroShot().Objects.Classes { | |
require.Nil(t, | |
migrator.AddClass(context.Background(), c, sg.shardState)) | |
} | |
sg.schema = testSchemaForZeroShot() | |
}) | |
t.Run("importing the training data", func(t *testing.T) { | |
classified := testDataZeroShotUnclassified() | |
bt := make(objects.BatchObjects, len(classified)) | |
for i, elem := range classified { | |
bt[i] = objects.BatchObject{ | |
OriginalIndex: i, | |
UUID: elem.ID, | |
Vector: elem.Vector, | |
Object: elem.Object(), | |
} | |
} | |
res, err := vrepo.BatchPutObjects(context.Background(), bt, nil) | |
require.Nil(t, err) | |
for _, elem := range res { | |
require.Nil(t, elem.Err) | |
} | |
}) | |
}) | |
t.Run("classification journey", func(t *testing.T) { | |
repo := newFakeClassificationRepo() | |
authorizer := &fakeAuthorizer{} | |
classifier := classification.New(sg, repo, vrepo, authorizer, logger, nil) | |
params := models.Classification{ | |
Class: "Recipes", | |
BasedOnProperties: []string{"text"}, | |
ClassifyProperties: []string{"ofFoodType"}, | |
Type: "zeroshot", | |
} | |
t.Run("scheduling a classification", func(t *testing.T) { | |
class, err := classifier.Schedule(context.Background(), nil, params) | |
require.Nil(t, err, "should not error") | |
require.NotNil(t, class) | |
assert.Len(t, class.ID, 36, "an id was assigned") | |
id = class.ID | |
}) | |
t.Run("retrieving the same classification by id", func(t *testing.T) { | |
class, err := classifier.Get(context.Background(), nil, id) | |
require.Nil(t, err) | |
require.NotNil(t, class) | |
assert.Equal(t, id, class.ID) | |
assert.Equal(t, models.ClassificationStatusRunning, class.Status) | |
}) | |
waitForStatusToNoLongerBeRunning(t, classifier, id) | |
t.Run("status is now completed", func(t *testing.T) { | |
class, err := classifier.Get(context.Background(), nil, id) | |
require.Nil(t, err) | |
require.NotNil(t, class) | |
assert.Equal(t, models.ClassificationStatusCompleted, class.Status) | |
assert.Equal(t, int64(2), class.Meta.CountSucceeded) | |
}) | |
t.Run("verify everything is classified", func(t *testing.T) { | |
filter := filters.LocalFilter{ | |
Root: &filters.Clause{ | |
Operator: filters.OperatorEqual, | |
On: &filters.Path{ | |
Class: "Recipes", | |
Property: "ofFoodType", | |
}, | |
Value: &filters.Value{ | |
Value: 0, | |
Type: schema.DataTypeInt, | |
}, | |
}, | |
} | |
res, err := vrepo.Search(context.Background(), dto.GetParams{ | |
ClassName: "Recipes", | |
Filters: &filter, | |
Pagination: &filters.Pagination{ | |
Limit: 100000, | |
}, | |
}) | |
require.Nil(t, err) | |
assert.Equal(t, 0, len(res)) | |
}) | |
}) | |
} | |
func waitForStatusToNoLongerBeRunning(t *testing.T, classifier *classification.Classifier, id strfmt.UUID) { | |
testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, true, func() interface{} { | |
class, err := classifier.Get(context.Background(), nil, id) | |
require.Nil(t, err) | |
require.NotNil(t, class) | |
return class.Status != models.ClassificationStatusRunning | |
}, 100*time.Millisecond, 20*time.Second, "wait until status in no longer running") | |
} | |