package writer

import (
	"context"
	"crypto/ecdsa"
	"encoding/hex"
	"errors"
	"fmt"
	"sync/atomic"

	"git.frostfs.info/TrueCloudLab/frostfs-node/internal/logs"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/client"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/object"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/policy"
	svcutil "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object/util"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object_manager/placement"
	containerSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container"
	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/erasurecode"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/transformer"
	"go.uber.org/zap"
	"golang.org/x/sync/errgroup"
)

var _ transformer.ObjectWriter = (*ECWriter)(nil)

var (
	errUnsupportedECObject    = errors.New("object is not supported for erasure coding")
	errFailedToSaveAllECParts = errors.New("failed to save all EC parts")
)

type ECWriter struct {
	Config        *Config
	PlacementOpts []placement.Option
	Container     containerSDK.Container
	Key           *ecdsa.PrivateKey
	CommonPrm     *svcutil.CommonPrm
	Relay         func(context.Context, client.NodeInfo, client.MultiAddressClient) error

	ObjectMeta      object.ContentMeta
	ObjectMetaValid bool

	remoteRequestSignKey *ecdsa.PrivateKey
}

func (e *ECWriter) WriteObject(ctx context.Context, obj *objectSDK.Object) error {
	relayed, isContainerNode, err := e.relayIfNotContainerNode(ctx, obj)
	if err != nil {
		return err
	}
	if relayed {
		return nil
	}

	if !object.IsECSupported(obj) {
		// must be resolved by caller
		return errUnsupportedECObject
	}

	if !e.ObjectMetaValid {
		if e.ObjectMeta, err = e.Config.FormatValidator.ValidateContent(obj); err != nil {
			return fmt.Errorf("(%T) could not validate payload content: %w", e, err)
		}
		e.ObjectMetaValid = true
	}

	if isContainerNode {
		restoreTokens := e.CommonPrm.ForgetTokens()
		defer restoreTokens()
		// As request executed on container node, so sign request with container key.
		e.remoteRequestSignKey, err = e.Config.KeyStorage.GetKey(nil)
		if err != nil {
			return err
		}
	} else {
		e.remoteRequestSignKey = e.Key
	}

	if obj.ECHeader() != nil {
		return e.writeECPart(ctx, obj)
	}
	return e.writeRawObject(ctx, obj)
}

func (e *ECWriter) relayIfNotContainerNode(ctx context.Context, obj *objectSDK.Object) (bool, bool, error) {
	currentNodeIsContainerNode, err := e.currentNodeIsContainerNode()
	if err != nil {
		return false, false, err
	}
	if currentNodeIsContainerNode {
		// object can be splitted or saved local
		return false, true, nil
	}
	if e.Relay == nil {
		return false, currentNodeIsContainerNode, nil
	}
	objID := object.AddressOf(obj).Object()
	var index uint32
	if obj.ECHeader() != nil {
		objID = obj.ECHeader().Parent()
		index = obj.ECHeader().Index()
	}
	if err := e.relayToContainerNode(ctx, objID, index); err != nil {
		return false, false, err
	}
	return true, currentNodeIsContainerNode, nil
}

func (e *ECWriter) currentNodeIsContainerNode() (bool, error) {
	t, err := placement.NewTraverser(e.PlacementOpts...)
	if err != nil {
		return false, err
	}
	for {
		nodes := t.Next()
		if len(nodes) == 0 {
			break
		}
		for _, node := range nodes {
			if e.Config.NetmapKeys.IsLocalKey(node.PublicKey()) {
				return true, nil
			}
		}
	}
	return false, nil
}

