diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go index c57d3f09..327757bb 100644 --- a/pkg/pool/pool.go +++ b/pkg/pool/pool.go @@ -45,15 +45,8 @@ func (pb *Builder) Build(ctx context.Context, options *BuilderOptions) (Pool, er if len(pb.addresses) == 0 { return nil, errors.New("no NeoFS peers configured") } - totalWeight := 0.0 - for _, w := range pb.weights { - totalWeight += w - } - for i, w := range pb.weights { - pb.weights[i] = w / totalWeight - } - options.weights = pb.weights + options.weights = adjustWeights(pb.weights) options.addresses = pb.addresses return newPool(ctx, options) } @@ -108,28 +101,75 @@ func newPool(ctx context.Context, options *BuilderOptions) (Pool, error) { ownerID := owner.NewIDFromNeo3Wallet(wallet) pool := &pool{sampler: sampler, owner: ownerID, clientPacks: clientPacks} - go func() { - ticker := time.NewTimer(options.ClientRebalanceInterval) - for range ticker.C { - ok := true - for i, clientPack := range pool.clientPacks { - func() { - tctx, c := context.WithTimeout(ctx, options.NodeRequestTimeout) - defer c() - if _, err := clientPack.client.EndpointInfo(tctx); err != nil { - ok = false - } - pool.lock.Lock() - pool.clientPacks[i].healthy = ok - pool.lock.Unlock() - }() - } - ticker.Reset(options.ClientRebalanceInterval) - } - }() + go startRebalance(ctx, pool, options) return pool, nil } +func startRebalance(ctx context.Context, p *pool, options *BuilderOptions) { + ticker := time.NewTimer(options.ClientRebalanceInterval) + buffer := make([]float64, len(options.weights)) + + for range ticker.C { + updateNodesHealth(ctx, p, options, buffer) + ticker.Reset(options.ClientRebalanceInterval) + } +} + +func updateNodesHealth(ctx context.Context, p *pool, options *BuilderOptions, bufferWeights []float64) { + if len(bufferWeights) != len(p.clientPacks) { + bufferWeights = make([]float64, len(p.clientPacks)) + } + healthyChanged := false + wg := sync.WaitGroup{} + for i, cPack := range p.clientPacks { + wg.Add(1) + go func(i int, netmap client.Netmap) { + defer wg.Done() + ok := true + tctx, c := context.WithTimeout(ctx, options.NodeRequestTimeout) + defer c() + if _, err := netmap.EndpointInfo(tctx); err != nil { + ok = false + bufferWeights[i] = 0 + } + if ok { + bufferWeights[i] = options.weights[i] + } + + p.lock.Lock() + if p.clientPacks[i].healthy != ok { + p.clientPacks[i].healthy = ok + healthyChanged = true + } + p.lock.Unlock() + }(i, cPack.client) + } + wg.Wait() + + if healthyChanged { + probabilities := adjustWeights(bufferWeights) + source := rand.NewSource(time.Now().UnixNano()) + p.lock.Lock() + p.sampler = NewSampler(probabilities, source) + p.lock.Unlock() + } +} + +func adjustWeights(weights []float64) []float64 { + adjusted := make([]float64, len(weights)) + sum := 0.0 + for _, weight := range weights { + sum += weight + } + if sum > 0 { + for i, weight := range weights { + adjusted[i] = weight / sum + } + } + + return adjusted +} + func (p *pool) Connection() (client.Client, *session.Token, error) { p.lock.RLock() defer p.lock.RUnlock() diff --git a/pkg/pool/sampler_test.go b/pkg/pool/sampler_test.go index ed6a3f7b..20eb5eb3 100644 --- a/pkg/pool/sampler_test.go +++ b/pkg/pool/sampler_test.go @@ -1,9 +1,12 @@ package pool import ( + "context" + "fmt" "math/rand" "testing" + "github.com/nspcc-dev/neofs-api-go/pkg/client" "github.com/stretchr/testify/require" ) @@ -38,3 +41,85 @@ func TestSamplerStability(t *testing.T) { require.Equal(t, tc.expected, res, "probabilities: %v", tc.probabilities) } } + +type netmapMock struct { + client.Client + name string + err error +} + +func newNetmapMock(name string, needErr bool) netmapMock { + var err error + if needErr { + err = fmt.Errorf("not available") + } + return netmapMock{name: name, err: err} +} + +func (n netmapMock) EndpointInfo(_ context.Context, _ ...client.CallOption) (*client.EndpointInfo, error) { + return nil, n.err +} + +func TestHealthyReweight(t *testing.T) { + var ( + weights = []float64{0.9, 0.1} + names = []string{"node0", "node1"} + options = &BuilderOptions{weights: weights} + buffer = make([]float64, len(weights)) + ) + + p := &pool{ + sampler: NewSampler(weights, rand.NewSource(0)), + clientPacks: []*clientPack{ + {client: newNetmapMock(names[0], true), healthy: true}, + {client: newNetmapMock(names[1], false), healthy: true}}, + } + + // check getting first node connection before rebalance happened + connection0, _, err := p.Connection() + require.NoError(t, err) + mock0 := connection0.(netmapMock) + require.Equal(t, names[0], mock0.name) + + updateNodesHealth(context.TODO(), p, options, buffer) + + connection1, _, err := p.Connection() + require.NoError(t, err) + mock1 := connection1.(netmapMock) + require.Equal(t, names[1], mock1.name) + + // enabled first node again + p.lock.Lock() + p.clientPacks[0].client = newNetmapMock(names[0], false) + p.lock.Unlock() + + updateNodesHealth(context.TODO(), p, options, buffer) + + connection0, _, err = p.Connection() + require.NoError(t, err) + mock0 = connection0.(netmapMock) + require.Equal(t, names[0], mock0.name) +} + +func TestHealthyNoReweight(t *testing.T) { + var ( + weights = []float64{0.9, 0.1} + names = []string{"node0", "node1"} + options = &BuilderOptions{weights: weights} + buffer = make([]float64, len(weights)) + ) + + sampler := NewSampler(weights, rand.NewSource(0)) + p := &pool{ + sampler: sampler, + clientPacks: []*clientPack{ + {client: newNetmapMock(names[0], false), healthy: true}, + {client: newNetmapMock(names[1], false), healthy: true}}, + } + + updateNodesHealth(context.TODO(), p, options, buffer) + + p.lock.RLock() + defer p.lock.RUnlock() + require.Truef(t, sampler == p.sampler, "Sampler must not be changed. Expected: %p, actual: %p", sampler, p.sampler) +}