frostfs-http-gw/pool.go
2020-02-14 13:12:36 +03:00

275 lines
5.6 KiB
Go

package main
import (
"context"
crand "crypto/rand"
"encoding/binary"
"errors"
"math/rand"
"sort"
"strconv"
"sync"
"time"
"github.com/nspcc-dev/neofs-api/service"
"github.com/nspcc-dev/neofs-api/state"
"github.com/spf13/viper"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/keepalive"
)
type (
node struct {
address string
weight uint32
conn *grpc.ClientConn
}
Pool struct {
log *zap.Logger
connectTimeout time.Duration
opts keepalive.ClientParameters
cur *grpc.ClientConn
*sync.RWMutex
nodes []*node
keys []uint32
conns map[uint32][]*grpc.ClientConn
}
)
var (
errEmptyConnection = errors.New("empty connection")
errNoHealthyConnections = errors.New("no active connections")
)
func newPool(ctx context.Context, l *zap.Logger, v *viper.Viper) *Pool {
p := &Pool{
log: l,
RWMutex: new(sync.RWMutex),
keys: make([]uint32, 0),
nodes: make([]*node, 0),
conns: make(map[uint32][]*grpc.ClientConn),
// fill with defaults:
connectTimeout: time.Second * 15,
opts: keepalive.ClientParameters{
Time: time.Second * 10,
Timeout: time.Minute * 5,
PermitWithoutStream: true,
},
}
buf := make([]byte, 8)
if _, err := crand.Read(buf); err != nil {
l.Panic("could not read seed", zap.Error(err))
}
seed := binary.BigEndian.Uint64(buf)
rand.Seed(int64(seed))
l.Info("used random seed", zap.Uint64("seed", seed))
if val := v.GetDuration("connect_timeout"); val > 0 {
p.connectTimeout = val
}
if val := v.GetDuration("keepalive.time"); val > 0 {
p.opts.Time = val
}
if val := v.GetDuration("keepalive.timeout"); val > 0 {
p.opts.Timeout = val
}
if v.IsSet("keepalive.permit_without_stream") {
p.opts.PermitWithoutStream = v.GetBool("keepalive.permit_without_stream")
}
for i := 0; ; i++ {
key := "peers." + strconv.Itoa(i) + "."
address := v.GetString(key + "address")
weight := v.GetFloat64(key + "weight")
if address == "" {
l.Warn("skip, empty address")
break
}
p.nodes = append(p.nodes, &node{
address: address,
weight: uint32(weight * 100),
})
l.Info("add new peer",
zap.String("address", p.nodes[i].address),
zap.Uint32("weight", p.nodes[i].weight))
}
p.reBalance(ctx)
cur, err := p.getConnection(ctx)
if err != nil {
l.Panic("could get connection", zap.Error(err))
}
p.cur = cur
return p
}
func (p *Pool) close() {
p.Lock()
defer p.Unlock()
for i := range p.nodes {
if p.nodes[i] == nil || p.nodes[i].conn == nil {
continue
}
p.log.Warn("close connection",
zap.String("address", p.nodes[i].address),
zap.Error(p.nodes[i].conn.Close()))
}
}
func (p *Pool) reBalance(ctx context.Context) {
p.Lock()
defer p.Unlock()
keys := make(map[uint32]struct{})
p.log.Info("re-balancing connections")
for i := range p.nodes {
var (
idx = -1
exists bool
err error
start = time.Now()
conn = p.nodes[i].conn
weight = p.nodes[i].weight
)
if conn == nil {
p.log.Warn("empty connection, try to connect",
zap.String("address", p.nodes[i].address))
ctx, cancel := context.WithTimeout(ctx, p.connectTimeout)
conn, err = grpc.DialContext(ctx, p.nodes[i].address,
grpc.WithBlock(),
grpc.WithInsecure(),
grpc.WithKeepaliveParams(p.opts))
cancel()
if err != nil || conn == nil {
p.log.Warn("skip, could not connect to node",
zap.String("address", p.nodes[i].address),
zap.Duration("elapsed", time.Since(start)),
zap.Error(err))
continue
}
p.nodes[i].conn = conn
p.log.Info("connected to node", zap.String("address", p.nodes[i].address))
}
for j := range p.conns[weight] {
if p.conns[weight][j] == conn {
idx = j
exists = true
break
}
}
// if something wrong with connection (bad state or unhealthy), try to close it and remove
if err = isAlive(ctx, p.log, conn); err != nil {
p.log.Warn("connection not alive",
zap.String("address", p.nodes[i].address),
zap.Error(err))
if exists {
// remove from connections
p.conns[weight] = append(p.conns[weight][:idx], p.conns[weight][idx+1:]...)
}
if err = conn.Close(); err != nil {
p.log.Warn("could not close bad connection",
zap.String("address", p.nodes[i].address),
zap.Error(err))
}
if p.nodes[i].conn != nil {
p.nodes[i].conn = nil
}
continue
}
keys[weight] = struct{}{}
if !exists {
p.conns[weight] = append(p.conns[weight], conn)
}
}
p.keys = p.keys[:0]
for w := range keys {
p.keys = append(p.keys, w)
}
sort.Slice(p.keys, func(i, j int) bool {
return p.keys[i] > p.keys[j]
})
}
func (p *Pool) getConnection(ctx context.Context) (*grpc.ClientConn, error) {
p.RLock()
defer p.RUnlock()
if err := isAlive(ctx, p.log, p.cur); err == nil {
return p.cur, nil
}
for _, w := range p.keys {
switch ln := len(p.conns[w]); ln {
case 0:
continue
case 1:
p.cur = p.conns[w][0]
return p.cur, nil
default: // > 1
i := rand.Intn(ln)
p.cur = p.conns[w][i]
return p.cur, nil
}
}
return nil, errNoHealthyConnections
}
func isAlive(ctx context.Context, log *zap.Logger, cur *grpc.ClientConn) error {
if cur == nil {
return errEmptyConnection
}
switch st := cur.GetState(); st {
case connectivity.Idle, connectivity.Ready, connectivity.Connecting:
req := new(state.HealthRequest)
req.SetTTL(service.NonForwardingTTL)
res, err := state.NewStatusClient(cur).HealthCheck(ctx, req)
if err != nil {
log.Warn("could not fetch health-check", zap.Error(err))
return err
} else if !res.Healthy {
return errors.New(res.Status)
}
return nil
default:
return errors.New(st.String())
}
}