package policer

import (
	"context"
	"encoding/hex"
	"errors"
	"fmt"

	"git.frostfs.info/TrueCloudLab/frostfs-node/internal/logs"
	objectcore "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/object"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/replicator"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap"
	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/erasurecode"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	"go.uber.org/zap"
	"golang.org/x/sync/errgroup"
)

var errNoECinfoReturnded = errors.New("no EC info returned")

type ecChunkProcessResult struct {
	validPlacement bool
	removeLocal    bool
}

var errInvalidECPlacement = errors.New("invalid EC placement: EC placement must have one placement vector")

func (p *Policer) processECContainerObject(ctx context.Context, objInfo objectcore.Info, policy netmap.PlacementPolicy) error {
	if objInfo.ECInfo == nil {
		return p.processECContainerRepObject(ctx, objInfo, policy)
	}
	return p.processECContainerECObject(ctx, objInfo, policy)
}

// processECContainerRepObject processes non erasure coded objects in EC container: tombstones, locks and linking objects.
// All of them must be stored on all of the container nodes.
func (p *Policer) processECContainerRepObject(ctx context.Context, objInfo objectcore.Info, policy netmap.PlacementPolicy) error {
	objID := objInfo.Address.Object()
	nn, err := p.placementBuilder.BuildPlacement(objInfo.Address.Container(), &objID, policy)
	if err != nil {
		return fmt.Errorf("%s: %w", logs.PolicerCouldNotBuildPlacementVectorForObject, err)
	}
	if len(nn) != 1 {
		return errInvalidECPlacement
	}

	c := &placementRequirements{}
	checkedNodes := newNodeCache()

	select {
	case <-ctx.Done():
		return ctx.Err()
	default:
	}

	p.processRepNodes(ctx, c, objInfo, nn[0], uint32(len(nn[0])), checkedNodes)

	if !c.needLocalCopy && c.removeLocalCopy {
		p.log.Info(logs.PolicerRedundantLocalObjectCopyDetected,
			zap.Stringer("object", objInfo.Address),
		)

		p.cbRedundantCopy(ctx, objInfo.Address)
	}
	return nil
}

func (p *Policer) processECContainerECObject(ctx context.Context, objInfo objectcore.Info, policy netmap.PlacementPolicy) error {
	nn, err := p.placementBuilder.BuildPlacement(objInfo.Address.Container(), &objInfo.ECInfo.ParentID, policy)
	if err != nil {
		return fmt.Errorf("%s: %w", logs.PolicerCouldNotBuildPlacementVectorForObject, err)
	}
	if len(nn) != 1 {
		return errInvalidECPlacement
	}

	select {
	case <-ctx.Done():
		return ctx.Err()
	default:
	}

	res := p.processECChunk(ctx, objInfo, nn[0])
	if !res.validPlacement {
		// drop local chunk only if all required chunks are in place
		res.removeLocal = res.removeLocal && p.pullRequiredECChunks(ctx, objInfo, nn[0])
	}
	p.adjustECPlacement(ctx, objInfo, nn[0], policy)

	if res.removeLocal {
		p.log.Info(logs.PolicerRedundantLocalObjectCopyDetected, zap.Stringer("object", objInfo.Address))
		p.cbRedundantCopy(ctx, objInfo.Address)
	}
	return nil
}

