package pool

import (
	"context"
	"crypto/ecdsa"
	crand "crypto/rand"
	"encoding/binary"
	"math/rand"
	"sort"
	"sync"
	"time"

	"google.golang.org/grpc/grpclog"

	"github.com/nspcc-dev/neofs-api-go/service"
	"github.com/nspcc-dev/neofs-api-go/state"
	"github.com/pkg/errors"
	"go.uber.org/atomic"
	"go.uber.org/zap"
	"google.golang.org/grpc"
	"google.golang.org/grpc/connectivity"
	"google.golang.org/grpc/keepalive"
)

type (
	node struct {
		index   int32
		address string
		weight  uint32
		usedAt  time.Time
		conn    *grpc.ClientConn
	}

	Client interface {
		GetConnection(context.Context) (*grpc.ClientConn, error)
	}

	Pool interface {
		Client

		Close()
		Status() error
		ReBalance(ctx context.Context)
	}

	Peer struct {
		Address string
		Weight  float64
	}

	Config struct {
		keepalive.ClientParameters

		ConnectionTTL  time.Duration
		ConnectTimeout time.Duration
		RequestTimeout time.Duration

		Peers []Peer

		GRPCVerbose bool
		GRPCLogger  grpclog.LoggerV2

		Logger     *zap.Logger
		PrivateKey *ecdsa.PrivateKey
	}

	pool struct {
		log *zap.Logger

		ttl time.Duration

		conTimeout time.Duration
		reqTimeout time.Duration
		opts       keepalive.ClientParameters

		currentIdx  *atomic.Int32
		currentConn *grpc.ClientConn

		reqHealth *state.HealthRequest

		*sync.Mutex
		nodes []*node
		keys  []uint32
		conns map[uint32][]*node

		unhealthy *atomic.Error
	}
)

var (
	errBootstrapping        = errors.New("bootstrapping")
	errEmptyConnection      = errors.New("empty connection")
	errNoHealthyConnections = errors.New("no active connections")
)

func New(cfg *Config) (Pool, error) {
	p := &pool{
		log:   cfg.Logger,
		Mutex: new(sync.Mutex),
		keys:  make([]uint32, 0),
		nodes: make([]*node, 0),
		conns: make(map[uint32][]*node),

		currentIdx: atomic.NewInt32(-1),

		ttl: cfg.ConnectionTTL,

		conTimeout: cfg.ConnectTimeout,
		reqTimeout: cfg.RequestTimeout,
		opts:       cfg.ClientParameters,

		unhealthy: atomic.NewError(errBootstrapping),
	}

	if cfg.GRPCVerbose {
		grpclog.SetLoggerV2(cfg.GRPCLogger)
	}

	buf := make([]byte, 8)
	if _, err := crand.Read(buf); err != nil {
		return nil, err
	}

	seed := binary.BigEndian.Uint64(buf)
	rand.Seed(int64(seed))
	cfg.Logger.Info("used random seed", zap.Uint64("seed", seed))

	p.reqHealth = new(state.HealthRequest)
	p.reqHealth.SetTTL(service.NonForwardingTTL)

	if err := service.SignRequestData(cfg.PrivateKey, p.reqHealth); err != nil {
		return nil, errors.Wrap(err, "could not sign `HealthRequest`")
	}

	for i := range cfg.Peers {
		if cfg.Peers[i].Address == "" {
			cfg.Logger.Warn("skip, empty address")
			break
		}

		p.nodes = append(p.nodes, &node{
			index:   int32(i),
			address: cfg.Peers[i].Address,
			weight:  uint32(cfg.Peers[i].Weight * 100),
		})

		cfg.Logger.Info("add new peer",
			zap.String("address", p.nodes[i].address),
			zap.Uint32("weight", p.nodes[i].weight))
	}

	return p, nil
}

func (p *pool) Status() error {
	return p.unhealthy.Load()
}

func (p *pool) Close() {
	p.Lock()
	defer p.Unlock()

	for i := range p.nodes {
		if p.nodes[i] == nil || p.nodes[i].conn == nil {
			continue
		}

		p.log.Warn("close connection",
			zap.String("address", p.nodes[i].address),
			zap.Error(p.nodes[i].conn.Close()))
	}
}

