KevinStephenson
Adding in weaviate code
b110593
raw
history blame
9.27 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
//go:build integrationTest
// +build integrationTest
package classification_integration_test
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/go-openapi/strfmt"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/weaviate/weaviate/adapters/repos/db"
"github.com/weaviate/weaviate/entities/dto"
"github.com/weaviate/weaviate/entities/filters"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/schema"
testhelper "github.com/weaviate/weaviate/test/helper"
"github.com/weaviate/weaviate/usecases/classification"
"github.com/weaviate/weaviate/usecases/objects"
)
func Test_Classifier_KNN_SaveConsistency(t *testing.T) {
dirName := t.TempDir()
logger, _ := test.NewNullLogger()
var id strfmt.UUID
shardState := singleShardState()
sg := &fakeSchemaGetter{
schema: schema.Schema{Objects: &models.Schema{Classes: nil}},
shardState: shardState,
}
vrepo, err := db.New(logger, db.Config{
MemtablesFlushIdleAfter: 60,
RootPath: dirName,
QueryMaximumResults: 10000,
MaxImportGoroutinesFactor: 1,
}, &fakeRemoteClient{}, &fakeNodeResolver{}, &fakeRemoteNodeClient{}, &fakeReplicationClient{}, nil)
require.Nil(t, err)
vrepo.SetSchemaGetter(sg)
require.Nil(t, vrepo.WaitForStartup(context.Background()))
migrator := db.NewMigrator(vrepo, logger)
// so we can reuse it for follow up requests, such as checking the status
size := 400
data := largeTestDataSize(size)
t.Run("preparations", func(t *testing.T) {
t.Run("creating the classes", func(t *testing.T) {
for _, c := range testSchema().Objects.Classes {
require.Nil(t,
migrator.AddClass(context.Background(), c, shardState))
}
sg.schema = testSchema()
})
t.Run("importing the training data", func(t *testing.T) {
classified := testDataAlreadyClassified()
bt := make(objects.BatchObjects, len(classified))
for i, elem := range classified {
bt[i] = objects.BatchObject{
OriginalIndex: i,
UUID: elem.ID,
Vector: elem.Vector,
Object: elem.Object(),
}
}
res, err := vrepo.BatchPutObjects(context.Background(), bt, nil)
require.Nil(t, err)
for _, elem := range res {
require.Nil(t, elem.Err)
}
})
t.Run("importing the to be classified data", func(t *testing.T) {
bt := make(objects.BatchObjects, size)
for i, elem := range data {
bt[i] = objects.BatchObject{
OriginalIndex: i,
UUID: elem.ID,
Vector: elem.Vector,
Object: elem.Object(),
}
}
res, err := vrepo.BatchPutObjects(context.Background(), bt, nil)
require.Nil(t, err)
for _, elem := range res {
require.Nil(t, elem.Err)
}
})
})
t.Run("classification journey", func(t *testing.T) {
repo := newFakeClassificationRepo()
authorizer := &fakeAuthorizer{}
classifier := classification.New(sg, repo, vrepo, authorizer, logger, nil)
params := models.Classification{
Class: "Article",
BasedOnProperties: []string{"description"},
ClassifyProperties: []string{"exactCategory", "mainCategory"},
Settings: map[string]interface{}{
"k": json.Number("1"),
},
}
t.Run("scheduling a classification", func(t *testing.T) {
class, err := classifier.Schedule(context.Background(), nil, params)
require.Nil(t, err, "should not error")
require.NotNil(t, class)
assert.Len(t, class.ID, 36, "an id was assigned")
id = class.ID
})
t.Run("retrieving the same classification by id", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, id, class.ID)
assert.Equal(t, models.ClassificationStatusRunning, class.Status)
})
waitForStatusToNoLongerBeRunning(t, classifier, id)
t.Run("status is now completed", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, models.ClassificationStatusCompleted, class.Status)
assert.Equal(t, int64(400), class.Meta.CountSucceeded)
})
t.Run("verify everything is classified", func(t *testing.T) {
filter := filters.LocalFilter{
Root: &filters.Clause{
Operator: filters.OperatorEqual,
On: &filters.Path{
Class: "Article",
Property: "exactCategory",
},
Value: &filters.Value{
Value: 0,
Type: schema.DataTypeInt,
},
},
}
res, err := vrepo.Search(context.Background(), dto.GetParams{
ClassName: "Article",
Filters: &filter,
Pagination: &filters.Pagination{
Limit: 10000,
},
})
require.Nil(t, err)
assert.Equal(t, 0, len(res))
})
})
}
func Test_Classifier_ZeroShot_SaveConsistency(t *testing.T) {
t.Skip()
dirName := t.TempDir()
logger, _ := test.NewNullLogger()
var id strfmt.UUID
sg := &fakeSchemaGetter{shardState: singleShardState()}
vrepo, err := db.New(logger, db.Config{
RootPath: dirName,
QueryMaximumResults: 10000,
MaxImportGoroutinesFactor: 1,
}, &fakeRemoteClient{}, &fakeNodeResolver{}, &fakeRemoteNodeClient{}, &fakeReplicationClient{}, nil)
require.Nil(t, err)
vrepo.SetSchemaGetter(sg)
require.Nil(t, vrepo.WaitForStartup(context.Background()))
migrator := db.NewMigrator(vrepo, logger)
t.Run("preparations", func(t *testing.T) {
t.Run("creating the classes", func(t *testing.T) {
for _, c := range testSchemaForZeroShot().Objects.Classes {
require.Nil(t,
migrator.AddClass(context.Background(), c, sg.shardState))
}
sg.schema = testSchemaForZeroShot()
})
t.Run("importing the training data", func(t *testing.T) {
classified := testDataZeroShotUnclassified()
bt := make(objects.BatchObjects, len(classified))
for i, elem := range classified {
bt[i] = objects.BatchObject{
OriginalIndex: i,
UUID: elem.ID,
Vector: elem.Vector,
Object: elem.Object(),
}
}
res, err := vrepo.BatchPutObjects(context.Background(), bt, nil)
require.Nil(t, err)
for _, elem := range res {
require.Nil(t, elem.Err)
}
})
})
t.Run("classification journey", func(t *testing.T) {
repo := newFakeClassificationRepo()
authorizer := &fakeAuthorizer{}
classifier := classification.New(sg, repo, vrepo, authorizer, logger, nil)
params := models.Classification{
Class: "Recipes",
BasedOnProperties: []string{"text"},
ClassifyProperties: []string{"ofFoodType"},
Type: "zeroshot",
}
t.Run("scheduling a classification", func(t *testing.T) {
class, err := classifier.Schedule(context.Background(), nil, params)
require.Nil(t, err, "should not error")
require.NotNil(t, class)
assert.Len(t, class.ID, 36, "an id was assigned")
id = class.ID
})
t.Run("retrieving the same classification by id", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, id, class.ID)
assert.Equal(t, models.ClassificationStatusRunning, class.Status)
})
waitForStatusToNoLongerBeRunning(t, classifier, id)
t.Run("status is now completed", func(t *testing.T) {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
assert.Equal(t, models.ClassificationStatusCompleted, class.Status)
assert.Equal(t, int64(2), class.Meta.CountSucceeded)
})
t.Run("verify everything is classified", func(t *testing.T) {
filter := filters.LocalFilter{
Root: &filters.Clause{
Operator: filters.OperatorEqual,
On: &filters.Path{
Class: "Recipes",
Property: "ofFoodType",
},
Value: &filters.Value{
Value: 0,
Type: schema.DataTypeInt,
},
},
}
res, err := vrepo.Search(context.Background(), dto.GetParams{
ClassName: "Recipes",
Filters: &filter,
Pagination: &filters.Pagination{
Limit: 100000,
},
})
require.Nil(t, err)
assert.Equal(t, 0, len(res))
})
})
}
func waitForStatusToNoLongerBeRunning(t *testing.T, classifier *classification.Classifier, id strfmt.UUID) {
testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, true, func() interface{} {
class, err := classifier.Get(context.Background(), nil, id)
require.Nil(t, err)
require.NotNil(t, class)
return class.Status != models.ClassificationStatusRunning
}, 100*time.Millisecond, 20*time.Second, "wait until status in no longer running")
}