package tree

import (
	"context"
	"errors"
	"fmt"
	"sync"
	"time"

	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/network"
	metrics "git.frostfs.info/TrueCloudLab/frostfs-observability/metrics/grpc"
	tracing "git.frostfs.info/TrueCloudLab/frostfs-observability/tracing/grpc"
	"github.com/hashicorp/golang-lru/v2/simplelru"
	"google.golang.org/grpc"
	"google.golang.org/grpc/connectivity"
	"google.golang.org/grpc/credentials/insecure"
)

type clientCache struct {
	sync.Mutex
	simplelru.LRU[string, cacheItem]
}

type cacheItem struct {
	cc      *grpc.ClientConn
	lastTry time.Time
}

const (
	defaultClientCacheSize      = 32
	defaultClientConnectTimeout = time.Second * 2
	defaultReconnectInterval    = time.Second * 15
)

var errRecentlyFailed = errors.New("client has recently failed")

func (c *clientCache) init() {
	l, _ := simplelru.NewLRU(defaultClientCacheSize, func(_ string, value cacheItem) {
		if conn := value.cc; conn != nil {
			_ = conn.Close()
		}
	})
	c.LRU = *l
}

func (c *clientCache) get(ctx context.Context, netmapAddr string) (TreeServiceClient, error) {
	c.Lock()
	ccInt, ok := c.LRU.Get(netmapAddr)
	c.Unlock()

	if ok {
		item := ccInt
		if item.cc == nil {
			if d := time.Since(item.lastTry); d < defaultReconnectInterval {
				return nil, fmt.Errorf("%w: %s till the next reconnection to %s",
					errRecentlyFailed, d, netmapAddr)
			}
		} else {
			if s := item.cc.GetState(); s == connectivity.Idle || s == connectivity.Ready {
				return NewTreeServiceClient(item.cc), nil
			}
			_ = item.cc.Close()
		}
	}

	cc, err := dialTreeService(ctx, netmapAddr)
	lastTry := time.Now()

	c.Lock()
	if err != nil {
		c.LRU.Add(netmapAddr, cacheItem{cc: nil, lastTry: lastTry})
	} else {
		c.LRU.Add(netmapAddr, cacheItem{cc: cc, lastTry: lastTry})
	}
	c.Unlock()

	if err != nil {
		return nil, err
	}

	return NewTreeServiceClient(cc), nil
}

func dialTreeService(ctx context.Context, netmapAddr string) (*grpc.ClientConn, error) {
	var netAddr network.Address
	if err := netAddr.FromString(netmapAddr); err != nil {
		return nil, err
	}

	opts := []grpc.DialOption{
		grpc.WithBlock(),
		grpc.WithChainUnaryInterceptor(
			metrics.NewUnaryClientInterceptor(),
			tracing.NewUnaryClientInteceptor(),
		),
		grpc.WithChainStreamInterceptor(
			metrics.NewStreamClientInterceptor(),
			tracing.NewStreamClientInterceptor(),
		),
	}

	if !netAddr.IsTLSEnabled() {
		opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
	}

	ctx, cancel := context.WithTimeout(ctx, defaultClientConnectTimeout)
	cc, err := grpc.DialContext(ctx, netAddr.URIAddr(), opts...)
	cancel()

	return cc, err
}