KevinStephenson
Adding in weaviate code
b110593
raw
history blame
4.06 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package objects
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/weaviate/weaviate/entities/additional"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/schema"
"github.com/weaviate/weaviate/entities/search"
)
func TestQuery(t *testing.T) {
t.Parallel()
var (
cls = "MyClass"
m = newFakeGetManager(schema.Schema{})
errAny = errors.New("any")
)
params := QueryParams{
Class: cls,
Limit: ptInt64(10),
}
inputs := QueryInput{
Class: cls,
Limit: 10,
}
tests := []struct {
class string
name string
param QueryParams
mockedErr *Error
authErr error
lockErr error
wantCode int
mockedDBResponse []search.Result
wantResponse []*models.Object
wantQueryInput QueryInput
wantUsageTracking bool
}{
{
name: "not found",
class: cls,
param: params,
mockedErr: &Error{Code: StatusNotFound},
wantCode: StatusNotFound,
wantQueryInput: inputs,
},
{
name: "forbidden",
class: cls,
param: params,
authErr: errAny,
wantCode: StatusForbidden,
wantQueryInput: inputs,
},
{
name: "internal error",
class: cls,
param: params,
lockErr: errAny,
wantCode: StatusInternalServerError,
wantQueryInput: inputs,
},
{
name: "happy path",
class: cls,
param: params,
mockedDBResponse: []search.Result{
{
ClassName: cls,
Schema: map[string]interface{}{
"foo": "bar",
},
Dims: 3,
Dist: 0,
},
},
wantResponse: []*models.Object{{
Class: cls,
VectorWeights: map[string]string(nil),
Properties: map[string]interface{}{
"foo": "bar",
},
}},
wantQueryInput: inputs,
},
{
name: "happy path with explicit vector requested",
class: cls,
param: QueryParams{
Class: cls,
Limit: ptInt64(10),
Additional: additional.Properties{Vector: true},
},
mockedDBResponse: []search.Result{
{
ClassName: cls,
Schema: map[string]interface{}{
"foo": "bar",
},
Dims: 3,
},
},
wantResponse: []*models.Object{{
Class: cls,
VectorWeights: map[string]string(nil),
Properties: map[string]interface{}{
"foo": "bar",
},
}},
wantQueryInput: QueryInput{
Class: cls,
Limit: 10,
Additional: additional.Properties{Vector: true},
},
wantUsageTracking: true,
},
{
name: "bad request",
class: cls,
param: QueryParams{Class: cls, Offset: ptInt64(1), Limit: &m.config.Config.QueryMaximumResults},
wantCode: StatusBadRequest,
wantQueryInput: inputs,
},
}
for i, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
m.authorizer.Err = tc.authErr
m.locks.Err = tc.lockErr
if tc.authErr == nil && tc.lockErr == nil {
m.repo.On("Query", &tc.wantQueryInput).Return(tc.mockedDBResponse, tc.mockedErr).Once()
}
if tc.wantUsageTracking {
m.metrics.On("AddUsageDimensions", cls, "get_rest", "list_include_vector",
tc.mockedDBResponse[0].Dims)
}
res, err := m.Manager.Query(context.Background(), nil, &tc.param)
code := 0
if err != nil {
code = err.Code
}
if tc.wantCode != code {
t.Errorf("case %d expected:%v got:%v", i+1, tc.wantCode, code)
}
assert.Equal(t, tc.wantResponse, res)
})
}
}