// processECChunk replicates EC chunk if needed.
func (p *Policer) processECChunk(ctx context.Context, objInfo objectcore.Info, nodes []netmap.NodeInfo) ecChunkProcessResult {
	var removeLocalChunk bool
	requiredNode := nodes[int(objInfo.ECInfo.Index)%(len(nodes))]
	if p.cfg.netmapKeys.IsLocalKey(requiredNode.PublicKey()) {
		// current node is required node, we are happy
		return ecChunkProcessResult{
			validPlacement: true,
		}
	}
	if requiredNode.IsMaintenance() {
		// consider maintenance mode has object, but do not drop local copy
		p.log.Debug(logs.PolicerConsiderNodeUnderMaintenanceAsOK, zap.String("node", netmap.StringifyPublicKey(requiredNode)))
		return ecChunkProcessResult{}
	}

	callCtx, cancel := context.WithTimeout(ctx, p.headTimeout)
	_, err := p.remoteHeader(callCtx, requiredNode, objInfo.Address, false)
	cancel()

	if err == nil {
		removeLocalChunk = true
	} else if client.IsErrObjectNotFound(err) {
		p.log.Debug(logs.PolicerShortageOfObjectCopiesDetected, zap.Stringer("object", objInfo.Address), zap.Uint32("shortage", 1))
		task := replicator.Task{
			NumCopies: 1,
			Addr:      objInfo.Address,
			Nodes:     []netmap.NodeInfo{requiredNode},
		}
		p.replicator.HandleReplicationTask(ctx, task, newNodeCache())
	} else if client.IsErrNodeUnderMaintenance(err) {
		// consider maintenance mode has object, but do not drop local copy
		p.log.Debug(logs.PolicerConsiderNodeUnderMaintenanceAsOK, zap.String("node", netmap.StringifyPublicKey(requiredNode)))
	} else {
		p.log.Error(logs.PolicerReceiveObjectHeaderToCheckPolicyCompliance, zap.Stringer("object", objInfo.Address), zap.String("error", err.Error()))
	}

	return ecChunkProcessResult{
		removeLocal: removeLocalChunk,
	}
}

func (p *Policer) pullRequiredECChunks(ctx context.Context, objInfo objectcore.Info, nodes []netmap.NodeInfo) bool {
	var parentAddress oid.Address
	parentAddress.SetContainer(objInfo.Address.Container())
	parentAddress.SetObject(objInfo.ECInfo.ParentID)

	requiredChunkIndexes := p.collectRequiredECChunks(nodes, objInfo)
	if len(requiredChunkIndexes) == 0 {
		p.log.Info(logs.PolicerNodeIsNotECObjectNode, zap.Stringer("object", objInfo.ECInfo.ParentID))
		return true
	}

	err := p.resolveLocalECChunks(ctx, parentAddress, requiredChunkIndexes)
	if err != nil {
		p.log.Error(logs.PolicerFailedToGetLocalECChunks, zap.Error(err), zap.Stringer("object", parentAddress))
		return false
	}
	if len(requiredChunkIndexes) == 0 {
		return true
	}

	indexToObjectID := make(map[uint32]oid.ID)
	success := p.resolveRemoteECChunks(ctx, parentAddress, nodes, requiredChunkIndexes, indexToObjectID)
	if !success {
		return false
	}

	for index, candidates := range requiredChunkIndexes {
		var addr oid.Address
		addr.SetContainer(objInfo.Address.Container())
		addr.SetObject(indexToObjectID[index])
		p.replicator.HandlePullTask(ctx, replicator.Task{
			Addr:  addr,
			Nodes: candidates,
		})
	}
	// there was some missing chunks, it's not ok
	return false
}

func (p *Policer) collectRequiredECChunks(nodes []netmap.NodeInfo, objInfo objectcore.Info) map[uint32][]netmap.NodeInfo {
	requiredChunkIndexes := make(map[uint32][]netmap.NodeInfo)
	for i, n := range nodes {
		if uint32(i) == objInfo.ECInfo.Total {
			break
		}
		if p.cfg.netmapKeys.IsLocalKey(n.PublicKey()) {
			requiredChunkIndexes[uint32(i)] = []netmap.NodeInfo{}
		}
	}
	return requiredChunkIndexes
}

func (p *Policer) resolveLocalECChunks(ctx context.Context, parentAddress oid.Address, required map[uint32][]netmap.NodeInfo) error {
	_, err := p.localHeader(ctx, parentAddress)
	var eiErr *objectSDK.ECInfoError
	if err == nil { // should not be happen
		return errNoECinfoReturnded
	}
	if !errors.As(err, &eiErr) {
		return err
	}
	for _, ch := range eiErr.ECInfo().Chunks {
		delete(required, ch.Index)
	}
	return nil
}