func (p *pool) ReBalance(ctx context.Context) {
	p.Lock()
	defer func() {
		p.Unlock()

		_, err := p.GetConnection(ctx)
		p.unhealthy.Store(err)
	}()

	keys := make(map[uint32]struct{})

	p.log.Debug("re-balancing connections")

	for i := range p.nodes {
		var (
			idx    = -1
			exists bool
			err    error
			start  = time.Now()
			conn   = p.nodes[i].conn
			weight = p.nodes[i].weight
		)

		if err = ctx.Err(); err != nil {
			p.log.Warn("something went wrong", zap.Error(err))
			p.unhealthy.Store(err)

			return
		}

		if conn == nil {
			p.log.Debug("empty connection, try to connect",
				zap.String("address", p.nodes[i].address))

			ctx, cancel := context.WithTimeout(ctx, p.conTimeout)
			conn, err = grpc.DialContext(ctx, p.nodes[i].address,
				grpc.WithBlock(),
				grpc.WithInsecure(),
				grpc.WithKeepaliveParams(p.opts))
			cancel()

			if err != nil || conn == nil {
				p.log.Warn("skip, could not connect to node",
					zap.String("address", p.nodes[i].address),
					zap.Stringer("elapsed", time.Since(start)),
					zap.Error(err))
				continue
			}

			p.nodes[i].conn = conn
			p.nodes[i].usedAt = time.Now()
			p.log.Debug("connected to node", zap.String("address", p.nodes[i].address))
		}

		for j := range p.conns[weight] {
			if p.conns[weight][j] != nil && p.conns[weight][j].conn == conn {
				idx = j
				exists = true
				break
			}
		}

		usedAt := time.Since(p.nodes[i].usedAt)

		// if something wrong with connection (bad state, unhealthy or not used a long time), try to close it and remove
		if err = p.isAlive(ctx, conn); err != nil || usedAt > p.ttl {
			p.log.Warn("connection not alive",
				zap.String("address", p.nodes[i].address),
				zap.Stringer("since", usedAt),
				zap.Error(err))

			if exists {
				// remove from connections
				p.conns[weight] = append(p.conns[weight][:idx], p.conns[weight][idx+1:]...)
			}

			if err = conn.Close(); err != nil {
				p.log.Warn("could not close bad connection",
					zap.String("address", p.nodes[i].address),
					zap.Stringer("since", usedAt),
					zap.Error(err))
			}

			if p.nodes[i].conn != nil {
				p.nodes[i].conn = nil
			}
			continue
		}

		keys[weight] = struct{}{}

		p.log.Debug("connection alive",
			zap.String("address", p.nodes[i].address),
			zap.Stringer("since", usedAt))

		if !exists {
			p.conns[weight] = append(p.conns[weight], p.nodes[i])
		}
	}

	p.keys = p.keys[:0]
	for w := range keys {
		p.keys = append(p.keys, w)
	}

	sort.Slice(p.keys, func(i, j int) bool {
		return p.keys[i] > p.keys[j]
	})
}

func (p *pool) GetConnection(ctx context.Context) (*grpc.ClientConn, error) {
	p.Lock()
	defer p.Unlock()

	if err := p.isAlive(ctx, p.currentConn); err == nil {
		if id := p.currentIdx.Load(); id != -1 && p.nodes[id] != nil {
			p.nodes[id].usedAt = time.Now()
		}

		return p.currentConn, nil
	}

	for _, w := range p.keys {
		switch ln := len(p.conns[w]); ln {
		case 0:
			continue
		case 1:
			p.currentConn = p.conns[w][0].conn
			p.conns[w][0].usedAt = time.Now()
			p.currentIdx.Store(p.conns[w][0].index)
			return p.currentConn, nil
		default: // > 1
			i := rand.Intn(ln)
			p.currentConn = p.conns[w][i].conn
			p.conns[w][i].usedAt = time.Now()
			p.currentIdx.Store(p.conns[w][i].index)
			return p.currentConn, nil
		}
	}

	p.currentConn = nil
	p.currentIdx.Store(-1)

	if ctx.Err() != nil {
		return nil, ctx.Err()
	}

	return nil, errNoHealthyConnections
}

func (p *pool) isAlive(ctx context.Context, cur *grpc.ClientConn) error {
	if cur == nil {
		return errEmptyConnection
	}

	switch st := cur.GetState(); st {
	case connectivity.Idle, connectivity.Ready, connectivity.Connecting:
		ctx, cancel := context.WithTimeout(ctx, p.reqTimeout)
		defer cancel()

		res, err := state.NewStatusClient(cur).HealthCheck(ctx, p.reqHealth)
		if err != nil {
			p.log.Warn("could not fetch health-check", zap.Error(err))

			return err
		} else if !res.Healthy {
			return errors.New(res.Status)
		}

		return nil
	default:
		return errors.New(st.String())
	}
}