package frostfs

import (
	"crypto/elliptic"
	"fmt"

	"git.frostfs.info/TrueCloudLab/frostfs-node/internal/logs"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/morph/client/frostfsid"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user"
	"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
	"github.com/nspcc-dev/neo-go/pkg/util"
	"go.uber.org/zap"
)

type bindCommon interface {
	User() []byte
	Keys() [][]byte
	TxHash() util.Uint256
}

func (np *Processor) processBind(e bindCommon, bind bool) bool {
	if !np.alphabetState.IsAlphabet() {
		np.log.Info(logs.FrostFSNonAlphabetModeIgnoreBind)
		return true
	}

	c := &bindCommonContext{
		bindCommon: e,
		bind:       bind,
	}

	err := np.checkBindCommon(c)
	if err != nil {
		np.log.Error(logs.FrostFSInvalidManageKeyEvent,
			zap.Bool("bind", c.bind),
			zap.String("error", err.Error()),
		)

		return false
	}

	return np.approveBindCommon(c) == nil
}

type bindCommonContext struct {
	bindCommon

	bind bool

	scriptHash util.Uint160
}

func (np *Processor) checkBindCommon(e *bindCommonContext) error {
	var err error

	e.scriptHash, err = util.Uint160DecodeBytesBE(e.User())
	if err != nil {
		return err
	}

	curve := elliptic.P256()

	for _, key := range e.Keys() {
		_, err = keys.NewPublicKeyFromBytes(key, curve)
		if err != nil {
			return err
		}
	}

	return nil
}

func (np *Processor) approveBindCommon(e *bindCommonContext) error {
	// calculate wallet address
	scriptHash := e.User()

	u160, err := util.Uint160DecodeBytesBE(scriptHash)
	if err != nil {
		np.log.Error(logs.FrostFSCouldNotDecodeScriptHashFromBytes,
			zap.String("error", err.Error()),
		)

		return err
	}

	var id user.ID
	id.SetScriptHash(u160)

	prm := frostfsid.CommonBindPrm{}
	prm.SetOwnerID(id.WalletBytes())
	prm.SetKeys(e.Keys())
	prm.SetHash(e.bindCommon.TxHash())

	var typ string
	if e.bind {
		typ = "bind"
		err = np.frostfsIDClient.AddKeys(prm)
	} else {
		typ = "unbind"
		err = np.frostfsIDClient.RemoveKeys(prm)
	}

	if err != nil {
		np.log.Error(fmt.Sprintf("could not approve %s", typ),
			zap.String("error", err.Error()))
	}

	return err
}