func (e *ECWriter) relayToContainerNode(ctx context.Context, objID oid.ID, index uint32) error {
	t, err := placement.NewTraverser(append(e.PlacementOpts, placement.ForObject(objID))...)
	if err != nil {
		return err
	}
	var lastErr error
	offset := int(index)
	for {
		nodes := t.Next()
		if len(nodes) == 0 {
			break
		}
		for idx := range nodes {
			node := nodes[(idx+offset)%len(nodes)]
			var info client.NodeInfo
			client.NodeInfoFromNetmapElement(&info, node)

			c, err := e.Config.ClientConstructor.Get(info)
			if err != nil {
				return fmt.Errorf("could not create SDK client %s: %w", info.AddressGroup(), err)
			}

			completed := make(chan interface{})
			if poolErr := e.Config.RemotePool.Submit(func() {
				defer close(completed)
				err = e.Relay(ctx, info, c)
			}); poolErr != nil {
				close(completed)
				svcutil.LogWorkerPoolError(ctx, e.Config.Logger, "PUT", poolErr)
				return poolErr
			}
			<-completed

			if err == nil {
				return nil
			}
			e.Config.Logger.Warn(ctx, logs.ECFailedToSendToContainerNode, zap.Stringers("address_group", info.AddressGroup()))
			lastErr = err
		}
	}
	if lastErr == nil {
		return nil
	}
	return errIncompletePut{
		singleErr: lastErr,
	}
}

func (e *ECWriter) writeECPart(ctx context.Context, obj *objectSDK.Object) error {
	if e.CommonPrm.LocalOnly() {
		return e.writePartLocal(ctx, obj)
	}

	t, err := placement.NewTraverser(append(e.PlacementOpts, placement.ForObject(obj.ECHeader().Parent()))...)
	if err != nil {
		return err
	}

	eg, egCtx := errgroup.WithContext(ctx)
	for {
		nodes := t.Next()
		if len(nodes) == 0 {
			break
		}

		eg.Go(func() error {
			return e.writePart(egCtx, obj, int(obj.ECHeader().Index()), nodes, make([]atomic.Bool, len(nodes)))
		})
		t.SubmitSuccess()
	}
	if err := eg.Wait(); err != nil {
		return errIncompletePut{
			singleErr: err,
		}
	}
	return nil
}

func (e *ECWriter) writeRawObject(ctx context.Context, obj *objectSDK.Object) error {
	// now only single EC policy is supported
	c, err := erasurecode.NewConstructor(policy.ECDataCount(e.Container.PlacementPolicy()), policy.ECParityCount(e.Container.PlacementPolicy()))
	if err != nil {
		return err
	}
	parts, err := c.Split(obj, e.Key)
	if err != nil {
		return err
	}
	partsProcessed := make([]atomic.Bool, len(parts))
	objID, _ := obj.ID()
	t, err := placement.NewTraverser(append(e.PlacementOpts, placement.ForObject(objID))...)
	if err != nil {
		return err
	}

	for {
		eg, egCtx := errgroup.WithContext(ctx)
		nodes := t.Next()
		if len(nodes) == 0 {
			break
		}

		visited := make([]atomic.Bool, len(nodes))
		for idx := range parts {
			visited[idx%len(nodes)].Store(true)
		}

		for idx := range parts {
			if !partsProcessed[idx].Load() {
				eg.Go(func() error {
					err := e.writePart(egCtx, parts[idx], idx, nodes, visited)
					if err == nil {
						partsProcessed[idx].Store(true)
						t.SubmitSuccess()
					}
					return err
				})
			}
		}
		err = eg.Wait()
	}
	if err != nil {
		return errIncompletePut{
			singleErr: err,
		}
	}
	for idx := range partsProcessed {
		if !partsProcessed[idx].Load() {
			return errIncompletePut{
				singleErr: errFailedToSaveAllECParts,
			}
		}
	}
	return nil
}

