package mfa import ( "bytes" "crypto/cipher" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/sha256" "errors" "fmt" "io" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/pquerna/otp" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/hkdf" "google.golang.org/protobuf/proto" ) const ( secretLength = 32 saltLength = 16 ) // PackMFABox encrypts OTP Key in a MFABox. Holders of unlocker private keys // can unpack this object and decrypt OTP Key. func PackMFABox(secret *otp.Key, unlockerKeys []*keys.PublicKey) (*MFABox, error) { if len(unlockerKeys) == 0 { return nil, errors.New("no unlocker keys provided") } // First step: generate encryption key and encrypt secret data with it. secretURL := secret.URL() // prepare MFA secret for encryption data, err := proto.Marshal(&Secrets{MFAURL: &secretURL}) if err != nil { return nil, fmt.Errorf("marshal secrets: %w", err) } // generate symmetric key to encrypt MFA secret secretEncryptionKey, err := generateRandomBytes(secretLength) if err != nil { return nil, fmt.Errorf("generate secrets encryption key: %w", err) } // encrypt MFA secret with ChaCha20-Poly1305 AEAD algorithm encryptedSecrets, hkdfsalt, err := encryptData(data, secretEncryptionKey) if err != nil { return nil, fmt.Errorf("encrypt secrets: %w", err) } // Second step: for each unlocker, encrypt secret encryption key, so // each unlocker could decrypt encryption key and then decrypt MFA secret with it. // generate ECDSA P-256 curve private key to derive unique encryption // key for every unlocker with ECDH algorithm. ecdhKey, err := keys.NewPrivateKey() if err != nil { return nil, fmt.Errorf("create private key for ECDH: %w", err) } unlockers := make([]*Unlocker, len(unlockerKeys)) for i := range unlockerKeys { unlockers[i], err = packUnlocker(secretEncryptionKey, ecdhKey, unlockerKeys[i]) if err != nil { return nil, fmt.Errorf("create unlocker: %w", err) } } return &MFABox{ Unlockers: unlockers, ECDHPublicKey: ecdhKey.PublicKey().Bytes(), EncryptedSecrets: encryptedSecrets, Salt: hkdfsalt, }, nil } // UnpackMFABox decrypts OTP key using unlocker key. func UnpackMFABox(box *MFABox, unlockerKey *keys.PrivateKey) (*otp.Key, error) { unlockerPublicKey := unlockerKey.PublicKey().Bytes() ecdhKey, err := keys.NewPublicKeyFromBytes(box.GetECDHPublicKey(), elliptic.P256()) if err != nil { return nil, fmt.Errorf("parse ECDH key: %w", err) } // First step: find unlocker message for unlocker key var suitableUnlocker *Unlocker for _, unlocker := range box.GetUnlockers() { if bytes.Equal(unlockerPublicKey, unlocker.GetPublicKey()) { suitableUnlocker = unlocker break } } if suitableUnlocker == nil { return nil, fmt.Errorf("no unlocker for %x", unlockerPublicKey) } // Second step: decrypt encryption key of MFA secret secretEncryptionKey, err := unpackUnlocker(suitableUnlocker, ecdhKey, unlockerKey) if err != nil { return nil, fmt.Errorf("unpack unlocker: %w", err) } // Third step: decrypt MFA secret data, err := decryptData(box.GetEncryptedSecrets(), secretEncryptionKey, box.GetSalt()) if err != nil { return nil, fmt.Errorf("decrypt secrets: %w", err) } secrets := new(Secrets) if err = proto.Unmarshal(data, secrets); err != nil { return nil, fmt.Errorf("unmarshal secrets: %w", err) } key, err := otp.NewKeyFromURL(secrets.GetMFAURL()) if err != nil { return nil, fmt.Errorf("parse OTP key: %w", err) } return key, nil } func packUnlocker(data []byte, ecdhKey *keys.PrivateKey, unlockerKey *keys.PublicKey) (*Unlocker, error) { // derive unique encryption key for unlocker with ECDH algorithm uniqueUnlockerKey, err := deriveECDH(ecdhKey, unlockerKey) if err != nil { return nil, fmt.Errorf("generate ECDH: %w", err) } // encrypt data based on unique encryption key encryptedData, salt, err := encryptData(data, uniqueUnlockerKey) return &Unlocker{ PublicKey: unlockerKey.Bytes(), EncryptedSecretsKey: encryptedData, Salt: salt, }, err } func unpackUnlocker(unlocker *Unlocker, ecdhKey *keys.PublicKey, unlockerKey *keys.PrivateKey) ([]byte, error) { // derive unique encryption key for unlocker with ECDH algorithm uniqueUnlockerKey, err := deriveECDH(unlockerKey, ecdhKey) if err != nil { return nil, fmt.Errorf("generate ECDH: %w", err) } return decryptData(unlocker.GetEncryptedSecretsKey(), uniqueUnlockerKey, unlocker.GetSalt()) } func encryptData(data, encryptionKey []byte) (encryptedData []byte, salt []byte, err error) { // generate salt for HKDF key derive function salt, err = generateRandomBytes(saltLength) if err != nil { return nil, nil, fmt.Errorf("generate HKDF salt: %w", err) } // get ChaCha20-Poly1305 AEAD cipher based on // a key derived from encryptionKey and random salt enc, err := getCipher(encryptionKey, salt) if err != nil { return nil, nil, fmt.Errorf("prepare AEAD: %w", err) } // generate random nonce to encrypt data nonce := make([]byte, enc.NonceSize()) _, err = rand.Read(nonce) if err != nil { return nil, nil, fmt.Errorf("generate random nonce: %w", err) } return enc.Seal(nonce, nonce, data, nil), salt, nil } func decryptData(encryptedData, encryptionKey, salt []byte) (data []byte, err error) { // get ChaCha20-Poly1305 AEAD cipher based on // a key derived from encryptionKey and random salt dec, err := getCipher(encryptionKey, salt) if err != nil { return nil, fmt.Errorf("prepare AEAD: %w", err) } ld, ns := len(encryptedData), dec.NonceSize() if ld < ns { return nil, fmt.Errorf("data size %d, should be greater than nonce size %d", ld, ns) } nonce, cypher := encryptedData[:dec.NonceSize()], encryptedData[dec.NonceSize():] return dec.Open(nil, nonce, cypher, nil) } func getCipher(secret, salt []byte) (cipher.AEAD, error) { key, err := deriveHKDF(secret, salt) if err != nil { return nil, fmt.Errorf("derive key: %w", err) } return chacha20poly1305.NewX(key) } func deriveECDH(prv *keys.PrivateKey, pub *keys.PublicKey) ([]byte, error) { prvECDH, err := prv.ECDH() if err != nil { return nil, fmt.Errorf("invalid ECDH private key: %w", err) } pubECDH, err := (*ecdsa.PublicKey)(pub).ECDH() if err != nil { return nil, fmt.Errorf("invalid ECDH public key: %w", err) } return prvECDH.ECDH(pubECDH) } func deriveHKDF(secret, salt []byte) ([]byte, error) { hash := sha256.New kdf := hkdf.New(hash, secret, salt, nil) key := make([]byte, 32) _, err := io.ReadFull(kdf, key) return key, err } func generateRandomBytes(length int) ([]byte, error) { b := make([]byte, length) _, err := rand.Read(b) return b, err }