package keys

import (
	"bytes"
	"errors"
	"fmt"

	"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
	"github.com/nspcc-dev/neo-go/pkg/encoding/base58"
	"golang.org/x/crypto/scrypt"
	"golang.org/x/text/unicode/norm"
)

// NEP-2 standard implementation for encrypting and decrypting private keys.

// NEP-2 specified parameters used for cryptography.
const (
	n       = 16384
	r       = 8
	p       = 8
	keyLen  = 64
	nepFlag = 0xe0
)

var nepHeader = []byte{0x01, 0x42}

// ScryptParams is a json-serializable container for scrypt KDF parameters.
type ScryptParams struct {
	N int `json:"n"`
	R int `json:"r"`
	P int `json:"p"`
}

// NEP2ScryptParams returns scrypt parameters specified in the NEP-2.
func NEP2ScryptParams() ScryptParams {
	return ScryptParams{
		N: n,
		R: r,
		P: p,
	}
}

// NEP2Encrypt encrypts a the PrivateKey using a given passphrase
// under the NEP-2 standard.
func NEP2Encrypt(priv *PrivateKey, passphrase string) (s string, err error) {
	address := priv.Address()

	addrHash := hash.Checksum([]byte(address))
	// Normalize the passphrase according to the NFC standard.
	phraseNorm := norm.NFC.Bytes([]byte(passphrase))
	derivedKey, err := scrypt.Key(phraseNorm, addrHash, n, r, p, keyLen)
	if err != nil {
		return s, err
	}

	derivedKey1 := derivedKey[:32]
	derivedKey2 := derivedKey[32:]
	xr := xor(priv.Bytes(), derivedKey1)

	encrypted, err := aesEncrypt(xr, derivedKey2)
	if err != nil {
		return s, err
	}

	buf := new(bytes.Buffer)
	buf.Write(nepHeader)
	buf.WriteByte(nepFlag)
	buf.Write(addrHash)
	buf.Write(encrypted)

	if buf.Len() != 39 {
		return s, fmt.Errorf("invalid buffer length: expecting 39 bytes got %d", buf.Len())
	}

	return base58.CheckEncode(buf.Bytes()), nil
}

// NEP2Decrypt decrypts an encrypted key using a given passphrase
// under the NEP-2 standard.
func NEP2Decrypt(key, passphrase string) (*PrivateKey, error) {
	b, err := base58.CheckDecode(key)
	if err != nil {
		return nil, err
	}
	if err := validateNEP2Format(b); err != nil {
		return nil, err
	}

	addrHash := b[3:7]
	// Normalize the passphrase according to the NFC standard.
	phraseNorm := norm.NFC.Bytes([]byte(passphrase))
	derivedKey, err := scrypt.Key(phraseNorm, addrHash, n, r, p, keyLen)
	if err != nil {
		return nil, err
	}

	derivedKey1 := derivedKey[:32]
	derivedKey2 := derivedKey[32:]
	encryptedBytes := b[7:]

	decrypted, err := aesDecrypt(encryptedBytes, derivedKey2)
	if err != nil {
		return nil, err
	}

	privBytes := xor(decrypted, derivedKey1)

	// Rebuild the private key.
	privKey, err := NewPrivateKeyFromBytes(privBytes)
	if err != nil {
		return nil, err
	}

	if !compareAddressHash(privKey, addrHash) {
		return nil, errors.New("password mismatch")
	}

	return privKey, nil
}

func compareAddressHash(priv *PrivateKey, inhash []byte) bool {
	address := priv.Address()
	addrHash := hash.Checksum([]byte(address))
	return bytes.Equal(addrHash, inhash)
}

func validateNEP2Format(b []byte) error {
	if len(b) != 39 {
		return fmt.Errorf("invalid length: expecting 39 got %d", len(b))
	}
	if b[0] != 0x01 {
		return fmt.Errorf("invalid byte sequence: expecting 0x01 got 0x%02x", b[0])
	}
	if b[1] != 0x42 {
		return fmt.Errorf("invalid byte sequence: expecting 0x42 got 0x%02x", b[1])
	}
	if b[2] != 0xe0 {
		return fmt.Errorf("invalid byte sequence: expecting 0xe0 got 0x%02x", b[2])
	}
	return nil
}

func xor(a, b []byte) []byte {
	if len(a) != len(b) {
		panic("cannot XOR non equal length arrays")
	}
	dst := make([]byte, len(a))
	for i := 0; i < len(dst); i++ {
		dst[i] = a[i] ^ b[i]
	}
	return dst
}