package placement

import (
	"crypto/sha256"
	"fmt"
	"slices"
	"sync"

	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	netmapSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap"
	"github.com/hashicorp/golang-lru/v2/simplelru"
)

// ContainerNodesCache caches results of ContainerNodes() invocation between epochs.
type ContainerNodesCache struct {
	// mtx protects lastEpoch and containerCache fields.
	mtx sync.Mutex
	// lastEpoch contains network map epoch for all values in the container cache.
	lastEpoch uint64
	// containerCache caches container nodes by ID. It is used to skip `GetContainerNodes` invocation if
	// neither netmap nor container has changed.
	containerCache simplelru.LRUCache[cid.ID, [][]netmapSDK.NodeInfo]
}

// defaultContainerCacheSize is the default size for the container cache.
const defaultContainerCacheSize = 10

// NewContainerNodesCache creates new cache which saves the result of the ContainerNodes() invocations.
// If size is <= 0, defaultContainerCacheSize (10) is used.
func NewContainerNodesCache(size int) *ContainerNodesCache {
	if size <= 0 {
		size = defaultContainerCacheSize
	}

	cache, _ := simplelru.NewLRU[cid.ID, [][]netmapSDK.NodeInfo](size, nil) // no error
	return &ContainerNodesCache{
		containerCache: cache,
	}
}

// ContainerNodes returns the result of nm.ContainerNodes(), possibly from the cache.
func (c *ContainerNodesCache) ContainerNodes(nm *netmapSDK.NetMap, cnr cid.ID, p netmapSDK.PlacementPolicy) ([][]netmapSDK.NodeInfo, error) {
	c.mtx.Lock()
	if nm.Epoch() == c.lastEpoch {
		raw, ok := c.containerCache.Get(cnr)
		c.mtx.Unlock()
		if ok {
			return c.cloneResult(raw), nil
		}
	} else {
		c.lastEpoch = nm.Epoch()
		c.containerCache.Purge()
		c.mtx.Unlock()
	}

	binCnr := make([]byte, sha256.Size)
	cnr.Encode(binCnr)

	cn, err := nm.ContainerNodes(p, binCnr)
	if err != nil {
		return nil, fmt.Errorf("could not get container nodes: %w", err)
	}

	c.mtx.Lock()
	if c.lastEpoch == nm.Epoch() {
		c.containerCache.Add(cnr, cn)
	}
	c.mtx.Unlock()
	return c.cloneResult(cn), nil
}

func (c *ContainerNodesCache) cloneResult(nodes [][]netmapSDK.NodeInfo) [][]netmapSDK.NodeInfo {
	result := make([][]netmapSDK.NodeInfo, len(nodes))
	for repIdx := range nodes {
		result[repIdx] = slices.Clone(nodes[repIdx])
	}
	return result
}