SemanticSearchPOC / usecases /cluster /transactions_broadcast_test.go
KevinStephenson
Adding in weaviate code
b110593
raw
history blame
5.58 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package cluster
import (
"context"
"fmt"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBroadcastOpenTransaction(t *testing.T) {
client := &fakeClient{}
state := &fakeState{hosts: []string{"host1", "host2", "host3"}}
bc := NewTxBroadcaster(state, client)
tx := &Transaction{ID: "foo"}
err := bc.BroadcastTransaction(context.Background(), tx)
require.Nil(t, err)
assert.ElementsMatch(t, []string{"host1", "host2", "host3"}, client.openCalled)
}
func TestBroadcastOpenTransactionWithReturnPayload(t *testing.T) {
client := &fakeClient{}
state := &fakeState{hosts: []string{"host1", "host2", "host3"}}
bc := NewTxBroadcaster(state, client)
bc.SetConsensusFunction(func(ctx context.Context,
in []*Transaction,
) (*Transaction, error) {
// instead of actually reaching a consensus this test mock simply merged
// all the individual results. For testing purposes this is even better
// because now we can be sure that every element was considered.
merged := ""
for _, tx := range in {
if len(merged) > 0 {
merged += ","
}
merged += tx.Payload.(string)
}
return &Transaction{
Payload: merged,
}, nil
})
tx := &Transaction{ID: "foo"}
err := bc.BroadcastTransaction(context.Background(), tx)
require.Nil(t, err)
assert.ElementsMatch(t, []string{"host1", "host2", "host3"}, client.openCalled)
results := strings.Split(tx.Payload.(string), ",")
assert.ElementsMatch(t, []string{
"hello_from_host1",
"hello_from_host2",
"hello_from_host3",
}, results)
}
func TestBroadcastOpenTransactionAfterNodeHasDied(t *testing.T) {
client := &fakeClient{}
state := &fakeState{hosts: []string{"host1", "host2", "host3"}}
bc := NewTxBroadcaster(state, client)
waitUntilIdealStateHasReached(t, bc, 3, 4*time.Second)
// host2 is dead
state.updateHosts([]string{"host1", "host3"})
tx := &Transaction{ID: "foo"}
err := bc.BroadcastTransaction(context.Background(), tx)
require.NotNil(t, err)
assert.Contains(t, err.Error(), "host2")
// no node is should have received an open
assert.ElementsMatch(t, []string{}, client.openCalled)
}
func waitUntilIdealStateHasReached(t *testing.T, bc *TxBroadcaster, goal int,
max time.Duration,
) {
ctx, cancel := context.WithTimeout(context.Background(), max)
defer cancel()
interval := time.NewTicker(250 * time.Millisecond)
defer interval.Stop()
for {
select {
case <-ctx.Done():
t.Error(fmt.Errorf("waiting to reach state goal %d: %w", goal, ctx.Err()))
return
case <-interval.C:
if len(bc.ideal.Members()) == goal {
return
}
}
}
}
func TestBroadcastAbortTransaction(t *testing.T) {
client := &fakeClient{}
state := &fakeState{hosts: []string{"host1", "host2", "host3"}}
bc := NewTxBroadcaster(state, client)
tx := &Transaction{ID: "foo"}
err := bc.BroadcastAbortTransaction(context.Background(), tx)
require.Nil(t, err)
assert.ElementsMatch(t, []string{"host1", "host2", "host3"}, client.abortCalled)
}
func TestBroadcastCommitTransaction(t *testing.T) {
client := &fakeClient{}
state := &fakeState{hosts: []string{"host1", "host2", "host3"}}
bc := NewTxBroadcaster(state, client)
tx := &Transaction{ID: "foo"}
err := bc.BroadcastCommitTransaction(context.Background(), tx)
require.Nil(t, err)
assert.ElementsMatch(t, []string{"host1", "host2", "host3"}, client.commitCalled)
}
func TestBroadcastCommitTransactionAfterNodeHasDied(t *testing.T) {
client := &fakeClient{}
state := &fakeState{hosts: []string{"host1", "host2", "host3"}}
bc := NewTxBroadcaster(state, client)
waitUntilIdealStateHasReached(t, bc, 3, 4*time.Second)
state.updateHosts([]string{"host1", "host3"})
tx := &Transaction{ID: "foo"}
err := bc.BroadcastCommitTransaction(context.Background(), tx)
require.NotNil(t, err)
assert.Contains(t, err.Error(), "host2")
// no node should have received the commit
assert.ElementsMatch(t, []string{}, client.commitCalled)
}
type fakeState struct {
hosts []string
sync.Mutex
}
func (f *fakeState) updateHosts(newHosts []string) {
f.Lock()
defer f.Unlock()
f.hosts = newHosts
}
func (f *fakeState) Hostnames() []string {
f.Lock()
defer f.Unlock()
return f.hosts
}
func (f *fakeState) AllNames() []string {
f.Lock()
defer f.Unlock()
return f.hosts
}
type fakeClient struct {
sync.Mutex
openCalled []string
abortCalled []string
commitCalled []string
}
func (f *fakeClient) OpenTransaction(ctx context.Context, host string, tx *Transaction) error {
f.Lock()
defer f.Unlock()
f.openCalled = append(f.openCalled, host)
tx.Payload = "hello_from_" + host
return nil
}
func (f *fakeClient) AbortTransaction(ctx context.Context, host string, tx *Transaction) error {
f.Lock()
defer f.Unlock()
f.abortCalled = append(f.abortCalled, host)
return nil
}
func (f *fakeClient) CommitTransaction(ctx context.Context, host string, tx *Transaction) error {
f.Lock()
defer f.Unlock()
f.commitCalled = append(f.commitCalled, host)
return nil
}