func (p *Policer) resolveRemoteECChunks(ctx context.Context, parentAddress oid.Address, nodes []netmap.NodeInfo, required map[uint32][]netmap.NodeInfo, indexToObjectID map[uint32]oid.ID) bool {
	var eiErr *objectSDK.ECInfoError
	for _, n := range nodes {
		if p.cfg.netmapKeys.IsLocalKey(n.PublicKey()) {
			continue
		}
		_, err := p.remoteHeader(ctx, n, parentAddress, true)
		if !errors.As(err, &eiErr) {
			continue
		}
		for _, ch := range eiErr.ECInfo().Chunks {
			if candidates, ok := required[ch.Index]; ok {
				candidates = append(candidates, n)
				required[ch.Index] = candidates

				var chunkID oid.ID
				if err := chunkID.ReadFromV2(ch.ID); err != nil {
					p.log.Error(logs.PolicerFailedToDecodeECChunkID, zap.Error(err), zap.Stringer("object", parentAddress))
					return false
				}
				if existed, ok := indexToObjectID[ch.Index]; ok && existed != chunkID {
					p.log.Error(logs.PolicerDifferentObjectIDForTheSameECChunk, zap.Stringer("first", existed),
						zap.Stringer("second", chunkID), zap.Stringer("object", parentAddress), zap.Uint32("index", ch.Index))
					return false
				}
				indexToObjectID[ch.Index] = chunkID
			}
		}
	}

	for index, candidates := range required {
		if len(candidates) == 0 {
			p.log.Error(logs.PolicerMissingECChunk, zap.Stringer("object", parentAddress), zap.Uint32("index", index))
			return false
		}
	}

	return true
}

func (p *Policer) adjustECPlacement(ctx context.Context, objInfo objectcore.Info, nodes []netmap.NodeInfo, policy netmap.PlacementPolicy) {
	var parentAddress oid.Address
	parentAddress.SetContainer(objInfo.Address.Container())
	parentAddress.SetObject(objInfo.ECInfo.ParentID)
	var eiErr *objectSDK.ECInfoError
	resolved := make(map[uint32][]netmap.NodeInfo)
	chunkIDs := make(map[uint32]oid.ID)
	restore := true // do not restore EC chunks if some node returned error
	for idx, n := range nodes {
		if uint32(idx) >= objInfo.ECInfo.Total && uint32(len(resolved)) == objInfo.ECInfo.Total {
			return
		}
		var err error
		if p.cfg.netmapKeys.IsLocalKey(n.PublicKey()) {
			_, err = p.localHeader(ctx, parentAddress)
		} else {
			_, err = p.remoteHeader(ctx, n, parentAddress, true)
		}

		if errors.As(err, &eiErr) {
			for _, ch := range eiErr.ECInfo().Chunks {
				resolved[ch.Index] = append(resolved[ch.Index], n)
				var ecInfoChunkID oid.ID
				if err := ecInfoChunkID.ReadFromV2(ch.ID); err != nil {
					p.log.Error(logs.PolicerFailedToDecodeECChunkID, zap.Error(err), zap.Stringer("object", parentAddress))
					return
				}
				if chunkID, exist := chunkIDs[ch.Index]; exist && chunkID != ecInfoChunkID {
					p.log.Error(logs.PolicerDifferentObjectIDForTheSameECChunk, zap.Stringer("first", chunkID),
						zap.Stringer("second", ecInfoChunkID), zap.Stringer("object", parentAddress), zap.Uint32("index", ch.Index))
					return
				}
				chunkIDs[ch.Index] = ecInfoChunkID
			}
		} else if !p.cfg.netmapKeys.IsLocalKey(n.PublicKey()) && uint32(idx) < objInfo.ECInfo.Total {
			p.log.Warn(logs.PolicerCouldNotGetObjectFromNodeMoving, zap.String("node", hex.EncodeToString(n.PublicKey())), zap.Stringer("object", parentAddress), zap.Error(err))
			p.replicator.HandleReplicationTask(ctx, replicator.Task{
				NumCopies: 1,
				Addr:      objInfo.Address,
				Nodes:     []netmap.NodeInfo{n},
			}, newNodeCache())
			restore = false
		}
	}
	if !restore || uint32(len(resolved)) == objInfo.ECInfo.Total {
		return
	}
	if objInfo.ECInfo.Total-uint32(len(resolved)) > policy.ReplicaDescriptor(0).GetECParityCount() {
		var found []uint32
		for i := range resolved {
			found = append(found, i)
		}
		p.log.Error(logs.PolicerCouldNotRestoreObjectNotEnoughChunks, zap.Stringer("object", parentAddress), zap.Uint32s("found_chunks", found))
		return
	}
	p.restoreECObject(ctx, objInfo, parentAddress, nodes, resolved, chunkIDs, policy)
}

