Refactor 'Pool' #316

Merged
fyrchik merged 7 commits from achuprov/frostfs-sdk-go:feat/refactorPool into master 2025-03-07 11:45:32 +00:00
6 changed files with 1766 additions and 1684 deletions

1283
pool/client.go Normal file

File diff suppressed because it is too large Load diff

330
pool/connection_manager.go Normal file
View file

@ -0,0 +1,330 @@
package pool
import (
"context"
"errors"
"fmt"
"math/rand"
"sort"
"sync"
"sync/atomic"
"time"
apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type innerPool struct {
lock sync.RWMutex
sampler *sampler
clients []client
}
type connectionManager struct {
innerPools []*innerPool
rebalanceParams rebalanceParameters
clientBuilder clientBuilder
logger *zap.Logger
healthChecker *healthCheck
}
// newConnectionManager returns an instance of connectionManager configured according to the parameters.
//
// Before using connectionManager, you MUST call Dial.
func newConnectionManager(options InitParameters) (*connectionManager, error) {
if options.key == nil {
return nil, fmt.Errorf("missed required parameter 'Key'")
}
nodesParams, err := adjustNodeParams(options.nodeParams)
if err != nil {
return nil, err
}
manager := &connectionManager{
logger: options.logger,
rebalanceParams: rebalanceParameters{
nodesParams: nodesParams,
nodeRequestTimeout: options.healthcheckTimeout,
clientRebalanceInterval: options.clientRebalanceInterval,
sessionExpirationDuration: options.sessionExpirationDuration,
},
clientBuilder: options.clientBuilder,
}
return manager, nil
}
func (cm *connectionManager) dial(ctx context.Context) error {
inner := make([]*innerPool, len(cm.rebalanceParams.nodesParams))
var atLeastOneHealthy bool
for i, params := range cm.rebalanceParams.nodesParams {
clients := make([]client, len(params.weights))
for j, addr := range params.addresses {
clients[j] = cm.clientBuilder(addr)
if err := clients[j].dial(ctx); err != nil {
cm.log(zap.WarnLevel, "failed to build client", zap.String("address", addr), zap.Error(err))
continue
}
atLeastOneHealthy = true
}
source := rand.NewSource(time.Now().UnixNano())
sampl := newSampler(params.weights, source)
inner[i] = &innerPool{
sampler: sampl,
clients: clients,
}
}
if !atLeastOneHealthy {
return fmt.Errorf("at least one node must be healthy")
}
cm.innerPools = inner
cm.healthChecker = newHealthCheck(cm.rebalanceParams.clientRebalanceInterval)
cm.healthChecker.startRebalance(ctx, cm.rebalance)
return nil
}
func (cm *connectionManager) rebalance(ctx context.Context) {
buffers := make([][]float64, len(cm.rebalanceParams.nodesParams))
for i, params := range cm.rebalanceParams.nodesParams {
buffers[i] = make([]float64, len(params.weights))
}
cm.updateNodesHealth(ctx, buffers)
}
func (cm *connectionManager) log(level zapcore.Level, msg string, fields ...zap.Field) {
if cm.logger == nil {
return
}
cm.logger.Log(level, msg, fields...)
}
func adjustNodeParams(nodeParams []NodeParam) ([]*nodesParam, error) {
if len(nodeParams) == 0 {
return nil, errors.New("no FrostFS peers configured")
}
nodesParamsMap := make(map[int]*nodesParam)
for _, param := range nodeParams {
nodes, ok := nodesParamsMap[param.priority]
if !ok {
nodes = &nodesParam{priority: param.priority}
}
nodes.addresses = append(nodes.addresses, param.address)
nodes.weights = append(nodes.weights, param.weight)
nodesParamsMap[param.priority] = nodes
}
nodesParams := make([]*nodesParam, 0, len(nodesParamsMap))
for _, nodes := range nodesParamsMap {
nodes.weights = adjustWeights(nodes.weights)
nodesParams = append(nodesParams, nodes)
}
sort.Slice(nodesParams, func(i, j int) bool {
return nodesParams[i].priority < nodesParams[j].priority
})
return nodesParams, nil
}
func (cm *connectionManager) updateNodesHealth(ctx context.Context, buffers [][]float64) {
wg := sync.WaitGroup{}
for i, inner := range cm.innerPools {
wg.Add(1)
bufferWeights := buffers[i]
go func(i int, _ *innerPool) {
defer wg.Done()
cm.updateInnerNodesHealth(ctx, i, bufferWeights)
}(i, inner)
}
wg.Wait()
}
func (cm *connectionManager) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights []float64) {
if i > len(cm.innerPools)-1 {
return
}
pool := cm.innerPools[i]
options := cm.rebalanceParams
healthyChanged := new(atomic.Bool)
wg := sync.WaitGroup{}
for j, cli := range pool.clients {
wg.Add(1)
go func(j int, cli client) {
defer wg.Done()
tctx, c := context.WithTimeout(ctx, options.nodeRequestTimeout)
defer c()
changed, err := restartIfUnhealthy(tctx, cli)
healthy := err == nil
if healthy {
bufferWeights[j] = options.nodesParams[i].weights[j]
} else {
bufferWeights[j] = 0
}
if changed {
fields := []zap.Field{zap.String("address", cli.address()), zap.Bool("healthy", healthy)}
if err != nil {
fields = append(fields, zap.String("reason", err.Error()))
}
cm.log(zap.DebugLevel, "health has changed", fields...)
healthyChanged.Store(true)
}
}(j, cli)
}
wg.Wait()
if healthyChanged.Load() {
probabilities := adjustWeights(bufferWeights)
source := rand.NewSource(time.Now().UnixNano())
pool.lock.Lock()
pool.sampler = newSampler(probabilities, source)
pool.lock.Unlock()
}
}
// restartIfUnhealthy checks healthy status of client and recreate it if status is unhealthy.
// Indicating if status was changed by this function call and returns error that caused unhealthy status.
func restartIfUnhealthy(ctx context.Context, c client) (changed bool, err error) {
defer func() {
if err != nil {
c.setUnhealthy()
} else {
c.setHealthy()
}
}()
wasHealthy := c.isHealthy()
if res, err := c.healthcheck(ctx); err == nil {
if res.Status().IsMaintenance() {
return wasHealthy, new(apistatus.NodeUnderMaintenance)
}
return !wasHealthy, nil
}
if err = c.restart(ctx); err != nil {
return wasHealthy, err
}
res, err := c.healthcheck(ctx)
if err != nil {
return wasHealthy, err
}
if res.Status().IsMaintenance() {
return wasHealthy, new(apistatus.NodeUnderMaintenance)
}
return !wasHealthy, nil
}
func adjustWeights(weights []float64) []float64 {
adjusted := make([]float64, len(weights))
sum := 0.0
for _, weight := range weights {
sum += weight
}
if sum > 0 {
for i, weight := range weights {
adjusted[i] = weight / sum
}
}
return adjusted
}
func (cm *connectionManager) connection() (client, error) {
for _, inner := range cm.innerPools {
cp, err := inner.connection()
if err == nil {
return cp, nil
}
}
return nil, errors.New("no healthy client")
}
// iterate iterates over all clients in all innerPools.
func (cm *connectionManager) iterate(cb func(client)) {
for _, inner := range cm.innerPools {
for _, cl := range inner.clients {
if cl.isHealthy() {
cb(cl)
}
}
}
}
func (p *innerPool) connection() (client, error) {
p.lock.RLock() // need lock because of using p.sampler
defer p.lock.RUnlock()
if len(p.clients) == 1 {
cp := p.clients[0]
if cp.isHealthy() {
return cp, nil
}
return nil, errors.New("no healthy client")
}
attempts := 3 * len(p.clients)
for range attempts {
i := p.sampler.Next()
if cp := p.clients[i]; cp.isHealthy() {
return cp, nil
}
}
return nil, errors.New("no healthy client")
}
func (cm connectionManager) Statistic() Statistic {
stat := Statistic{}
for _, inner := range cm.innerPools {
nodes := make([]string, 0, len(inner.clients))
for _, cl := range inner.clients {
if cl.isHealthy() {
nodes = append(nodes, cl.address())
}
node := NodeStatistic{
address: cl.address(),
methods: cl.methodsStatus(),
overallErrors: cl.overallErrorRate(),
currentErrors: cl.currentErrorRate(),
}
stat.nodes = append(stat.nodes, node)
stat.overallErrors += node.overallErrors
}
if len(stat.currentNodes) == 0 {
stat.currentNodes = nodes
}
}
return stat
}
func (cm *connectionManager) close() {
cm.healthChecker.stopRebalance()
// close all clients
for _, pools := range cm.innerPools {
for _, cli := range pools.clients {
_ = cli.close()
}
}
}

