Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package test | |
import ( | |
"encoding/json" | |
"fmt" | |
"testing" | |
"time" | |
"github.com/go-openapi/strfmt" | |
"github.com/stretchr/testify/assert" | |
"github.com/stretchr/testify/require" | |
"github.com/weaviate/weaviate/client/classifications" | |
"github.com/weaviate/weaviate/client/objects" | |
"github.com/weaviate/weaviate/client/schema" | |
"github.com/weaviate/weaviate/entities/models" | |
"github.com/weaviate/weaviate/test/helper" | |
testhelper "github.com/weaviate/weaviate/test/helper" | |
) | |
func knnClassification(t *testing.T) { | |
var id strfmt.UUID | |
t.Run("ensure class shard for classification is ready", func(t *testing.T) { | |
testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, "READY", | |
func() interface{} { | |
shardStatus, err := helper.Client(t).Schema.SchemaObjectsShardsGet(schema.NewSchemaObjectsShardsGetParams().WithClassName("Recipe"), nil) | |
require.Nil(t, err) | |
require.GreaterOrEqual(t, len(shardStatus.Payload), 1) | |
return shardStatus.Payload[0].Status | |
}, 250*time.Millisecond, 15*time.Second) | |
}) | |
t.Run("start the classification and wait for completion", func(t *testing.T) { | |
res, err := helper.Client(t).Classifications.ClassificationsPost( | |
classifications.NewClassificationsPostParams().WithParams(&models.Classification{ | |
Class: "Recipe", | |
ClassifyProperties: []string{"ofType"}, | |
BasedOnProperties: []string{"content"}, | |
Type: "knn", | |
Settings: map[string]interface{}{ | |
"k": 5, | |
}, | |
}), nil) | |
require.Nil(t, err) | |
id = res.Payload.ID | |
// wait for classification to be completed | |
testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, "completed", | |
func() interface{} { | |
res, err := helper.Client(t).Classifications.ClassificationsGet( | |
classifications.NewClassificationsGetParams().WithID(id.String()), nil) | |
require.Nil(t, err) | |
return res.Payload.Status | |
}, 100*time.Millisecond, 15*time.Second) | |
}) | |
t.Run("assure changes present", func(t *testing.T) { | |
// wait for latest changes to be indexed / wait for consistency | |
testhelper.AssertEventuallyEqual(t, true, func() interface{} { | |
res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams(). | |
WithID(unclassifiedSavory), nil) | |
require.Nil(t, err) | |
return res.Payload.Properties.(map[string]interface{})["ofType"] != nil | |
}) | |
testhelper.AssertEventuallyEqual(t, true, func() interface{} { | |
res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams(). | |
WithID(unclassifiedSweet), nil) | |
require.Nil(t, err) | |
return res.Payload.Properties.(map[string]interface{})["ofType"] != nil | |
}) | |
}) | |
t.Run("inspect unclassified savory", func(t *testing.T) { | |
res, err := helper.Client(t).Objects. | |
ObjectsGet(objects.NewObjectsGetParams(). | |
WithID(unclassifiedSavory). | |
WithInclude(ptString("classification")), nil) | |
require.Nil(t, err) | |
schema, ok := res.Payload.Properties.(map[string]interface{}) | |
require.True(t, ok) | |
expectedRefTarget := fmt.Sprintf("weaviate://localhost/RecipeType/%s", | |
recipeTypeSavory) | |
ref := schema["ofType"].([]interface{})[0].(map[string]interface{}) | |
assert.Equal(t, ref["beacon"].(string), expectedRefTarget) | |
verifyMetaDistances(t, ref) | |
}) | |
t.Run("inspect unclassified sweet", func(t *testing.T) { | |
res, err := helper.Client(t).Objects. | |
ObjectsGet(objects.NewObjectsGetParams(). | |
WithID(unclassifiedSweet). | |
WithInclude(ptString("classification")), nil) | |
require.Nil(t, err) | |
schema, ok := res.Payload.Properties.(map[string]interface{}) | |
require.True(t, ok) | |
expectedRefTarget := fmt.Sprintf("weaviate://localhost/RecipeType/%s", | |
recipeTypeSweet) | |
ref := schema["ofType"].([]interface{})[0].(map[string]interface{}) | |
assert.Equal(t, ref["beacon"].(string), expectedRefTarget) | |
verifyMetaDistances(t, ref) | |
}) | |
} | |
func verifyMetaDistances(t *testing.T, ref map[string]interface{}) { | |
classification, ok := ref["classification"].(map[string]interface{}) | |
require.True(t, ok) | |
assert.Equal(t, json.Number("3"), classification["winningCount"]) | |
assert.Equal(t, json.Number("2"), classification["losingCount"]) | |
assert.Equal(t, json.Number("5"), classification["overallCount"]) | |
closestWinning, err := classification["closestWinningDistance"].(json.Number).Float64() | |
require.Nil(t, err) | |
closestLosing, err := classification["closestLosingDistance"].(json.Number).Float64() | |
require.Nil(t, err) | |
closestOverall, err := classification["closestOverallDistance"].(json.Number).Float64() | |
require.Nil(t, err) | |
meanWinning, err := classification["meanWinningDistance"].(json.Number).Float64() | |
require.Nil(t, err) | |
meanLosing, err := classification["meanLosingDistance"].(json.Number).Float64() | |
require.Nil(t, err) | |
assert.True(t, closestWinning == closestOverall, "closestWinning == closestOverall") | |
assert.True(t, closestWinning < meanWinning, "closestWinning < meanWinning") | |
assert.True(t, closestWinning < closestLosing, "closestWinning < closestLosing") | |
assert.True(t, closestLosing < meanLosing, "closestLosing < meanLosing") | |
} | |