package getsvc

import (
	"context"
	"encoding/hex"
	"errors"
	"fmt"
	"sync"

	"git.frostfs.info/TrueCloudLab/frostfs-node/internal/logs"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/client"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/container"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/policy"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object_manager/placement"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/logger"
	apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status"
	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/erasurecode"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	"go.uber.org/zap"
	"golang.org/x/sync/errgroup"
)

var errECPartsRetrieveCompleted = errors.New("EC parts receive completed")

type ecRemoteStorage interface {
	getObjectFromNode(ctx context.Context, addr oid.Address, info client.NodeInfo) (*objectSDK.Object, error)
	headObjectFromNode(ctx context.Context, addr oid.Address, info client.NodeInfo, raw bool) (*objectSDK.Object, error)
}

type assemblerec struct {
	addr               oid.Address
	ecInfo             *ecInfo
	rng                *objectSDK.Range
	remoteStorage      ecRemoteStorage
	localStorage       localStorage
	log                *logger.Logger
	head               bool
	traverserGenerator traverserGenerator
	epoch              uint64
}

func newAssemblerEC(
	addr oid.Address,
	ecInfo *ecInfo,
	rng *objectSDK.Range,
	remoteStorage ecRemoteStorage,
	localStorage localStorage,
	log *logger.Logger,
	head bool,
	tg traverserGenerator,
	epoch uint64,
) *assemblerec {
	return &assemblerec{
		addr:               addr,
		rng:                rng,
		ecInfo:             ecInfo,
		remoteStorage:      remoteStorage,
		localStorage:       localStorage,
		log:                log,
		head:               head,
		traverserGenerator: tg,
		epoch:              epoch,
	}
}

// Assemble assembles erasure-coded object and writes it's content to ObjectWriter.
// It returns parent object.
func (a *assemblerec) Assemble(ctx context.Context, writer ObjectWriter) (*objectSDK.Object, error) {
	switch {
	case a.head:
		return a.reconstructHeader(ctx, writer)
	case a.rng != nil:
		return a.reconstructRange(ctx, writer)
	default:
		return a.reconstructObject(ctx, writer)
	}
}

func (a *assemblerec) getConstructor(cnr *container.Container) (*erasurecode.Constructor, error) {
	dataCount := policy.ECDataCount(cnr.Value.PlacementPolicy())
	parityCount := policy.ECParityCount(cnr.Value.PlacementPolicy())
	return erasurecode.NewConstructor(dataCount, parityCount)
}

func (a *assemblerec) reconstructHeader(ctx context.Context, writer ObjectWriter) (*objectSDK.Object, error) {
	obj, err := a.reconstructObjectFromParts(ctx, true)
	if err == nil {
		return obj, writer.WriteHeader(ctx, obj)
	}
	return nil, err
}

func (a *assemblerec) reconstructRange(ctx context.Context, writer ObjectWriter) (*objectSDK.Object, error) {
	obj, err := a.reconstructObjectFromParts(ctx, false)
	if err != nil {
		return nil, err
	}

	from := a.rng.GetOffset()
	to := from + a.rng.GetLength()
	if pLen := uint64(len(obj.Payload())); to < from || pLen < from || pLen < to {
		return nil, &apistatus.ObjectOutOfRange{}
	}
	err = writer.WriteChunk(ctx, obj.Payload()[from:to])
	if err != nil {
		return nil, err
	}
	return obj, err
}

func (a *assemblerec) reconstructObject(ctx context.Context, writer ObjectWriter) (*objectSDK.Object, error) {
	obj, err := a.reconstructObjectFromParts(ctx, false)
	if err == nil {
		err = writer.WriteHeader(ctx, obj.CutPayload())
		if err == nil {
			err = writer.WriteChunk(ctx, obj.Payload())
			if err != nil {
				return nil, err
			}
		}
	}
	return obj, err
}

func (a *assemblerec) reconstructObjectFromParts(ctx context.Context, headers bool) (*objectSDK.Object, error) {
	objID := a.addr.Object()
	trav, cnr, err := a.traverserGenerator.GenerateTraverser(a.addr.Container(), &objID, a.epoch)
	if err != nil {
		return nil, err
	}
	c, err := a.getConstructor(cnr)
	if err != nil {
		return nil, err
	}
	parts := a.retrieveParts(ctx, trav, cnr)
	if headers {
		return c.ReconstructHeader(parts)
	}
	return c.Reconstruct(parts)
}

func (a *assemblerec) retrieveParts(ctx context.Context, trav *placement.Traverser, cnr *container.Container) []*objectSDK.Object {
	dataCount := policy.ECDataCount(cnr.Value.PlacementPolicy())
	parityCount := policy.ECParityCount(cnr.Value.PlacementPolicy())

	remoteNodes := make([]placement.Node, 0)
	for {
		batch := trav.Next()
		if len(batch) == 0 {
			break
		}
		remoteNodes = append(remoteNodes, batch...)
	}

	parts, err := a.processECNodesRequests(ctx, remoteNodes, dataCount, parityCount)
	if err != nil {
		a.log.Debug(logs.GetUnableToGetAllPartsECObject, zap.Error(err))
	}
	return parts
}