func (p *Policer) restoreECObject(ctx context.Context, objInfo objectcore.Info, parentAddress oid.Address, nodes []netmap.NodeInfo, existedChunks map[uint32][]netmap.NodeInfo, chunkIDs map[uint32]oid.ID, policy netmap.PlacementPolicy) {
	c, err := erasurecode.NewConstructor(int(policy.ReplicaDescriptor(0).GetECDataCount()), int(policy.ReplicaDescriptor(0).GetECParityCount()))
	if err != nil {
		p.log.Error(logs.PolicerFailedToRestoreObject, zap.Stringer("object", parentAddress), zap.Error(err))
		return
	}
	parts := p.collectExistedChunks(ctx, objInfo, existedChunks, parentAddress, chunkIDs)
	if parts == nil {
		return
	}
	key, err := p.keyStorage.GetKey(nil)
	if err != nil {
		p.log.Error(logs.PolicerFailedToRestoreObject, zap.Stringer("object", parentAddress), zap.Error(err))
		return
	}
	required := make([]bool, len(parts))
	for i, p := range parts {
		if p == nil {
			required[i] = true
		}
	}
	if err := c.ReconstructParts(parts, required, key); err != nil {
		p.log.Error(logs.PolicerFailedToRestoreObject, zap.Stringer("object", parentAddress), zap.Error(err))
		return
	}
	for idx, part := range parts {
		if _, exists := existedChunks[uint32(idx)]; exists {
			continue
		}
		var addr oid.Address
		addr.SetContainer(parentAddress.Container())
		pID, _ := part.ID()
		addr.SetObject(pID)
		targetNode := nodes[idx%len(nodes)]
		if p.cfg.netmapKeys.IsLocalKey(targetNode.PublicKey()) {
			p.replicator.HandleLocalPutTask(ctx, replicator.Task{
				Addr: addr,
				Obj:  part,
			})
		} else {
			p.replicator.HandleReplicationTask(ctx, replicator.Task{
				NumCopies: 1,
				Addr:      addr,
				Nodes:     []netmap.NodeInfo{targetNode},
				Obj:       part,
			}, newNodeCache())
		}
	}
}

func (p *Policer) collectExistedChunks(ctx context.Context, objInfo objectcore.Info, existedChunks map[uint32][]netmap.NodeInfo, parentAddress oid.Address, chunkIDs map[uint32]oid.ID) []*objectSDK.Object {
	parts := make([]*objectSDK.Object, objInfo.ECInfo.Total)
	errGroup, egCtx := errgroup.WithContext(ctx)
	for idx, nodes := range existedChunks {
		idx := idx
		nodes := nodes
		errGroup.Go(func() error {
			var objID oid.Address
			objID.SetContainer(parentAddress.Container())
			objID.SetObject(chunkIDs[idx])
			var obj *objectSDK.Object
			var err error
			for _, node := range nodes {
				if p.cfg.netmapKeys.IsLocalKey(node.PublicKey()) {
					obj, err = p.localObject(egCtx, objID)
				} else {
					obj, err = p.remoteObject(egCtx, node, objID)
				}
				if err == nil {
					break
				}
				p.log.Warn(logs.PolicerCouldNotGetChunk, zap.Stringer("object", parentAddress), zap.Stringer("chunkID", objID), zap.Error(err), zap.String("node", hex.EncodeToString(node.PublicKey())))
			}
			if obj != nil {
				parts[idx] = obj
			}
			return nil
		})
	}
	if err := errGroup.Wait(); err != nil {
		p.log.Error(logs.PolicerCouldNotGetChunks, zap.Stringer("object", parentAddress), zap.Error(err))
		return nil
	}
	return parts
}