Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package grouper | |
import ( | |
"testing" | |
"github.com/go-openapi/strfmt" | |
"github.com/weaviate/weaviate/entities/schema/crossref" | |
"github.com/sirupsen/logrus/hooks/test" | |
"github.com/stretchr/testify/assert" | |
"github.com/stretchr/testify/require" | |
"github.com/weaviate/weaviate/entities/models" | |
"github.com/weaviate/weaviate/entities/search" | |
) | |
func TestGrouper_ModeClosest(t *testing.T) { | |
in := []search.Result{ | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.1, 0.98}, | |
Schema: map[string]interface{}{ | |
"name": "A1", | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.1, 0.96}, | |
Schema: map[string]interface{}{ | |
"name": "A2", | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.1, 0.93}, | |
Schema: map[string]interface{}{ | |
"name": "A3", | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.98, 0.1}, | |
Schema: map[string]interface{}{ | |
"name": "B1", | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.93, 0.1}, | |
Schema: map[string]interface{}{ | |
"name": "B2", | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.92, 0.1}, | |
Schema: map[string]interface{}{ | |
"name": "B3", | |
}, | |
}, | |
} | |
expectedOut := []search.Result{ | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.1, 0.98}, | |
Schema: map[string]interface{}{ | |
"name": "A1", | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.98, 0.1}, | |
Schema: map[string]interface{}{ | |
"name": "B1", | |
}, | |
}, | |
} | |
log, _ := test.NewNullLogger() | |
res, err := New(log).Group(in, "closest", 0.2) | |
require.Nil(t, err) | |
assert.Equal(t, expectedOut, res) | |
for i := range res { | |
assert.Equal(t, expectedOut[i].ClassName, res[i].ClassName) | |
} | |
} | |
func TestGrouper_ModeMerge(t *testing.T) { | |
in := []search.Result{ | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.1, 0.98}, | |
Schema: map[string]interface{}{ | |
"name": "A1", | |
"count": 10.0, | |
"illegal": true, | |
"location": &models.GeoCoordinates{ | |
Latitude: ptFloat32(20), | |
Longitude: ptFloat32(20), | |
}, | |
"relatedTo": []interface{}{ | |
search.LocalRef{ | |
Class: "Foo", | |
Fields: map[string]interface{}{ | |
"id": strfmt.UUID("1"), | |
"foo": "bar1", | |
}, | |
}, | |
search.LocalRef{ | |
Class: "Foo", | |
Fields: map[string]interface{}{ | |
"id": strfmt.UUID("2"), | |
"foo": "bar2", | |
}, | |
}, | |
}, | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.1, 0.96}, | |
Schema: map[string]interface{}{ | |
"name": "A2", | |
"count": 11.0, | |
"illegal": true, | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.1, 0.96}, | |
Schema: map[string]interface{}{ | |
"name": "A2", | |
"count": 11.0, | |
"illegal": true, | |
"relatedTo": []interface{}{ | |
search.LocalRef{ | |
Class: "Foo", | |
Fields: map[string]interface{}{ | |
"id": strfmt.UUID("3"), | |
"foo": "bar3", | |
}, | |
}, | |
}, | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.1, 0.93}, | |
Schema: map[string]interface{}{ | |
"name": "A3", | |
"count": 12.0, | |
"illegal": false, | |
"location": &models.GeoCoordinates{ | |
Latitude: ptFloat32(22), | |
Longitude: ptFloat32(18), | |
}, | |
"relatedTo": []interface{}{ | |
search.LocalRef{ | |
Class: "Foo", | |
Fields: map[string]interface{}{ | |
"id": strfmt.UUID("2"), | |
"foo": "bar2", | |
}, | |
}, | |
}, | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.98, 0.1}, | |
Schema: map[string]interface{}{ | |
"name": "B1", | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.93, 0.1}, | |
Schema: map[string]interface{}{ | |
"name": "B2", | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.92, 0.1}, | |
Schema: map[string]interface{}{ | |
"name": "B3", | |
}, | |
}, | |
} | |
expectedOut := []search.Result{ | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.1, 0.95750004}, // centroid position of all inputs | |
Schema: map[string]interface{}{ | |
"name": "A1 (A2, A3)", // note that A2 is only contained once, even though its twice in the input set | |
"count": 11.0, // mean of all inputs | |
"illegal": true, // the most common input value, with a bias towards true on equal count | |
"location": &models.GeoCoordinates{ | |
Latitude: ptFloat32(21), | |
Longitude: ptFloat32(19), | |
}, | |
"relatedTo": []interface{}{ | |
search.LocalRef{ | |
Class: "Foo", | |
Fields: map[string]interface{}{ | |
"id": strfmt.UUID("1"), | |
"foo": "bar1", | |
}, | |
}, | |
search.LocalRef{ | |
Class: "Foo", | |
Fields: map[string]interface{}{ | |
"id": strfmt.UUID("2"), | |
"foo": "bar2", | |
}, | |
}, | |
search.LocalRef{ | |
Class: "Foo", | |
Fields: map[string]interface{}{ | |
"id": strfmt.UUID("3"), | |
"foo": "bar3", | |
}, | |
}, | |
}, | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.9433334, 0.1}, | |
Schema: map[string]interface{}{ | |
"name": "B1 (B2, B3)", | |
}, | |
}, | |
} | |
log, _ := test.NewNullLogger() | |
res, err := New(log).Group(in, "merge", 0.2) | |
require.Nil(t, err) | |
assert.Equal(t, expectedOut, res) | |
for i := range res { | |
assert.Equal(t, expectedOut[i].ClassName, res[i].ClassName) | |
} | |
} | |
// Since reference properties can be represented both as models.MultipleRef | |
// and []interface{}, we need to test for both cases. TestGrouper_ModeMerge | |
// above tests the case of []interface{}, so this test handles the other case. | |
// see https://github.com/weaviate/weaviate/pull/2320 for more info | |
func Test_Grouper_ModeMerge_MultipleRef(t *testing.T) { | |
in := []search.Result{ | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.1, 0.98}, | |
Schema: map[string]interface{}{ | |
"name": "A1", | |
"count": 10.0, | |
"illegal": true, | |
"location": &models.GeoCoordinates{ | |
Latitude: ptFloat32(20), | |
Longitude: ptFloat32(20), | |
}, | |
"relatedTo": models.MultipleRef{ | |
&models.SingleRef{ | |
Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "3dc4417d-1508-4914-9929-8add49684b9f").String()), | |
Class: "Foo", | |
}, | |
&models.SingleRef{ | |
Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "f1d6df98-33a7-40bb-bcb4-57c1f35d31ab").String()), | |
Class: "Foo", | |
}, | |
}, | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.1, 0.96}, | |
Schema: map[string]interface{}{ | |
"name": "A2", | |
"count": 11.0, | |
"illegal": true, | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.1, 0.96}, | |
Schema: map[string]interface{}{ | |
"name": "A2", | |
"count": 11.0, | |
"illegal": true, | |
"relatedTo": models.MultipleRef{ | |
&models.SingleRef{ | |
Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "f280a7f7-7fab-46ed-b895-1490512660ae").String()), | |
Class: "Foo", | |
}, | |
}, | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.1, 0.93}, | |
Schema: map[string]interface{}{ | |
"name": "A3", | |
"count": 12.0, | |
"illegal": false, | |
"location": &models.GeoCoordinates{ | |
Latitude: ptFloat32(22), | |
Longitude: ptFloat32(18), | |
}, | |
"relatedTo": models.MultipleRef{ | |
&models.SingleRef{ | |
Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "f1d6df98-33a7-40bb-bcb4-57c1f35d31ab").String()), | |
Class: "Foo", | |
}, | |
}, | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.98, 0.1}, | |
Schema: map[string]interface{}{ | |
"name": "B1", | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.93, 0.1}, | |
Schema: map[string]interface{}{ | |
"name": "B2", | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.92, 0.1}, | |
Schema: map[string]interface{}{ | |
"name": "B3", | |
}, | |
}, | |
} | |
expectedOut := []search.Result{ | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.1, 0.95750004}, // centroid position of all inputs | |
Schema: map[string]interface{}{ | |
"name": "A1 (A2, A3)", // note that A2 is only contained once, even though its twice in the input set | |
"count": 11.0, // mean of all inputs | |
"illegal": true, // the most common input value, with a bias towards true on equal count | |
"location": &models.GeoCoordinates{ | |
Latitude: ptFloat32(21), | |
Longitude: ptFloat32(19), | |
}, | |
"relatedTo": []interface{}{ | |
&models.SingleRef{ | |
Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "3dc4417d-1508-4914-9929-8add49684b9f").String()), | |
Class: "Foo", | |
}, | |
&models.SingleRef{ | |
Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "f1d6df98-33a7-40bb-bcb4-57c1f35d31ab").String()), | |
Class: "Foo", | |
}, | |
&models.SingleRef{ | |
Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "f280a7f7-7fab-46ed-b895-1490512660ae").String()), | |
Class: "Foo", | |
}, | |
}, | |
}, | |
}, | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.9433334, 0.1}, | |
Schema: map[string]interface{}{ | |
"name": "B1 (B2, B3)", | |
}, | |
}, | |
} | |
log, _ := test.NewNullLogger() | |
res, err := New(log).Group(in, "merge", 0.2) | |
require.Nil(t, err) | |
assert.Equal(t, expectedOut, res) | |
for i := range res { | |
assert.Equal(t, expectedOut[i].ClassName, res[i].ClassName) | |
} | |
} | |
func TestGrouper_ModeMergeFailWithIDTypeOtherThenUUID(t *testing.T) { | |
in := []search.Result{ | |
{ | |
ClassName: "Foo", | |
Vector: []float32{0.1, 0.1, 0.98}, | |
Schema: map[string]interface{}{ | |
"name": "A1", | |
"count": 10.0, | |
"illegal": true, | |
"location": &models.GeoCoordinates{ | |
Latitude: ptFloat32(20), | |
Longitude: ptFloat32(20), | |
}, | |
"relatedTo": []interface{}{ | |
search.LocalRef{ | |
Class: "Foo", | |
Fields: map[string]interface{}{ | |
"id": "1", | |
"foo": "bar1", | |
}, | |
}, | |
search.LocalRef{ | |
Class: "Foo", | |
Fields: map[string]interface{}{ | |
"id": "2", | |
"foo": "bar2", | |
}, | |
}, | |
}, | |
}, | |
}, | |
} | |
log, _ := test.NewNullLogger() | |
res, err := New(log).Group(in, "merge", 0.2) | |
require.NotNil(t, err) | |
assert.Nil(t, res) | |
assert.EqualError(t, err, | |
"group 0: merge values: prop 'relatedTo': element 0: "+ | |
"found a search.LocalRef, 'id' field type expected to be strfmt.UUID but got string") | |
} | |