From ce038b04dfceb2203e6275d8ec09a2b41c5f1e9e Mon Sep 17 00:00:00 2001 From: Alexander Chuprov Date: Mon, 8 Apr 2024 20:06:55 +0300 Subject: [PATCH] [#30] pool: Add pool Update Signed-off-by: Alexander Chuprov --- pool/pool.go | 128 +++++++++++++++++++++++++++++++---- pool/pool_test.go | 165 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 279 insertions(+), 14 deletions(-) diff --git a/pool/pool.go b/pool/pool.go index 30c5219..9803aa6 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -1871,9 +1871,19 @@ func (p *Pool) DeleteObject(ctx context.Context, prm PrmObjectDelete) error { // See also InitParameters.SetClientRebalanceInterval. func (p *Pool) Dial(ctx context.Context) error { pool := p.pool.Load() - err := pool.Dial(ctx) + if err := pool.Dial(ctx); err != nil { + return err + } + + pool.cancelLock.Lock() + if pool.cancel != nil { + pool.cancel() + } + pool.cancelLock.Unlock() + pool.startRebalance(ctx) - return err + + return nil } // FindSiblingByParentID implements relations.Relations. @@ -2004,6 +2014,27 @@ func (p *Pool) Statistic() Statistic { return p.pool.Load().Statistic() } +// Update is a method that lets you refresh the list of nodes without recreating the pool. +// Use a long-lived context to avoid early rebalance stop. +// Can interrupt an operation being performed on a node that was removed. +// Ensures that: +// 1) Preserved connections would not be closed. +// 2) In the event of an error, the pool remains operational. +func (p *Pool) Update(ctx context.Context, prm []NodeParam) error { + pool := p.pool.Load() + + newPool, equal, err := pool.update(ctx, prm) + if equal || err != nil { + return err + } + + newPool.startRebalance(ctx) + oldPool := p.pool.Swap(newPool) + oldPool.stopRebalance() + + return nil +} + type pool struct { innerPools []*innerPool key *ecdsa.PrivateKey @@ -2073,8 +2104,6 @@ func newPool(options InitParameters) (*pool, error) { } func (p *pool) Dial(ctx context.Context) error { - p.stopRebalance() - err := p.dial(ctx, nil) if err != nil { return err @@ -2142,6 +2171,64 @@ func (p *pool) dial(ctx context.Context, existingClients map[string]client) erro return nil } +func nodesParamEqual(a, b []*nodesParam) bool { + if len(a) != len(b) { + return false + } + + for i, v := range a { + if v.priority != b[i].priority || len(v.addresses) != len(b[i].addresses) { + return false + } + + for j, address := range v.addresses { + if address != b[i].addresses[j] { + return false + } + } + } + + return true +} + +// Update requires that no other parallel operations are executed concurrently on the pool instance. +func (p *pool) update(ctx context.Context, prm []NodeParam) (*pool, bool, error) { + newPool := *p + + existingClients := make(map[string]client) + for i, pool := range newPool.rebalanceParams.nodesParams { + for j := range pool.weights { + existingClients[pool.addresses[j]] = newPool.innerPools[i].clients[j] + } + } + + nodesParams, err := adjustNodeParams(prm) + if err != nil { + return nil, false, err + } + + if nodesParamEqual(newPool.rebalanceParams.nodesParams, nodesParams) { + return nil, true, err + } + + newPool.rebalanceParams.nodesParams = nodesParams + + err = newPool.dial(ctx, existingClients) + if err != nil { + return nil, false, err + } + + // After newPool.dial(ctx, existingClients), existingClients will contain only outdated clients. + // Removing outdated clients + for _, client := range existingClients { + if clientErr := client.close(); clientErr != nil { + err = errors.Join(err, clientErr) + } + } + + return &newPool, false, err +} + func (p *pool) log(level zapcore.Level, msg string, fields ...zap.Field) { if p.logger == nil { return @@ -2195,6 +2282,11 @@ func fillDefaultInitParams(params *InitParameters, cache *sessionCache) { } } +type addressWeightPair struct { + address string + weight float64 +} + func adjustNodeParams(nodeParams []NodeParam) ([]*nodesParam, error) { if len(nodeParams) == 0 { return nil, errors.New("no FrostFS peers configured") @@ -2221,6 +2313,22 @@ func adjustNodeParams(nodeParams []NodeParam) ([]*nodesParam, error) { return nodesParams[i].priority < nodesParams[j].priority }) + for _, nodes := range nodesParams { + addressWeightPairs := make([]addressWeightPair, len(nodes.addresses)) + for i := range nodes.addresses { + addressWeightPairs[i] = addressWeightPair{address: nodes.addresses[i], weight: nodes.weights[i]} + } + + sort.Slice(addressWeightPairs, func(i, j int) bool { + return addressWeightPairs[i].address < addressWeightPairs[j].address + }) + + for i, pair := range addressWeightPairs { + nodes.addresses[i] = pair.address + nodes.weights[i] = pair.weight + } + } + return nodesParams, nil } @@ -2228,12 +2336,6 @@ func (p *pool) startRebalance(ctx context.Context) { p.cancelLock.Lock() defer p.cancelLock.Unlock() - // stop rebalance - if p.cancel != nil { - p.cancel() - <-p.closedCh - } - rebalanceCtx, cancel := context.WithCancel(ctx) p.closedCh = make(chan struct{}) @@ -2266,10 +2368,8 @@ func (p *pool) stopRebalance() { p.cancelLock.Lock() defer p.cancelLock.Unlock() - if p.cancel != nil { - p.cancel() - <-p.closedCh - } + p.cancel() + <-p.closedCh } func (p *pool) updateNodesHealth(ctx context.Context, buffers [][]float64) { diff --git a/pool/pool_test.go b/pool/pool_test.go index 03b1b2a..0ff1a8c 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -4,7 +4,9 @@ import ( "context" "crypto/ecdsa" "errors" + "reflect" "strconv" + "sync" "testing" "time" @@ -174,6 +176,150 @@ func TestTwoNodes(t *testing.T) { require.True(t, assertAuthKeyForAny(st, clientKeys)) } +func TestTwoNodesUpdate(t *testing.T) { + var clientKeys []*ecdsa.PrivateKey + mockClientBuilder := func(addr string) client { + key := newPrivateKey(t) + clientKeys = append(clientKeys, key) + return newMockClient(addr, *key) + } + + opts := InitParameters{ + key: newPrivateKey(t), + nodeParams: []NodeParam{ + {2, "peer0", 1}, + {2, "peer1", 1}, + }, + } + opts.setClientBuilder(mockClientBuilder) + + pool, err := NewPool(opts) + require.NoError(t, err) + err = pool.Dial(context.Background()) + require.NoError(t, err) + t.Cleanup(pool.Close) + + cp, err := pool.pool.Load().connection() + require.NoError(t, err) + st, _ := pool.pool.Load().cache.Get(formCacheKey(cp.address(), pool.pool.Load().key, false)) + require.True(t, assertAuthKeyForAny(st, clientKeys)) + + pool.Update(context.Background(), []NodeParam{ + {1, "peer-1", 1}, + {2, "peer0", 1}, + {2, "peer1", 1}, + }) + + st1, _ := pool.pool.Load().cache.Get(formCacheKey(cp.address(), pool.pool.Load().key, false)) + require.Equal(t, &st1, &st) + + cp2, err := pool.pool.Load().connection() + require.NoError(t, err) + require.Equal(t, cp2.address(), "peer-1") +} + +func TestUpdateNodeMultithread(t *testing.T) { + key1 := newPrivateKey(t) + mockClientBuilder := func(addr string) client { + return newMockClient(addr, *key1) + } + + opts := InitParameters{ + key: newPrivateKey(t), + nodeParams: []NodeParam{{1, "peer0", 1}}, + } + opts.setClientBuilder(mockClientBuilder) + + pool, err := NewPool(opts) + require.NoError(t, err) + err = pool.Dial(context.Background()) + require.NoError(t, err) + t.Cleanup(pool.Close) + + wg := sync.WaitGroup{} + + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + err := pool.Update(context.Background(), []NodeParam{{1, "peer" + strconv.Itoa(i+1), 1}}) + require.NoError(t, err) + }(i) + } + + wg.Wait() +} + +func TestUpdateNodeEqualConfig(t *testing.T) { + key1 := newPrivateKey(t) + mockClientBuilder := func(addr string) client { + return newMockClient(addr, *key1) + } + + opts := InitParameters{ + key: newPrivateKey(t), + nodeParams: []NodeParam{{1, "peer0", 1}}, + } + opts.setClientBuilder(mockClientBuilder) + + pool, err := NewPool(opts) + + require.NoError(t, err) + err = pool.Dial(context.Background()) + require.NoError(t, err) + t.Cleanup(pool.Close) + + cp, err := pool.pool.Load().connection() + require.NoError(t, err) + st, _ := pool.pool.Load().cache.Get(formCacheKey(cp.address(), pool.pool.Load().key, false)) + expectedAuthKey := frostfsecdsa.PublicKey(key1.PublicKey) + require.True(t, st.AssertAuthKey(&expectedAuthKey)) + + _, flag, err := pool.pool.Load().update(context.Background(), []NodeParam{{1, "peer0", 1}}) + require.NoError(t, err) + require.True(t, flag) +} + +func TestUpdateNode(t *testing.T) { + key1 := newPrivateKey(t) + mockClientBuilder := func(addr string) client { + return newMockClient(addr, *key1) + } + + opts := InitParameters{ + key: newPrivateKey(t), + nodeParams: []NodeParam{{1, "peer0", 1}}, + } + opts.setClientBuilder(mockClientBuilder) + + pool, err := NewPool(opts) + require.NoError(t, err) + err = pool.Dial(context.Background()) + require.NoError(t, err) + t.Cleanup(pool.Close) + + cp, err := pool.pool.Load().connection() + require.NoError(t, err) + st, _ := pool.pool.Load().cache.Get(formCacheKey(cp.address(), pool.pool.Load().key, false)) + expectedAuthKey := frostfsecdsa.PublicKey(key1.PublicKey) + require.True(t, st.AssertAuthKey(&expectedAuthKey)) + + pool.Update(context.Background(), []NodeParam{{1, "peer0", 1}}) + cp1, err := pool.pool.Load().connection() + st1, _ := pool.pool.Load().cache.Get(formCacheKey(cp1.address(), pool.pool.Load().key, false)) + require.NoError(t, err) + require.Equal(t, &st, &st1) + require.Equal(t, &cp, &cp1) + + pool.Update(context.Background(), []NodeParam{{1, "peer1", 1}}) + cp2, err := pool.pool.Load().connection() + require.NoError(t, err) + + st2, _ := pool.pool.Load().cache.Get(formCacheKey(cp2.address(), pool.pool.Load().key, false)) + require.NotEqual(t, cp.address(), cp2.address()) + require.NotEqual(t, &st, &st2) +} + func assertAuthKeyForAny(st session.Object, clientKeys []*ecdsa.PrivateKey) bool { for _, key := range clientKeys { expectedAuthKey := frostfsecdsa.PublicKey(key.PublicKey) @@ -668,6 +814,25 @@ func TestHandleError(t *testing.T) { } } +func TestAdjustNodeParams(t *testing.T) { + nodes1 := []NodeParam{ + {1, "peer0", 1}, + {1, "peer1", 2}, + {1, "peer2", 3}, + {2, "peer21", 2}, + } + nodes2 := []NodeParam{ + {1, "peer0", 1}, + {1, "peer2", 3}, + {1, "peer1", 2}, + {2, "peer21", 2}, + } + + nodesParam1, _ := adjustNodeParams(nodes1) + nodesParam2, _ := adjustNodeParams(nodes2) + require.True(t, reflect.DeepEqual(nodesParam1, nodesParam2)) +} + func TestSwitchAfterErrorThreshold(t *testing.T) { nodes := []NodeParam{ {1, "peer0", 1},