Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package clients | |
import ( | |
"context" | |
"encoding/json" | |
"net/http" | |
"net/http/httptest" | |
"testing" | |
"time" | |
"github.com/go-openapi/strfmt" | |
"github.com/stretchr/testify/assert" | |
"github.com/stretchr/testify/require" | |
"github.com/weaviate/weaviate/entities/additional" | |
"github.com/weaviate/weaviate/entities/models" | |
"github.com/weaviate/weaviate/entities/storobj" | |
"github.com/weaviate/weaviate/usecases/objects" | |
"github.com/weaviate/weaviate/usecases/replica" | |
) | |
const ( | |
RequestError = "RIDNotFound" | |
RequestSuccess = "RIDSuccess" | |
RequestInternalError = "RIDInternal" | |
RequestMalFormedResponse = "RIDMalFormed" | |
) | |
const ( | |
UUID1 = strfmt.UUID("73f2eb5f-5abf-447a-81ca-74b1dd168241") | |
UUID2 = strfmt.UUID("73f2eb5f-5abf-447a-81ca-74b1dd168242") | |
) | |
type fakeServer struct { | |
method string | |
path string | |
RequestError replica.SimpleResponse | |
RequestSuccess replica.SimpleResponse | |
host string | |
} | |
func newFakeReplicationServer(t *testing.T, method, path string) *fakeServer { | |
return &fakeServer{ | |
method: method, | |
path: path, | |
RequestError: replica.SimpleResponse{Errors: []replica.Error{{Msg: "error"}}}, | |
RequestSuccess: replica.SimpleResponse{}, | |
} | |
} | |
func (f *fakeServer) server(t *testing.T) *httptest.Server { | |
handler := func(w http.ResponseWriter, r *http.Request) { | |
if r.Method != f.method { | |
t.Errorf("method want %s got %s", f.method, r.Method) | |
w.WriteHeader(http.StatusBadRequest) | |
return | |
} | |
if f.path != r.URL.Path { | |
t.Errorf("path want %s got %s", f.path, r.URL.Path) | |
w.WriteHeader(http.StatusBadRequest) | |
return | |
} | |
requestID := r.URL.Query().Get(replica.RequestKey) | |
switch requestID { | |
case RequestInternalError: | |
w.WriteHeader(http.StatusInternalServerError) | |
case RequestError: | |
bytes, _ := json.Marshal(&f.RequestError) | |
w.Write(bytes) | |
case RequestSuccess: | |
bytes, _ := json.Marshal(&replica.SimpleResponse{}) | |
w.Write(bytes) | |
case RequestMalFormedResponse: | |
w.Write([]byte(`mal formed`)) | |
} | |
} | |
serv := httptest.NewServer(http.HandlerFunc(handler)) | |
f.host = serv.URL[7:] | |
return serv | |
} | |
func anyObject(uuid strfmt.UUID) models.Object { | |
return models.Object{ | |
Class: "C1", | |
CreationTimeUnix: 900000000001, | |
LastUpdateTimeUnix: 900000000002, | |
ID: uuid, | |
Properties: map[string]interface{}{ | |
"stringProp": "string", | |
"textProp": "text", | |
"datePropArray": []string{"1980-01-01T00:00:00+02:00"}, | |
}, | |
} | |
} | |
func TestReplicationPutObject(t *testing.T) { | |
t.Parallel() | |
ctx := context.Background() | |
f := newFakeReplicationServer(t, http.MethodPost, "/replicas/indices/C1/shards/S1/objects") | |
ts := f.server(t) | |
defer ts.Close() | |
client := newReplicationClient(ts.Client()) | |
t.Run("EncodeRequest", func(t *testing.T) { | |
obj := &storobj.Object{} | |
_, err := client.PutObject(ctx, "Node1", "C1", "S1", "RID", obj) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "encode") | |
}) | |
obj := &storobj.Object{MarshallerVersion: 1, Object: anyObject(UUID1)} | |
t.Run("ConnectionError", func(t *testing.T) { | |
_, err := client.PutObject(ctx, "", "C1", "S1", "", obj) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "connect") | |
}) | |
t.Run("Error", func(t *testing.T) { | |
resp, err := client.PutObject(ctx, f.host, "C1", "S1", RequestError, obj) | |
assert.Nil(t, err) | |
assert.Equal(t, replica.SimpleResponse{Errors: f.RequestError.Errors}, resp) | |
}) | |
t.Run("DecodeResponse", func(t *testing.T) { | |
_, err := client.PutObject(ctx, f.host, "C1", "S1", RequestMalFormedResponse, obj) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "decode response") | |
}) | |
t.Run("ServerInternalError", func(t *testing.T) { | |
_, err := client.PutObject(ctx, f.host, "C1", "S1", RequestInternalError, obj) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "status code") | |
}) | |
} | |
func TestReplicationDeleteObject(t *testing.T) { | |
t.Parallel() | |
ctx := context.Background() | |
uuid := UUID1 | |
path := "/replicas/indices/C1/shards/S1/objects/" + uuid.String() | |
fs := newFakeReplicationServer(t, http.MethodDelete, path) | |
ts := fs.server(t) | |
defer ts.Close() | |
client := newReplicationClient(ts.Client()) | |
t.Run("ConnectionError", func(t *testing.T) { | |
_, err := client.DeleteObject(ctx, "", "C1", "S1", "", uuid) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "connect") | |
}) | |
t.Run("Error", func(t *testing.T) { | |
resp, err := client.DeleteObject(ctx, fs.host, "C1", "S1", RequestError, uuid) | |
assert.Nil(t, err) | |
assert.Equal(t, replica.SimpleResponse{Errors: fs.RequestError.Errors}, resp) | |
}) | |
t.Run("DecodeResponse", func(t *testing.T) { | |
_, err := client.DeleteObject(ctx, fs.host, "C1", "S1", RequestMalFormedResponse, uuid) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "decode response") | |
}) | |
t.Run("ServerInternalError", func(t *testing.T) { | |
_, err := client.DeleteObject(ctx, fs.host, "C1", "S1", RequestInternalError, uuid) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "status code") | |
}) | |
} | |
func TestReplicationPutObjects(t *testing.T) { | |
t.Parallel() | |
ctx := context.Background() | |
fs := newFakeReplicationServer(t, http.MethodPost, "/replicas/indices/C1/shards/S1/objects") | |
fs.RequestError.Errors = append(fs.RequestError.Errors, replica.Error{Msg: "error2"}) | |
ts := fs.server(t) | |
defer ts.Close() | |
client := newReplicationClient(ts.Client()) | |
t.Run("EncodeRequest", func(t *testing.T) { | |
objs := []*storobj.Object{{}} | |
_, err := client.PutObjects(ctx, "Node1", "C1", "S1", "RID", objs) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "encode") | |
}) | |
objects := []*storobj.Object{ | |
{MarshallerVersion: 1, Object: anyObject(UUID1)}, | |
{MarshallerVersion: 1, Object: anyObject(UUID2)}, | |
} | |
t.Run("ConnectionError", func(t *testing.T) { | |
_, err := client.PutObjects(ctx, "", "C1", "S1", "", objects) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "connect") | |
}) | |
t.Run("Error", func(t *testing.T) { | |
resp, err := client.PutObjects(ctx, fs.host, "C1", "S1", RequestError, objects) | |
assert.Nil(t, err) | |
assert.Equal(t, replica.SimpleResponse{Errors: fs.RequestError.Errors}, resp) | |
}) | |
t.Run("DecodeResponse", func(t *testing.T) { | |
_, err := client.PutObjects(ctx, fs.host, "C1", "S1", RequestMalFormedResponse, objects) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "decode response") | |
}) | |
t.Run("ServerInternalError", func(t *testing.T) { | |
_, err := client.PutObjects(ctx, fs.host, "C1", "S1", RequestInternalError, objects) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "status code") | |
}) | |
} | |
func TestReplicationMergeObject(t *testing.T) { | |
t.Parallel() | |
ctx := context.Background() | |
uuid := UUID1 | |
f := newFakeReplicationServer(t, http.MethodPatch, "/replicas/indices/C1/shards/S1/objects/"+uuid.String()) | |
ts := f.server(t) | |
defer ts.Close() | |
client := newReplicationClient(ts.Client()) | |
doc := &objects.MergeDocument{ID: uuid} | |
t.Run("ConnectionError", func(t *testing.T) { | |
_, err := client.MergeObject(ctx, "", "C1", "S1", "", doc) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "connect") | |
}) | |
t.Run("Error", func(t *testing.T) { | |
resp, err := client.MergeObject(ctx, f.host, "C1", "S1", RequestError, doc) | |
assert.Nil(t, err) | |
assert.Equal(t, replica.SimpleResponse{Errors: f.RequestError.Errors}, resp) | |
}) | |
t.Run("DecodeResponse", func(t *testing.T) { | |
_, err := client.MergeObject(ctx, f.host, "C1", "S1", RequestMalFormedResponse, doc) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "decode response") | |
}) | |
t.Run("ServerInternalError", func(t *testing.T) { | |
_, err := client.MergeObject(ctx, f.host, "C1", "S1", RequestInternalError, doc) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "status code") | |
}) | |
} | |
func TestReplicationAddReferences(t *testing.T) { | |
t.Parallel() | |
ctx := context.Background() | |
fs := newFakeReplicationServer(t, http.MethodPost, "/replicas/indices/C1/shards/S1/objects/references") | |
fs.RequestError.Errors = append(fs.RequestError.Errors, replica.Error{Msg: "error2"}) | |
ts := fs.server(t) | |
defer ts.Close() | |
client := newReplicationClient(ts.Client()) | |
refs := []objects.BatchReference{{OriginalIndex: 1}, {OriginalIndex: 2}} | |
t.Run("ConnectionError", func(t *testing.T) { | |
_, err := client.AddReferences(ctx, "", "C1", "S1", "", refs) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "connect") | |
}) | |
t.Run("Error", func(t *testing.T) { | |
resp, err := client.AddReferences(ctx, fs.host, "C1", "S1", RequestError, refs) | |
assert.Nil(t, err) | |
assert.Equal(t, replica.SimpleResponse{Errors: fs.RequestError.Errors}, resp) | |
}) | |
t.Run("DecodeResponse", func(t *testing.T) { | |
_, err := client.AddReferences(ctx, fs.host, "C1", "S1", RequestMalFormedResponse, refs) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "decode response") | |
}) | |
t.Run("ServerInternalError", func(t *testing.T) { | |
_, err := client.AddReferences(ctx, fs.host, "C1", "S1", RequestInternalError, refs) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "status code") | |
}) | |
} | |
func TestReplicationDeleteObjects(t *testing.T) { | |
t.Parallel() | |
ctx := context.Background() | |
fs := newFakeReplicationServer(t, http.MethodDelete, "/replicas/indices/C1/shards/S1/objects") | |
fs.RequestError.Errors = append(fs.RequestError.Errors, replica.Error{Msg: "error2"}) | |
ts := fs.server(t) | |
defer ts.Close() | |
client := newReplicationClient(ts.Client()) | |
uuids := []strfmt.UUID{strfmt.UUID("1"), strfmt.UUID("2")} | |
t.Run("ConnectionError", func(t *testing.T) { | |
_, err := client.DeleteObjects(ctx, "", "C1", "S1", "", uuids, false) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "connect") | |
}) | |
t.Run("Error", func(t *testing.T) { | |
resp, err := client.DeleteObjects(ctx, fs.host, "C1", "S1", RequestError, uuids, false) | |
assert.Nil(t, err) | |
assert.Equal(t, replica.SimpleResponse{Errors: fs.RequestError.Errors}, resp) | |
}) | |
t.Run("DecodeResponse", func(t *testing.T) { | |
_, err := client.DeleteObjects(ctx, fs.host, "C1", "S1", RequestMalFormedResponse, uuids, false) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "decode response") | |
}) | |
t.Run("ServerInternalError", func(t *testing.T) { | |
_, err := client.DeleteObjects(ctx, fs.host, "C1", "S1", RequestInternalError, uuids, false) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "status code") | |
}) | |
} | |
func TestReplicationAbort(t *testing.T) { | |
t.Parallel() | |
ctx := context.Background() | |
path := "/replicas/indices/C1/shards/S1:abort" | |
fs := newFakeReplicationServer(t, http.MethodPost, path) | |
ts := fs.server(t) | |
defer ts.Close() | |
client := newReplicationClient(ts.Client()) | |
t.Run("ConnectionError", func(t *testing.T) { | |
client := newReplicationClient(ts.Client()) | |
client.maxBackOff = client.timeoutUnit * 20 | |
_, err := client.Abort(ctx, "", "C1", "S1", "") | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "connect") | |
}) | |
t.Run("Error", func(t *testing.T) { | |
resp, err := client.Abort(ctx, fs.host, "C1", "S1", RequestError) | |
assert.Nil(t, err) | |
assert.Equal(t, replica.SimpleResponse{Errors: fs.RequestError.Errors}, resp) | |
}) | |
t.Run("DecodeResponse", func(t *testing.T) { | |
_, err := client.Abort(ctx, fs.host, "C1", "S1", RequestMalFormedResponse) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "decode response") | |
}) | |
client.timeoutUnit = client.maxBackOff * 3 | |
t.Run("ServerInternalError", func(t *testing.T) { | |
_, err := client.Abort(ctx, fs.host, "C1", "S1", RequestInternalError) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "status code") | |
}) | |
} | |
func TestReplicationCommit(t *testing.T) { | |
t.Parallel() | |
ctx := context.Background() | |
path := "/replicas/indices/C1/shards/S1:commit" | |
fs := newFakeReplicationServer(t, http.MethodPost, path) | |
ts := fs.server(t) | |
defer ts.Close() | |
resp := replica.SimpleResponse{} | |
client := newReplicationClient(ts.Client()) | |
t.Run("ConnectionError", func(t *testing.T) { | |
err := client.Commit(ctx, "", "C1", "S1", "", &resp) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "connect") | |
}) | |
t.Run("Error", func(t *testing.T) { | |
err := client.Commit(ctx, fs.host, "C1", "S1", RequestError, &resp) | |
assert.Nil(t, err) | |
assert.Equal(t, replica.SimpleResponse{Errors: fs.RequestError.Errors}, resp) | |
}) | |
t.Run("DecodeResponse", func(t *testing.T) { | |
err := client.Commit(ctx, fs.host, "C1", "S1", RequestMalFormedResponse, &resp) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "decode response") | |
}) | |
t.Run("ServerInternalError", func(t *testing.T) { | |
err := client.Commit(ctx, fs.host, "C1", "S1", RequestInternalError, &resp) | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "status code") | |
}) | |
} | |
func TestReplicationFetchObject(t *testing.T) { | |
t.Parallel() | |
expected := objects.Replica{ | |
ID: UUID1, | |
Object: &storobj.Object{ | |
MarshallerVersion: 1, | |
Object: models.Object{ | |
ID: UUID1, | |
Properties: map[string]interface{}{ | |
"stringProp": "abc", | |
}, | |
}, | |
Vector: []float32{1, 2, 3, 4, 5}, | |
VectorLen: 5, | |
}, | |
} | |
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
b, _ := expected.MarshalBinary() | |
w.Write(b) | |
})) | |
c := newReplicationClient(server.Client()) | |
resp, err := c.FetchObject(context.Background(), server.URL[7:], | |
"C1", "S1", expected.ID, nil, additional.Properties{}) | |
require.Nil(t, err) | |
assert.Equal(t, expected.ID, resp.ID) | |
assert.Equal(t, expected.Deleted, resp.Deleted) | |
assert.EqualValues(t, expected.Object, resp.Object) | |
} | |
func TestReplicationFetchObjects(t *testing.T) { | |
t.Parallel() | |
expected := objects.Replicas{ | |
{ | |
ID: UUID1, | |
Object: &storobj.Object{ | |
MarshallerVersion: 1, | |
Object: models.Object{ | |
ID: UUID1, | |
Properties: map[string]interface{}{ | |
"stringProp": "abc", | |
}, | |
}, | |
Vector: []float32{1, 2, 3, 4, 5}, | |
VectorLen: 5, | |
}, | |
}, | |
{ | |
ID: UUID2, | |
Object: &storobj.Object{ | |
MarshallerVersion: 1, | |
Object: models.Object{ | |
ID: UUID2, | |
Properties: map[string]interface{}{ | |
"floatProp": float64(123), | |
}, | |
}, | |
Vector: []float32{10, 20, 30, 40, 50}, | |
VectorLen: 5, | |
}, | |
}, | |
} | |
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
b, _ := expected.MarshalBinary() | |
w.Write(b) | |
})) | |
c := newReplicationClient(server.Client()) | |
resp, err := c.FetchObjects(context.Background(), server.URL[7:], "C1", "S1", []strfmt.UUID{expected[0].ID}) | |
require.Nil(t, err) | |
require.Len(t, resp, 2) | |
assert.Equal(t, expected[0].ID, resp[0].ID) | |
assert.Equal(t, expected[0].Deleted, resp[0].Deleted) | |
assert.EqualValues(t, expected[0].Object, resp[0].Object) | |
assert.Equal(t, expected[1].ID, resp[1].ID) | |
assert.Equal(t, expected[1].Deleted, resp[1].Deleted) | |
assert.EqualValues(t, expected[1].Object, resp[1].Object) | |
} | |
func TestReplicationDigestObjects(t *testing.T) { | |
t.Parallel() | |
now := time.Now() | |
expected := []replica.RepairResponse{ | |
{ | |
ID: UUID1.String(), | |
UpdateTime: now.UnixMilli(), | |
Version: 1, | |
}, | |
{ | |
ID: UUID2.String(), | |
UpdateTime: now.UnixMilli(), | |
Version: 1, | |
}, | |
} | |
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
b, _ := json.Marshal(expected) | |
w.Write(b) | |
})) | |
c := newReplicationClient(server.Client()) | |
resp, err := c.DigestObjects(context.Background(), server.URL[7:], "C1", "S1", []strfmt.UUID{ | |
strfmt.UUID(expected[0].ID), | |
strfmt.UUID(expected[1].ID), | |
}) | |
require.Nil(t, err) | |
require.Len(t, resp, 2) | |
assert.Equal(t, expected[0].ID, resp[0].ID) | |
assert.Equal(t, expected[0].Deleted, resp[0].Deleted) | |
assert.Equal(t, expected[0].UpdateTime, resp[0].UpdateTime) | |
assert.Equal(t, expected[0].Version, resp[0].Version) | |
assert.Equal(t, expected[1].ID, resp[1].ID) | |
assert.Equal(t, expected[1].Deleted, resp[1].Deleted) | |
assert.Equal(t, expected[1].UpdateTime, resp[1].UpdateTime) | |
assert.Equal(t, expected[1].Version, resp[1].Version) | |
} | |
func TestReplicationOverwriteObjects(t *testing.T) { | |
t.Parallel() | |
now := time.Now() | |
input := []*objects.VObject{ | |
{ | |
LatestObject: &models.Object{ | |
ID: UUID1, | |
Class: "C1", | |
CreationTimeUnix: now.UnixMilli(), | |
LastUpdateTimeUnix: now.Add(time.Hour).UnixMilli(), | |
Properties: map[string]interface{}{ | |
"stringProp": "abc", | |
}, | |
Vector: []float32{1, 2, 3, 4, 5}, | |
}, | |
StaleUpdateTime: now.UnixMilli(), | |
Version: 0, | |
}, | |
} | |
expected := []replica.RepairResponse{ | |
{ | |
ID: UUID1.String(), | |
Version: 1, | |
UpdateTime: now.Add(time.Hour).UnixMilli(), | |
}, | |
} | |
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
b, _ := json.Marshal(expected) | |
w.Write(b) | |
})) | |
c := newReplicationClient(server.Client()) | |
resp, err := c.OverwriteObjects(context.Background(), server.URL[7:], "C1", "S1", input) | |
require.Nil(t, err) | |
require.Len(t, resp, 1) | |
assert.Equal(t, expected[0].ID, resp[0].ID) | |
assert.Equal(t, expected[0].Version, resp[0].Version) | |
assert.Equal(t, expected[0].UpdateTime, resp[0].UpdateTime) | |
} | |
func TestExpBackOff(t *testing.T) { | |
N := 200 | |
av := time.Duration(0) | |
delay := time.Nanosecond * 20 | |
for i := 0; i < N; i++ { | |
av += backOff(delay) | |
} | |
av /= time.Duration(N) | |
if av < time.Nanosecond*30 || av > time.Nanosecond*50 { | |
t.Errorf("average time got %v", av) | |
} | |
} | |
func newReplicationClient(httpClient *http.Client) *replicationClient { | |
c := NewReplicationClient(httpClient).(*replicationClient) | |
c.minBackOff = time.Millisecond * 1 | |
c.maxBackOff = time.Millisecond * 8 | |
c.timeoutUnit = time.Millisecond * 20 | |
return c | |
} | |