SemanticSearchPOC / test /acceptance /classifications /knn_classification_test.go
KevinStephenson
Adding in weaviate code
b110593
raw
history blame
5.5 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package test
import (
"encoding/json"
"fmt"
"testing"
"time"
"github.com/go-openapi/strfmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/weaviate/weaviate/client/classifications"
"github.com/weaviate/weaviate/client/objects"
"github.com/weaviate/weaviate/client/schema"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/test/helper"
testhelper "github.com/weaviate/weaviate/test/helper"
)
func knnClassification(t *testing.T) {
var id strfmt.UUID
t.Run("ensure class shard for classification is ready", func(t *testing.T) {
testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, "READY",
func() interface{} {
shardStatus, err := helper.Client(t).Schema.SchemaObjectsShardsGet(schema.NewSchemaObjectsShardsGetParams().WithClassName("Recipe"), nil)
require.Nil(t, err)
require.GreaterOrEqual(t, len(shardStatus.Payload), 1)
return shardStatus.Payload[0].Status
}, 250*time.Millisecond, 15*time.Second)
})
t.Run("start the classification and wait for completion", func(t *testing.T) {
res, err := helper.Client(t).Classifications.ClassificationsPost(
classifications.NewClassificationsPostParams().WithParams(&models.Classification{
Class: "Recipe",
ClassifyProperties: []string{"ofType"},
BasedOnProperties: []string{"content"},
Type: "knn",
Settings: map[string]interface{}{
"k": 5,
},
}), nil)
require.Nil(t, err)
id = res.Payload.ID
// wait for classification to be completed
testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, "completed",
func() interface{} {
res, err := helper.Client(t).Classifications.ClassificationsGet(
classifications.NewClassificationsGetParams().WithID(id.String()), nil)
require.Nil(t, err)
return res.Payload.Status
}, 100*time.Millisecond, 15*time.Second)
})
t.Run("assure changes present", func(t *testing.T) {
// wait for latest changes to be indexed / wait for consistency
testhelper.AssertEventuallyEqual(t, true, func() interface{} {
res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams().
WithID(unclassifiedSavory), nil)
require.Nil(t, err)
return res.Payload.Properties.(map[string]interface{})["ofType"] != nil
})
testhelper.AssertEventuallyEqual(t, true, func() interface{} {
res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams().
WithID(unclassifiedSweet), nil)
require.Nil(t, err)
return res.Payload.Properties.(map[string]interface{})["ofType"] != nil
})
})
t.Run("inspect unclassified savory", func(t *testing.T) {
res, err := helper.Client(t).Objects.
ObjectsGet(objects.NewObjectsGetParams().
WithID(unclassifiedSavory).
WithInclude(ptString("classification")), nil)
require.Nil(t, err)
schema, ok := res.Payload.Properties.(map[string]interface{})
require.True(t, ok)
expectedRefTarget := fmt.Sprintf("weaviate://localhost/RecipeType/%s",
recipeTypeSavory)
ref := schema["ofType"].([]interface{})[0].(map[string]interface{})
assert.Equal(t, ref["beacon"].(string), expectedRefTarget)
verifyMetaDistances(t, ref)
})
t.Run("inspect unclassified sweet", func(t *testing.T) {
res, err := helper.Client(t).Objects.
ObjectsGet(objects.NewObjectsGetParams().
WithID(unclassifiedSweet).
WithInclude(ptString("classification")), nil)
require.Nil(t, err)
schema, ok := res.Payload.Properties.(map[string]interface{})
require.True(t, ok)
expectedRefTarget := fmt.Sprintf("weaviate://localhost/RecipeType/%s",
recipeTypeSweet)
ref := schema["ofType"].([]interface{})[0].(map[string]interface{})
assert.Equal(t, ref["beacon"].(string), expectedRefTarget)
verifyMetaDistances(t, ref)
})
}
func verifyMetaDistances(t *testing.T, ref map[string]interface{}) {
classification, ok := ref["classification"].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, json.Number("3"), classification["winningCount"])
assert.Equal(t, json.Number("2"), classification["losingCount"])
assert.Equal(t, json.Number("5"), classification["overallCount"])
closestWinning, err := classification["closestWinningDistance"].(json.Number).Float64()
require.Nil(t, err)
closestLosing, err := classification["closestLosingDistance"].(json.Number).Float64()
require.Nil(t, err)
closestOverall, err := classification["closestOverallDistance"].(json.Number).Float64()
require.Nil(t, err)
meanWinning, err := classification["meanWinningDistance"].(json.Number).Float64()
require.Nil(t, err)
meanLosing, err := classification["meanLosingDistance"].(json.Number).Float64()
require.Nil(t, err)
assert.True(t, closestWinning == closestOverall, "closestWinning == closestOverall")
assert.True(t, closestWinning < meanWinning, "closestWinning < meanWinning")
assert.True(t, closestWinning < closestLosing, "closestWinning < closestLosing")
assert.True(t, closestLosing < meanLosing, "closestLosing < meanLosing")
}