47
pool/healthcheck.go Normal file
View file

@ -0,0 +1,47 @@
package pool
import (
"context"
"time"
)
type healthCheck struct {
cancel context.CancelFunc
closedCh chan struct{}
clientRebalanceInterval time.Duration
}
func newHealthCheck(clientRebalanceInterval time.Duration) *healthCheck {
var h healthCheck
h.clientRebalanceInterval = clientRebalanceInterval
h.closedCh = make(chan struct{})
return &h
}
// startRebalance runs loop to monitor connection healthy status.
func (h *healthCheck) startRebalance(ctx context.Context, callback func(ctx context.Context)) {
ctx, cancel := context.WithCancel(ctx)
h.cancel = cancel
go func() {
ticker := time.NewTicker(h.clientRebalanceInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
close(h.closedCh)
return
case <-ticker.C:
callback(ctx)
ticker.Reset(h.clientRebalanceInterval)
}
}
}()
}
func (h *healthCheck) stopRebalance() {
h.cancel()
<-h.closedCh
}

File diff suppressed because it is too large Load diff

View file

@ -104,7 +104,7 @@ func TestBuildPoolOneNodeFailed(t *testing.T) {
expectedAuthKey := frostfsecdsa.PublicKey(clientKeys[1].PublicKey)
condition := func() bool {
cp, err := clientPool.connection()
cp, err := clientPool.manager.connection()
if err != nil {
return false
}
@ -141,7 +141,7 @@ func TestOneNode(t *testing.T) {
require.NoError(t, err)
t.Cleanup(pool.Close)
cp, err := pool.connection()
cp, err := pool.manager.connection()
require.NoError(t, err)
st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false))
expectedAuthKey := frostfsecdsa.PublicKey(key1.PublicKey)
@ -171,7 +171,7 @@ func TestTwoNodes(t *testing.T) {
require.NoError(t, err)
t.Cleanup(pool.Close)
cp, err := pool.connection()
cp, err := pool.manager.connection()
require.NoError(t, err)
st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false))
require.True(t, assertAuthKeyForAny(st, clientKeys))
@ -220,13 +220,12 @@ func TestOneOfTwoFailed(t *testing.T) {
err = pool.Dial(context.Background())
require.NoError(t, err)
require.NoError(t, err)
t.Cleanup(pool.Close)
time.Sleep(2 * time.Second)
for range 5 {
cp, err := pool.connection()
cp, err := pool.manager.connection()
require.NoError(t, err)
st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false))
require.True(t, assertAuthKeyForAny(st, clientKeys))
@ -369,7 +368,7 @@ func TestUpdateNodesHealth(t *testing.T) {
tc.prepareCli(cli)
p, log := newPool(t, cli)
p.updateNodesHealth(ctx, [][]float64{{1}})
p.manager.updateNodesHealth(ctx, [][]float64{{1}})
changed := tc.wasHealthy != tc.willHealthy
require.Equalf(t, tc.willHealthy, cli.isHealthy(), "healthy status should be: %v", tc.willHealthy)
@ -385,19 +384,19 @@ func newPool(t *testing.T, cli *mockClient) (*Pool, *observer.ObservedLogs) {
require.NoError(t, err)
return &Pool{
innerPools: []*innerPool{{
sampler: newSampler([]float64{1}, rand.NewSource(0)),
clients: []client{cli},
}},
cache: cache,
key: newPrivateKey(t),
closedCh: make(chan struct{}),
rebalanceParams: rebalanceParameters{
nodesParams: []*nodesParam{{1, []string{"peer0"}, []float64{1}}},
nodeRequestTimeout: time.Second,
clientRebalanceInterval: 200 * time.Millisecond,
},
logger: log,
cache: cache,
key: newPrivateKey(t),
manager: &connectionManager{
innerPools: []*innerPool{{
sampler: newSampler([]float64{1}, rand.NewSource(0)),
clients: []client{cli},
}},
healthChecker: newHealthCheck(200 * time.Millisecond),
rebalanceParams: rebalanceParameters{
nodesParams: []*nodesParam{{1, []string{"peer0"}, []float64{1}}},
nodeRequestTimeout: time.Second,
},
logger: log},
}, observedLog
}
@ -435,7 +434,7 @@ func TestTwoFailed(t *testing.T) {
time.Sleep(2 * time.Second)
_, err = pool.connection()
_, err = pool.manager.connection()
require.Error(t, err)
require.Contains(t, err.Error(), "no healthy")
}
@ -469,7 +468,7 @@ func TestSessionCache(t *testing.T) {
t.Cleanup(pool.Close)
// cache must contain session token
cp, err := pool.connection()
cp, err := pool.manager.connection()
require.NoError(t, err)
st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false))
require.True(t, st.AssertAuthKey(&expectedAuthKey))
@ -482,7 +481,7 @@ func TestSessionCache(t *testing.T) {
require.Error(t, err)
// cache must not contain session token
cp, err = pool.connection()
cp, err = pool.manager.connection()
require.NoError(t, err)
_, ok := pool.cache.Get(formCacheKey(cp.address(), pool.key, false))
require.False(t, ok)
@ -494,7 +493,7 @@ func TestSessionCache(t *testing.T) {
require.NoError(t, err)
// cache must contain session token
cp, err = pool.connection()
cp, err = pool.manager.connection()
require.NoError(t, err)
st, _ = pool.cache.Get(formCacheKey(cp.address(), pool.key, false))
require.True(t, st.AssertAuthKey(&expectedAuthKey))
@ -538,7 +537,7 @@ func TestPriority(t *testing.T) {
expectedAuthKey1 := frostfsecdsa.PublicKey(clientKeys[0].PublicKey)
firstNode := func() bool {
cp, err := pool.connection()
cp, err := pool.manager.connection()
require.NoError(t, err)
st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false))
return st.AssertAuthKey(&expectedAuthKey1)
@ -546,7 +545,7 @@ func TestPriority(t *testing.T) {
expectedAuthKey2 := frostfsecdsa.PublicKey(clientKeys[1].PublicKey)
secondNode := func() bool {
cp, err := pool.connection()
cp, err := pool.manager.connection()
require.NoError(t, err)
st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false))
return st.AssertAuthKey(&expectedAuthKey2)
@ -583,7 +582,7 @@ func TestSessionCacheWithKey(t *testing.T) {
require.NoError(t, err)
// cache must contain session token
cp, err := pool.connection()
cp, err := pool.manager.connection()
require.NoError(t, err)
st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false))
require.True(t, st.AssertAuthKey(&expectedAuthKey))
@ -636,9 +635,8 @@ func TestSessionTokenOwner(t *testing.T) {
cc.sessionTarget = func(tok session.Object) {
tkn = tok
}
err = p.initCallContext(&cc, prm, prmCtx)
err = p.initCall(&cc, prm, prmCtx)
require.NoError(t, err)
err = p.openDefaultSession(ctx, &cc)
require.NoError(t, err)
require.True(t, tkn.VerifySignature())
@ -922,14 +920,14 @@ func TestSwitchAfterErrorThreshold(t *testing.T) {
t.Cleanup(pool.Close)
for range errorThreshold {
conn, err := pool.connection()
conn, err := pool.manager.connection()
require.NoError(t, err)
require.Equal(t, nodes[0].address, conn.address())
_, err = conn.objectGet(ctx, PrmObjectGet{})
require.Error(t, err)
}
conn, err := pool.connection()
conn, err := pool.manager.connection()
require.NoError(t, err)
require.Equal(t, nodes[1].address, conn.address())
_, err = conn.objectGet(ctx, PrmObjectGet{})

View file

@ -47,9 +47,6 @@ func TestHealthyReweight(t *testing.T) {
buffer = make([]float64, len(weights))
)
cache, err := newCache(0)
require.NoError(t, err)
client1 := newMockClient(names[0], *newPrivateKey(t))
client1.errOnDial()
@ -59,22 +56,20 @@ func TestHealthyReweight(t *testing.T) {
sampler: newSampler(weights, rand.NewSource(0)),
clients: []client{client1, client2},
}
p := &Pool{
cm := &connectionManager{
innerPools: []*innerPool{inner},
cache: cache,
key: newPrivateKey(t),
rebalanceParams: rebalanceParameters{nodesParams: []*nodesParam{{weights: weights}}},
}
// check getting first node connection before rebalance happened
connection0, err := p.connection()
connection0, err := cm.connection()
require.NoError(t, err)
mock0 := connection0.(*mockClient)
require.Equal(t, names[0], mock0.address())
p.updateInnerNodesHealth(context.TODO(), 0, buffer)
cm.updateInnerNodesHealth(context.TODO(), 0, buffer)
connection1, err := p.connection()
connection1, err := cm.connection()
require.NoError(t, err)
mock1 := connection1.(*mockClient)
require.Equal(t, names[1], mock1.address())
@ -84,10 +79,10 @@ func TestHealthyReweight(t *testing.T) {
inner.clients[0] = newMockClient(names[0], *newPrivateKey(t))
inner.lock.Unlock()
p.updateInnerNodesHealth(context.TODO(), 0, buffer)
cm.updateInnerNodesHealth(context.TODO(), 0, buffer)
inner.sampler = newSampler(weights, rand.NewSource(0))
connection0, err = p.connection()
connection0, err = cm.connection()
require.NoError(t, err)
mock0 = connection0.(*mockClient)
require.Equal(t, names[0], mock0.address())
@ -108,12 +103,12 @@ func TestHealthyNoReweight(t *testing.T) {
newMockClient(names[1], *newPrivateKey(t)),
},
}
p := &Pool{
cm := &connectionManager{
innerPools: []*innerPool{inner},
rebalanceParams: rebalanceParameters{nodesParams: []*nodesParam{{weights: weights}}},
}
p.updateInnerNodesHealth(context.TODO(), 0, buffer)
cm.updateInnerNodesHealth(context.TODO(), 0, buffer)
inner.lock.RLock()
defer inner.lock.RUnlock()