KevinStephenson
Adding in weaviate code
b110593
raw
history blame
5.28 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package test
import (
"testing"
"github.com/weaviate/weaviate/entities/models"
)
func addTestDataDot(t *testing.T) {
createObject(t, &models.Object{
Class: "Dot_Class",
Properties: map[string]interface{}{
"name": "object_1",
},
Vector: []float32{
3, 4, 5, // our base object
},
})
createObject(t, &models.Object{
Class: "Dot_Class",
Properties: map[string]interface{}{
"name": "object_2",
},
Vector: []float32{
1, 1, 1, // a length-one vector
},
})
createObject(t, &models.Object{
Class: "Dot_Class",
Properties: map[string]interface{}{
"name": "object_3",
},
Vector: []float32{
0, 0, 0, // a zero vecto
},
})
createObject(t, &models.Object{
Class: "Dot_Class",
Properties: map[string]interface{}{
"name": "object_2",
},
Vector: []float32{
-3, -4, -5, // negative of the base vector
},
})
}
func testDot(t *testing.T) {
t.Run("without any limiting distance", func(t *testing.T) {
res := AssertGraphQL(t, nil, `
{
Get{
Dot_Class(nearVector:{vector: [3,4,5]}){
name
_additional{distance}
}
}
}
`)
results := res.Get("Get", "Dot_Class").AsSlice()
expectedDistances := []float32{
-50, // the same vector as the query
-12, // the same angle as the query vector,
0, // the vector in between,
50, // the negative of the query vec
}
compareDistances(t, expectedDistances, results)
})
t.Run("with a specified certainty arg - should error", func(t *testing.T) {
ErrorGraphQL(t, nil, `
{
Get{
Dot_Class(nearVector:{certainty: 0.7, vector: [3,4,5]}){
name
_additional{distance}
}
}
}
`)
})
t.Run("with a specified certainty prop - should error", func(t *testing.T) {
ErrorGraphQL(t, nil, `
{
Get{
Dot_Class(nearVector:{distance: 0.7, vector: [3,4,5]}){
name
_additional{certainty}
}
}
}
`)
})
t.Run("with a max distancer higher than all results, should contain all elements", func(t *testing.T) {
res := AssertGraphQL(t, nil, `
{
Get{
Dot_Class(nearVector:{distance: 50, vector: [3,4,5]}){
name
_additional{distance}
}
}
}
`)
results := res.Get("Get", "Dot_Class").AsSlice()
expectedDistances := []float32{
-50, // the same vector as the query
-12, // the same angle as the query vector,
0, // the vector in between,
50, // the negative of the query vec
}
compareDistances(t, expectedDistances, results)
})
t.Run("with a positive max distance that does not match all results, should contain 3 elems", func(t *testing.T) {
res := AssertGraphQL(t, nil, `
{
Get{
Dot_Class(nearVector:{distance: 30, vector: [3,4,5]}){
name
_additional{distance}
}
}
}
`)
results := res.Get("Get", "Dot_Class").AsSlice()
expectedDistances := []float32{
-50, // the same vector as the query
-12, // the same angle as the query vector,
0, // the vector in between,
// the last one is not contained as it would have a distance of 50, which is > 30
}
compareDistances(t, expectedDistances, results)
})
t.Run("with distance 0, should contain 3 elems", func(t *testing.T) {
res := AssertGraphQL(t, nil, `
{
Get{
Dot_Class(nearVector:{distance: 0, vector: [3,4,5]}){
name
_additional{distance}
}
}
}
`)
results := res.Get("Get", "Dot_Class").AsSlice()
expectedDistances := []float32{
-50, // the same vector as the query
-12, // the same angle as the query vector,
0, // the vector in between,
// the last one is not contained as it would have a distance of 50, which is > 0
}
compareDistances(t, expectedDistances, results)
})
t.Run("with a negative distance that should only leave the first element", func(t *testing.T) {
res := AssertGraphQL(t, nil, `
{
Get{
Dot_Class(nearVector:{distance: -40, vector: [3,4,5]}){
name
_additional{distance}
}
}
}
`)
results := res.Get("Get", "Dot_Class").AsSlice()
expectedDistances := []float32{
-50, // the same vector as the query
// the second element's distance would be -12 which is > -40
// the third element's distance would be 0 which is > -40
// the last one is not contained as it would have a distance of 50, which is > 0
}
compareDistances(t, expectedDistances, results)
})
t.Run("with a distance so small that no element should be left", func(t *testing.T) {
res := AssertGraphQL(t, nil, `
{
Get{
Dot_Class(nearVector:{distance: -60, vector: [3,4,5]}){
name
_additional{distance}
}
}
}
`)
results := res.Get("Get", "Dot_Class").AsSlice()
expectedDistances := []float32{
// all elements have a distance > -60, so nothing matches
}
compareDistances(t, expectedDistances, results)
})
}