func (a *assemblerec) processECNodesRequests(ctx context.Context, nodes []placement.Node, dataCount, parityCount int) ([]*objectSDK.Object, error) {
	foundChunks := make(map[uint32]*objectSDK.Object)
	var foundChunksGuard sync.Mutex
	eg, ctx := errgroup.WithContext(ctx)
	eg.SetLimit(dataCount)

	for _, ch := range a.ecInfo.localChunks {
		eg.Go(func() error {
			select {
			case <-ctx.Done():
				return ctx.Err()
			default:
			}
			object := a.tryGetChunkFromLocalStorage(ctx, ch)
			if object == nil {
				return nil
			}
			foundChunksGuard.Lock()
			foundChunks[ch.Index] = object
			count := len(foundChunks)
			foundChunksGuard.Unlock()
			if count >= dataCount {
				return errECPartsRetrieveCompleted
			}
			return nil
		})
	}

	for _, node := range nodes {
		var info client.NodeInfo
		client.NodeInfoFromNetmapElement(&info, node)
		eg.Go(func() error {
			select {
			case <-ctx.Done():
				return ctx.Err()
			default:
			}
			chunks := a.tryGetChunkListFromNode(ctx, info)
			for _, ch := range chunks {
				object := a.tryGetChunkFromRemoteStorage(ctx, info, ch)
				if object == nil {
					continue
				}
				foundChunksGuard.Lock()
				foundChunks[ch.Index] = object
				count := len(foundChunks)
				foundChunksGuard.Unlock()
				if count >= dataCount {
					return errECPartsRetrieveCompleted
				}
			}
			return nil
		})
	}
	err := eg.Wait()
	if err == nil || errors.Is(err, errECPartsRetrieveCompleted) {
		parts := make([]*objectSDK.Object, dataCount+parityCount)
		for idx, chunk := range foundChunks {
			parts[idx] = chunk
		}
		return parts, nil
	}
	return nil, err
}

func (a *assemblerec) tryGetChunkFromLocalStorage(ctx context.Context, ch objectSDK.ECChunk) *objectSDK.Object {
	var objID oid.ID
	err := objID.ReadFromV2(ch.ID)
	if err != nil {
		a.log.Error(logs.GetUnableToHeadPartECObject, zap.String("node", "local"), zap.Uint32("part_index", ch.Index), zap.Error(fmt.Errorf("invalid object ID: %w", err)))
		return nil
	}
	var addr oid.Address
	addr.SetContainer(a.addr.Container())
	addr.SetObject(objID)
	var object *objectSDK.Object
	if a.head {
		object, err = a.localStorage.Head(ctx, addr, false)
		if err != nil {
			a.log.Warn(logs.GetUnableToHeadPartECObject, zap.String("node", "local"), zap.Stringer("part_id", objID), zap.Error(err))
			return nil
		}
	} else {
		object, err = a.localStorage.Get(ctx, addr)
		if err != nil {
			a.log.Warn(logs.GetUnableToGetPartECObject, zap.String("node", "local"), zap.Stringer("part_id", objID), zap.Error(err))
			return nil
		}
	}
	return object
}

func (a *assemblerec) tryGetChunkListFromNode(ctx context.Context, node client.NodeInfo) []objectSDK.ECChunk {
	if chunks, found := a.ecInfo.remoteChunks[string(node.PublicKey())]; found {
		return chunks
	}
	var errECInfo *objectSDK.ECInfoError
	_, err := a.remoteStorage.headObjectFromNode(ctx, a.addr, node, true)
	if err == nil {
		a.log.Error(logs.GetUnexpectedECObject, zap.String("node", hex.EncodeToString(node.PublicKey())))
		return nil
	}
	if !errors.As(err, &errECInfo) {
		a.log.Warn(logs.GetUnableToHeadPartsECObject, zap.String("node", hex.EncodeToString(node.PublicKey())), zap.Error(err))
		return nil
	}
	result := make([]objectSDK.ECChunk, 0, len(errECInfo.ECInfo().Chunks))
	for _, ch := range errECInfo.ECInfo().Chunks {
		result = append(result, objectSDK.ECChunk(ch))
	}
	return result
}

func (a *assemblerec) tryGetChunkFromRemoteStorage(ctx context.Context, node client.NodeInfo, ch objectSDK.ECChunk) *objectSDK.Object {
	var objID oid.ID
	err := objID.ReadFromV2(ch.ID)
	if err != nil {
		a.log.Error(logs.GetUnableToHeadPartECObject, zap.String("node", hex.EncodeToString(node.PublicKey())), zap.Uint32("part_index", ch.Index), zap.Error(fmt.Errorf("invalid object ID: %w", err)))
		return nil
	}
	var addr oid.Address
	addr.SetContainer(a.addr.Container())
	addr.SetObject(objID)
	var object *objectSDK.Object
	if a.head {
		object, err = a.remoteStorage.headObjectFromNode(ctx, addr, node, false)
		if err != nil {
			a.log.Warn(logs.GetUnableToHeadPartECObject, zap.String("node", hex.EncodeToString(node.PublicKey())), zap.Stringer("part_id", objID), zap.Error(err))
			return nil
		}
	} else {
		object, err = a.remoteStorage.getObjectFromNode(ctx, addr, node)
		if err != nil {
			a.log.Warn(logs.GetUnableToGetPartECObject, zap.String("node", hex.EncodeToString(node.PublicKey())), zap.Stringer("part_id", objID), zap.Error(err))
			return nil
		}
	}
	return object
}