Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package objects | |
import ( | |
"context" | |
"errors" | |
"testing" | |
"github.com/stretchr/testify/assert" | |
"github.com/weaviate/weaviate/entities/additional" | |
"github.com/weaviate/weaviate/entities/models" | |
"github.com/weaviate/weaviate/entities/schema" | |
"github.com/weaviate/weaviate/entities/search" | |
) | |
func TestQuery(t *testing.T) { | |
t.Parallel() | |
var ( | |
cls = "MyClass" | |
m = newFakeGetManager(schema.Schema{}) | |
errAny = errors.New("any") | |
) | |
params := QueryParams{ | |
Class: cls, | |
Limit: ptInt64(10), | |
} | |
inputs := QueryInput{ | |
Class: cls, | |
Limit: 10, | |
} | |
tests := []struct { | |
class string | |
name string | |
param QueryParams | |
mockedErr *Error | |
authErr error | |
lockErr error | |
wantCode int | |
mockedDBResponse []search.Result | |
wantResponse []*models.Object | |
wantQueryInput QueryInput | |
wantUsageTracking bool | |
}{ | |
{ | |
name: "not found", | |
class: cls, | |
param: params, | |
mockedErr: &Error{Code: StatusNotFound}, | |
wantCode: StatusNotFound, | |
wantQueryInput: inputs, | |
}, | |
{ | |
name: "forbidden", | |
class: cls, | |
param: params, | |
authErr: errAny, | |
wantCode: StatusForbidden, | |
wantQueryInput: inputs, | |
}, | |
{ | |
name: "internal error", | |
class: cls, | |
param: params, | |
lockErr: errAny, | |
wantCode: StatusInternalServerError, | |
wantQueryInput: inputs, | |
}, | |
{ | |
name: "happy path", | |
class: cls, | |
param: params, | |
mockedDBResponse: []search.Result{ | |
{ | |
ClassName: cls, | |
Schema: map[string]interface{}{ | |
"foo": "bar", | |
}, | |
Dims: 3, | |
Dist: 0, | |
}, | |
}, | |
wantResponse: []*models.Object{{ | |
Class: cls, | |
VectorWeights: map[string]string(nil), | |
Properties: map[string]interface{}{ | |
"foo": "bar", | |
}, | |
}}, | |
wantQueryInput: inputs, | |
}, | |
{ | |
name: "happy path with explicit vector requested", | |
class: cls, | |
param: QueryParams{ | |
Class: cls, | |
Limit: ptInt64(10), | |
Additional: additional.Properties{Vector: true}, | |
}, | |
mockedDBResponse: []search.Result{ | |
{ | |
ClassName: cls, | |
Schema: map[string]interface{}{ | |
"foo": "bar", | |
}, | |
Dims: 3, | |
}, | |
}, | |
wantResponse: []*models.Object{{ | |
Class: cls, | |
VectorWeights: map[string]string(nil), | |
Properties: map[string]interface{}{ | |
"foo": "bar", | |
}, | |
}}, | |
wantQueryInput: QueryInput{ | |
Class: cls, | |
Limit: 10, | |
Additional: additional.Properties{Vector: true}, | |
}, | |
wantUsageTracking: true, | |
}, | |
{ | |
name: "bad request", | |
class: cls, | |
param: QueryParams{Class: cls, Offset: ptInt64(1), Limit: &m.config.Config.QueryMaximumResults}, | |
wantCode: StatusBadRequest, | |
wantQueryInput: inputs, | |
}, | |
} | |
for i, tc := range tests { | |
t.Run(tc.name, func(t *testing.T) { | |
m.authorizer.Err = tc.authErr | |
m.locks.Err = tc.lockErr | |
if tc.authErr == nil && tc.lockErr == nil { | |
m.repo.On("Query", &tc.wantQueryInput).Return(tc.mockedDBResponse, tc.mockedErr).Once() | |
} | |
if tc.wantUsageTracking { | |
m.metrics.On("AddUsageDimensions", cls, "get_rest", "list_include_vector", | |
tc.mockedDBResponse[0].Dims) | |
} | |
res, err := m.Manager.Query(context.Background(), nil, &tc.param) | |
code := 0 | |
if err != nil { | |
code = err.Code | |
} | |
if tc.wantCode != code { | |
t.Errorf("case %d expected:%v got:%v", i+1, tc.wantCode, code) | |
} | |
assert.Equal(t, tc.wantResponse, res) | |
}) | |
} | |
} | |