func (e *ECWriter) writePart(ctx context.Context, obj *objectSDK.Object, partIdx int, nodes []placement.Node, visited []atomic.Bool) error {
	select {
	case <-ctx.Done():
		return ctx.Err()
	default:
	}

	// try to save to node for current part index
	node := nodes[partIdx%len(nodes)]
	err := e.putECPartToNode(ctx, obj, node)
	if err == nil {
		return nil
	}
	e.Config.Logger.Warn(ctx, logs.ECFailedToSaveECPart, zap.Stringer("part_address", object.AddressOf(obj)),
		zap.Stringer("parent_address", obj.ECHeader().Parent()), zap.Int("part_index", partIdx),
		zap.String("node", hex.EncodeToString(node.PublicKey())), zap.Error(err))

	partVisited := make([]bool, len(nodes))
	partVisited[partIdx%len(nodes)] = true

	// try to save to any node not visited by any of other parts
	for i := 1; i < len(nodes); i++ {
		select {
		case <-ctx.Done():
			return ctx.Err()
		default:
		}

		idx := (partIdx + i) % len(nodes)
		if !visited[idx].CompareAndSwap(false, true) {
			continue
		}
		node = nodes[idx]
		err := e.putECPartToNode(ctx, obj, node)
		if err == nil {
			return nil
		}
		e.Config.Logger.Warn(ctx, logs.ECFailedToSaveECPart, zap.Stringer("part_address", object.AddressOf(obj)),
			zap.Stringer("parent_address", obj.ECHeader().Parent()), zap.Int("part_index", partIdx),
			zap.String("node", hex.EncodeToString(node.PublicKey())),
			zap.Error(err))

		partVisited[idx] = true
	}

	// try to save to any node not visited by current part
	for i := range nodes {
		select {
		case <-ctx.Done():
			return ctx.Err()
		default:
		}

		if partVisited[i] {
			continue
		}
		node = nodes[i]
		err := e.putECPartToNode(ctx, obj, node)
		if err == nil {
			return nil
		}
		e.Config.Logger.Warn(ctx, logs.ECFailedToSaveECPart, zap.Stringer("part_address", object.AddressOf(obj)),
			zap.Stringer("parent_address", obj.ECHeader().Parent()), zap.Int("part_index", partIdx),
			zap.String("node", hex.EncodeToString(node.PublicKey())),
			zap.Error(err))
	}

	return fmt.Errorf("failed to save EC chunk %s to any node", object.AddressOf(obj))
}

func (e *ECWriter) putECPartToNode(ctx context.Context, obj *objectSDK.Object, node placement.Node) error {
	if e.Config.NetmapKeys.IsLocalKey(node.PublicKey()) {
		return e.writePartLocal(ctx, obj)
	}
	return e.writePartRemote(ctx, obj, node)
}

func (e *ECWriter) writePartLocal(ctx context.Context, obj *objectSDK.Object) error {
	var err error
	localTarget := LocalTarget{
		Storage:   e.Config.LocalStore,
		Container: e.Container,
	}
	completed := make(chan interface{})
	if poolErr := e.Config.LocalPool.Submit(func() {
		defer close(completed)
		err = localTarget.WriteObject(ctx, obj, e.ObjectMeta)
	}); poolErr != nil {
		close(completed)
		return poolErr
	}
	<-completed
	return err
}

func (e *ECWriter) writePartRemote(ctx context.Context, obj *objectSDK.Object, node placement.Node) error {
	var clientNodeInfo client.NodeInfo
	client.NodeInfoFromNetmapElement(&clientNodeInfo, node)

	remoteTaget := remoteWriter{
		privateKey:        e.remoteRequestSignKey,
		clientConstructor: e.Config.ClientConstructor,
		commonPrm:         e.CommonPrm,
		nodeInfo:          clientNodeInfo,
	}

	var err error
	completed := make(chan interface{})
	if poolErr := e.Config.RemotePool.Submit(func() {
		defer close(completed)
		err = remoteTaget.WriteObject(ctx, obj, e.ObjectMeta)
	}); poolErr != nil {
		close(completed)
		return poolErr
	}
	<-completed
	return err
}