SemanticSearchPOC / test /helper /journey /group_by_journey_tests.go
KevinStephenson
Adding in weaviate code
b110593
raw
history blame
4.78 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package journey
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/weaviate/weaviate/test/helper"
graphqlhelper "github.com/weaviate/weaviate/test/helper/graphql"
"github.com/weaviate/weaviate/test/helper/sample-schema/documents"
)
func GroupBySingleAndMultiShardTests(t *testing.T, weaviateEndpoint string) {
if weaviateEndpoint != "" {
helper.SetupClient(weaviateEndpoint)
}
// helper methods
getGroup := func(value interface{}) map[string]interface{} {
group := value.(map[string]interface{})["_additional"].(map[string]interface{})["group"].(map[string]interface{})
return group
}
getGroupHits := func(group map[string]interface{}) (string, []string) {
result := []string{}
hits := group["hits"].([]interface{})
for _, hit := range hits {
additional := hit.(map[string]interface{})["_additional"].(map[string]interface{})
result = append(result, additional["id"].(string))
}
groupedBy := group["groupedBy"].(map[string]interface{})
groupedByValue := groupedBy["value"].(string)
return groupedByValue, result
}
// test methods
create := func(t *testing.T, multishard bool) {
for _, class := range documents.ClassesContextionaryVectorizer(multishard) {
helper.CreateClass(t, class)
}
for _, obj := range documents.Objects() {
helper.CreateObject(t, obj)
helper.AssertGetObjectEventually(t, obj.Class, obj.ID)
}
}
groupBy := func(t *testing.T, groupsCount, objectsPerGroup int) {
query := `
{
Get{
Passage(
nearObject:{
id: "00000000-0000-0000-0000-000000000001"
}
groupBy:{
path:["ofDocument"]
groups:%v
objectsPerGroup:%v
}
){
_additional{
id
group{
groupedBy{value}
count
maxDistance
minDistance
hits {
_additional{
id
distance
}
}
}
}
}
}
}
`
result := graphqlhelper.AssertGraphQL(t, helper.RootAuth, fmt.Sprintf(query, groupsCount, objectsPerGroup))
groups := result.Get("Get", "Passage").AsSlice()
require.Len(t, groups, groupsCount)
expectedResults := map[string][]string{}
groupedBy1 := `weaviate://localhost/Document/00000000-0000-0000-0000-000000000011`
expectedGroup1 := []string{
documents.PassageIDs[0].String(),
documents.PassageIDs[5].String(),
documents.PassageIDs[4].String(),
documents.PassageIDs[3].String(),
documents.PassageIDs[2].String(),
documents.PassageIDs[1].String(),
}
expectedResults[groupedBy1] = expectedGroup1
groupedBy2 := `weaviate://localhost/Document/00000000-0000-0000-0000-000000000012`
expectedGroup2 := []string{
documents.PassageIDs[6].String(),
documents.PassageIDs[7].String(),
}
expectedResults[groupedBy2] = expectedGroup2
groupsOrder := []string{groupedBy1, groupedBy2}
for i, current := range groups {
group := getGroup(current)
groupedBy, ids := getGroupHits(group)
assert.Equal(t, groupsOrder[i], groupedBy)
for j := range ids {
assert.Equal(t, expectedResults[groupedBy][j], ids[j])
}
}
}
delete := func(t *testing.T) {
helper.DeleteClass(t, documents.Passage)
helper.DeleteClass(t, documents.Document)
}
// tests
tests := []struct {
name string
multishard bool
groups, objectsPerGroup int
}{
{
name: "single shard - 2 groups 10 objects per group",
multishard: false,
groups: 2,
objectsPerGroup: 10,
},
{
name: "multi shard - 2 groups 10 objects per group",
multishard: true,
groups: 2,
objectsPerGroup: 10,
},
{
name: "single shard - 1 groups 1 objects per group",
multishard: false,
groups: 1,
objectsPerGroup: 1,
},
{
name: "multi shard - 1 groups 1 objects per group",
multishard: true,
groups: 1,
objectsPerGroup: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Run("create", func(t *testing.T) {
create(t, tt.multishard)
})
t.Run("group by", func(t *testing.T) {
groupBy(t, tt.groups, tt.objectsPerGroup)
})
t.Run("delete", func(t *testing.T) {
delete(t)
})
})
}
}