245 lines
4.7 KiB
Go
245 lines
4.7 KiB
Go
|
package main
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
crand "crypto/rand"
|
||
|
"encoding/binary"
|
||
|
"errors"
|
||
|
"math/rand"
|
||
|
"sort"
|
||
|
"strconv"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/spf13/viper"
|
||
|
"go.uber.org/zap"
|
||
|
"google.golang.org/grpc"
|
||
|
"google.golang.org/grpc/connectivity"
|
||
|
"google.golang.org/grpc/keepalive"
|
||
|
)
|
||
|
|
||
|
type (
|
||
|
node struct {
|
||
|
address string
|
||
|
weight uint32
|
||
|
conn *grpc.ClientConn
|
||
|
}
|
||
|
|
||
|
Pool struct {
|
||
|
log *zap.Logger
|
||
|
|
||
|
ctl time.Duration
|
||
|
opts keepalive.ClientParameters
|
||
|
|
||
|
cur *grpc.ClientConn
|
||
|
|
||
|
*sync.RWMutex
|
||
|
nodes []*node
|
||
|
keys []uint32
|
||
|
conns map[uint32][]*grpc.ClientConn
|
||
|
}
|
||
|
)
|
||
|
|
||
|
var errNoHealthyConnections = errors.New("no active connections")
|
||
|
|
||
|
func newPool(ctx context.Context, l *zap.Logger, v *viper.Viper) *Pool {
|
||
|
p := &Pool{
|
||
|
log: l,
|
||
|
RWMutex: new(sync.RWMutex),
|
||
|
keys: make([]uint32, 0),
|
||
|
nodes: make([]*node, 0),
|
||
|
conns: make(map[uint32][]*grpc.ClientConn),
|
||
|
|
||
|
// fill with defaults:
|
||
|
ctl: time.Second * 15,
|
||
|
opts: keepalive.ClientParameters{
|
||
|
Time: time.Second * 10,
|
||
|
Timeout: time.Minute * 5,
|
||
|
PermitWithoutStream: true,
|
||
|
},
|
||
|
}
|
||
|
buf := make([]byte, 8)
|
||
|
if _, err := crand.Read(buf); err != nil {
|
||
|
l.Panic("could not read seed", zap.Error(err))
|
||
|
}
|
||
|
|
||
|
seed := binary.BigEndian.Uint64(buf)
|
||
|
rand.Seed(int64(seed))
|
||
|
l.Info("used random seed", zap.Uint64("seed", seed))
|
||
|
|
||
|
if val := v.GetDuration("connect_timeout"); val > 0 {
|
||
|
p.ctl = val
|
||
|
}
|
||
|
|
||
|
if val := v.GetDuration("keepalive.time"); val > 0 {
|
||
|
p.opts.Time = val
|
||
|
}
|
||
|
|
||
|
if val := v.GetDuration("keepalive.timeout"); val > 0 {
|
||
|
p.opts.Timeout = val
|
||
|
}
|
||
|
|
||
|
if v.IsSet("keepalive.permit_without_stream") {
|
||
|
p.opts.PermitWithoutStream = v.GetBool("keepalive.permit_without_stream")
|
||
|
}
|
||
|
|
||
|
for i := 0; ; i++ {
|
||
|
key := "peers." + strconv.Itoa(i) + "."
|
||
|
address := v.GetString(key + "address")
|
||
|
weight := v.GetFloat64(key + "weight")
|
||
|
|
||
|
if address == "" {
|
||
|
l.Warn("skip, empty address")
|
||
|
break
|
||
|
}
|
||
|
|
||
|
p.nodes = append(p.nodes, &node{
|
||
|
address: address,
|
||
|
weight: uint32(weight * 100),
|
||
|
})
|
||
|
|
||
|
l.Info("add new peer",
|
||
|
zap.String("address", p.nodes[i].address),
|
||
|
zap.Uint32("weight", p.nodes[i].weight))
|
||
|
}
|
||
|
|
||
|
p.reBalance(ctx)
|
||
|
|
||
|
cur, err := p.getConnection()
|
||
|
if err != nil {
|
||
|
l.Panic("could get connection", zap.Error(err))
|
||
|
}
|
||
|
|
||
|
p.cur = cur
|
||
|
|
||
|
return p
|
||
|
}
|
||
|
|
||
|
func (p *Pool) close() {
|
||
|
p.Lock()
|
||
|
defer p.Unlock()
|
||
|
|
||
|
for i := range p.nodes {
|
||
|
if p.nodes[i] == nil || p.nodes[i].conn == nil {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
p.log.Warn("close connection",
|
||
|
zap.String("address", p.nodes[i].address),
|
||
|
zap.Error(p.nodes[i].conn.Close()))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (p *Pool) reBalance(ctx context.Context) {
|
||
|
p.Lock()
|
||
|
defer p.Unlock()
|
||
|
|
||
|
keys := make(map[uint32]struct{})
|
||
|
|
||
|
for i := range p.nodes {
|
||
|
var (
|
||
|
idx = -1
|
||
|
exists bool
|
||
|
err error
|
||
|
conn = p.nodes[i].conn
|
||
|
weight = p.nodes[i].weight
|
||
|
)
|
||
|
|
||
|
if conn == nil {
|
||
|
p.log.Warn("empty connection, try to connect",
|
||
|
zap.String("address", p.nodes[i].address))
|
||
|
|
||
|
ctx, cancel := context.WithTimeout(ctx, p.ctl)
|
||
|
conn, err = grpc.DialContext(ctx, p.nodes[i].address,
|
||
|
grpc.WithBlock(),
|
||
|
grpc.WithInsecure(),
|
||
|
grpc.WithKeepaliveParams(p.opts))
|
||
|
cancel()
|
||
|
|
||
|
if err != nil || conn == nil {
|
||
|
p.log.Warn("skip, could not connect to node",
|
||
|
zap.String("address", p.nodes[i].address),
|
||
|
zap.Error(err))
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
p.nodes[i].conn = conn
|
||
|
p.log.Info("connected to node", zap.String("address", p.nodes[i].address))
|
||
|
}
|
||
|
|
||
|
switch st := conn.GetState(); st {
|
||
|
case connectivity.Idle, connectivity.Ready, connectivity.Connecting:
|
||
|
keys[weight] = struct{}{}
|
||
|
|
||
|
for j := range p.conns[weight] {
|
||
|
if p.conns[weight][j] == conn {
|
||
|
exists = true
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if !exists {
|
||
|
p.conns[weight] = append(p.conns[weight], conn)
|
||
|
}
|
||
|
|
||
|
// Problems with connection, try to remove from :
|
||
|
default:
|
||
|
for j := range p.conns[weight] {
|
||
|
if p.conns[weight][j] == conn {
|
||
|
idx = j
|
||
|
exists = true
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if exists {
|
||
|
// remove from connections
|
||
|
p.conns[weight] = append(p.conns[weight][:idx], p.conns[weight][idx+1:]...)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
p.keys = p.keys[:0]
|
||
|
for w := range keys {
|
||
|
p.keys = append(p.keys, w)
|
||
|
}
|
||
|
|
||
|
sort.Slice(p.keys, func(i, j int) bool {
|
||
|
return p.keys[i] > p.keys[j]
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (p *Pool) getConnection() (*grpc.ClientConn, error) {
|
||
|
p.RLock()
|
||
|
defer p.RUnlock()
|
||
|
|
||
|
if p.cur != nil && isAlive(p.cur) {
|
||
|
return p.cur, nil
|
||
|
}
|
||
|
|
||
|
for _, w := range p.keys {
|
||
|
switch ln := len(p.conns[w]); ln {
|
||
|
case 0:
|
||
|
continue
|
||
|
case 1:
|
||
|
p.cur = p.conns[w][0]
|
||
|
return p.cur, nil
|
||
|
default: // > 1
|
||
|
i := rand.Intn(ln)
|
||
|
p.cur = p.conns[w][i]
|
||
|
return p.cur, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil, errNoHealthyConnections
|
||
|
}
|
||
|
|
||
|
func isAlive(cur *grpc.ClientConn) bool {
|
||
|
switch st := cur.GetState(); st {
|
||
|
case connectivity.Idle, connectivity.Ready, connectivity.Connecting:
|
||
|
return true
|
||
|
default:
|
||
|
return false
|
||
|
}
|
||
|
}
|