diff --git a/pool.go b/pool.go index d361881..0395cd4 100644 --- a/pool.go +++ b/pool.go @@ -14,6 +14,7 @@ import ( "github.com/nspcc-dev/neofs-api/service" "github.com/nspcc-dev/neofs-api/state" "github.com/spf13/viper" + "go.uber.org/atomic" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" @@ -22,23 +23,29 @@ import ( type ( node struct { + index int32 address string weight uint32 + usedAt time.Time conn *grpc.ClientConn } Pool struct { log *zap.Logger + ttl time.Duration + connectTimeout time.Duration + requestTimeout time.Duration opts keepalive.ClientParameters - cur *grpc.ClientConn + currentIdx *atomic.Int32 + currentConn *grpc.ClientConn - *sync.RWMutex + *sync.Mutex nodes []*node keys []uint32 - conns map[uint32][]*grpc.ClientConn + conns map[uint32][]*node } ) @@ -49,20 +56,26 @@ var ( func newPool(ctx context.Context, l *zap.Logger, v *viper.Viper) *Pool { p := &Pool{ - log: l, - RWMutex: new(sync.RWMutex), - keys: make([]uint32, 0), - nodes: make([]*node, 0), - conns: make(map[uint32][]*grpc.ClientConn), + log: l, + Mutex: new(sync.Mutex), + keys: make([]uint32, 0), + nodes: make([]*node, 0), + conns: make(map[uint32][]*node), + + ttl: defaultTTL, + + currentIdx: atomic.NewInt32(-1), // fill with defaults: - connectTimeout: time.Second * 15, + requestTimeout: defaultRequestTimeout, + connectTimeout: defaultConnectTimeout, opts: keepalive.ClientParameters{ - Time: time.Second * 10, - Timeout: time.Minute * 5, + Time: defaultKeepaliveTime, + Timeout: defaultKeepaliveTimeout, PermitWithoutStream: true, }, } + buf := make([]byte, 8) if _, err := crand.Read(buf); err != nil { l.Panic("could not read seed", zap.Error(err)) @@ -72,6 +85,10 @@ func newPool(ctx context.Context, l *zap.Logger, v *viper.Viper) *Pool { rand.Seed(int64(seed)) l.Info("used random seed", zap.Uint64("seed", seed)) + if val := v.GetDuration("conn_ttl"); val > 0 { + p.ttl = val + } + if val := v.GetDuration("connect_timeout"); val > 0 { p.connectTimeout = val } @@ -99,6 +116,7 @@ func newPool(ctx context.Context, l *zap.Logger, v *viper.Viper) *Pool { } p.nodes = append(p.nodes, &node{ + index: int32(i), address: address, weight: uint32(weight * 100), }) @@ -110,13 +128,10 @@ func newPool(ctx context.Context, l *zap.Logger, v *viper.Viper) *Pool { p.reBalance(ctx) - cur, err := p.getConnection(ctx) - if err != nil { + if _, err := p.getConnection(ctx); err != nil { l.Panic("could get connection", zap.Error(err)) } - p.cur = cur - return p } @@ -173,21 +188,25 @@ func (p *Pool) reBalance(ctx context.Context) { } p.nodes[i].conn = conn + p.nodes[i].usedAt = time.Now() p.log.Info("connected to node", zap.String("address", p.nodes[i].address)) } for j := range p.conns[weight] { - if p.conns[weight][j] == conn { + if p.conns[weight][j] != nil && p.conns[weight][j].conn == conn { idx = j exists = true break } } - // if something wrong with connection (bad state or unhealthy), try to close it and remove - if err = isAlive(ctx, p.log, conn); err != nil { + usedAt := time.Since(p.nodes[i].usedAt) + + // if something wrong with connection (bad state, unhealthy or not used a long time), try to close it and remove + if err = p.isAlive(ctx, conn); err != nil || usedAt > p.ttl { p.log.Warn("connection not alive", zap.String("address", p.nodes[i].address), + zap.Duration("used_at", usedAt), zap.Error(err)) if exists { @@ -198,6 +217,7 @@ func (p *Pool) reBalance(ctx context.Context) { if err = conn.Close(); err != nil { p.log.Warn("could not close bad connection", zap.String("address", p.nodes[i].address), + zap.Duration("used_at", usedAt), zap.Error(err)) } @@ -209,8 +229,12 @@ func (p *Pool) reBalance(ctx context.Context) { keys[weight] = struct{}{} + p.log.Info("connection alive", + zap.String("address", p.nodes[i].address), + zap.Duration("used_at", usedAt)) + if !exists { - p.conns[weight] = append(p.conns[weight], conn) + p.conns[weight] = append(p.conns[weight], p.nodes[i]) } } @@ -225,11 +249,15 @@ func (p *Pool) reBalance(ctx context.Context) { } func (p *Pool) getConnection(ctx context.Context) (*grpc.ClientConn, error) { - p.RLock() - defer p.RUnlock() + p.Lock() + defer p.Unlock() - if err := isAlive(ctx, p.log, p.cur); err == nil { - return p.cur, nil + if err := p.isAlive(ctx, p.currentConn); err == nil { + if id := p.currentIdx.Load(); id != -1 && p.nodes[id] != nil { + p.nodes[id].usedAt = time.Now() + } + + return p.currentConn, nil } for _, w := range p.keys { @@ -237,19 +265,26 @@ func (p *Pool) getConnection(ctx context.Context) (*grpc.ClientConn, error) { case 0: continue case 1: - p.cur = p.conns[w][0] - return p.cur, nil + p.currentConn = p.conns[w][0].conn + p.conns[w][0].usedAt = time.Now() + p.currentIdx.Store(p.conns[w][0].index) + return p.currentConn, nil default: // > 1 i := rand.Intn(ln) - p.cur = p.conns[w][i] - return p.cur, nil + p.currentConn = p.conns[w][i].conn + p.conns[w][i].usedAt = time.Now() + p.currentIdx.Store(p.conns[w][i].index) + return p.currentConn, nil } } + p.currentConn = nil + p.currentIdx.Store(-1) + return nil, errNoHealthyConnections } -func isAlive(ctx context.Context, log *zap.Logger, cur *grpc.ClientConn) error { +func (p *Pool) isAlive(ctx context.Context, cur *grpc.ClientConn) error { if cur == nil { return errEmptyConnection } @@ -259,9 +294,12 @@ func isAlive(ctx context.Context, log *zap.Logger, cur *grpc.ClientConn) error { req := new(state.HealthRequest) req.SetTTL(service.NonForwardingTTL) + ctx, cancel := context.WithTimeout(ctx, p.requestTimeout) + defer cancel() + res, err := state.NewStatusClient(cur).HealthCheck(ctx, req) if err != nil { - log.Warn("could not fetch health-check", zap.Error(err)) + p.log.Warn("could not fetch health-check", zap.Error(err)) return err } else if !res.Healthy {