2025-03-13 11:57:16 +03:00
|
|
|
package mfa
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
|
|
|
"crypto/cipher"
|
|
|
|
"crypto/ecdsa"
|
|
|
|
"crypto/elliptic"
|
|
|
|
"crypto/rand"
|
|
|
|
"crypto/sha256"
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
|
|
|
|
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
|
|
|
"github.com/pquerna/otp"
|
|
|
|
"golang.org/x/crypto/chacha20poly1305"
|
|
|
|
"golang.org/x/crypto/hkdf"
|
|
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
secretLength = 32
|
|
|
|
saltLength = 16
|
|
|
|
)
|
|
|
|
|
|
|
|
// PackMFABox encrypts OTP Key in a MFABox. Holders of unlocker private keys
|
|
|
|
// can unpack this object and decrypt OTP Key.
|
|
|
|
func PackMFABox(secret *otp.Key, unlockerKeys []*keys.PublicKey) (*MFABox, error) {
|
|
|
|
if len(unlockerKeys) == 0 {
|
|
|
|
return nil, errors.New("no unlocker keys provided")
|
|
|
|
}
|
|
|
|
|
|
|
|
// First step: generate encryption key and encrypt secret data with it.
|
2025-03-13 12:55:23 +03:00
|
|
|
secretURL := secret.URL()
|
2025-03-13 11:57:16 +03:00
|
|
|
|
|
|
|
// prepare MFA secret for encryption
|
2025-03-13 12:55:23 +03:00
|
|
|
data, err := proto.Marshal(&Secrets{MFAURL: &secretURL})
|
2025-03-13 11:57:16 +03:00
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("marshal secrets: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// generate symmetric key to encrypt MFA secret
|
|
|
|
secretEncryptionKey, err := generateRandomBytes(secretLength)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("generate secrets encryption key: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// encrypt MFA secret with ChaCha20-Poly1305 AEAD algorithm
|
|
|
|
encryptedSecrets, hkdfsalt, err := encryptData(data, secretEncryptionKey)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("encrypt secrets: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Second step: for each unlocker, encrypt secret encryption key, so
|
|
|
|
// each unlocker could decrypt encryption key and then decrypt MFA secret with it.
|
|
|
|
|
|
|
|
// generate ECDSA P-256 curve private key to derive unique encryption
|
|
|
|
// key for every unlocker with ECDH algorithm.
|
|
|
|
ecdhKey, err := keys.NewPrivateKey()
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("create private key for ECDH: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
unlockers := make([]*Unlocker, len(unlockerKeys))
|
|
|
|
for i := range unlockerKeys {
|
|
|
|
unlockers[i], err = packUnlocker(secretEncryptionKey, ecdhKey, unlockerKeys[i])
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("create unlocker: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return &MFABox{
|
|
|
|
Unlockers: unlockers,
|
|
|
|
ECDHPublicKey: ecdhKey.PublicKey().Bytes(),
|
|
|
|
EncryptedSecrets: encryptedSecrets,
|
|
|
|
Salt: hkdfsalt,
|
|
|
|
}, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// UnpackMFABox decrypts OTP key using unlocker key.
|
|
|
|
func UnpackMFABox(box *MFABox, unlockerKey *keys.PrivateKey) (*otp.Key, error) {
|
|
|
|
unlockerPublicKey := unlockerKey.PublicKey().Bytes()
|
2025-03-13 12:55:23 +03:00
|
|
|
ecdhKey, err := keys.NewPublicKeyFromBytes(box.GetECDHPublicKey(), elliptic.P256())
|
2025-03-13 11:57:16 +03:00
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("parse ECDH key: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// First step: find unlocker message for unlocker key
|
|
|
|
var suitableUnlocker *Unlocker
|
2025-03-13 12:55:23 +03:00
|
|
|
for _, unlocker := range box.GetUnlockers() {
|
2025-03-13 11:57:16 +03:00
|
|
|
if bytes.Equal(unlockerPublicKey, unlocker.GetPublicKey()) {
|
|
|
|
suitableUnlocker = unlocker
|
|
|
|
break
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if suitableUnlocker == nil {
|
|
|
|
return nil, fmt.Errorf("no unlocker for %x", unlockerPublicKey)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Second step: decrypt encryption key of MFA secret
|
|
|
|
secretEncryptionKey, err := unpackUnlocker(suitableUnlocker, ecdhKey, unlockerKey)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("unpack unlocker: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Third step: decrypt MFA secret
|
2025-03-13 12:55:23 +03:00
|
|
|
data, err := decryptData(box.GetEncryptedSecrets(), secretEncryptionKey, box.GetSalt())
|
2025-03-13 11:57:16 +03:00
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("decrypt secrets: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
secrets := new(Secrets)
|
|
|
|
if err = proto.Unmarshal(data, secrets); err != nil {
|
|
|
|
return nil, fmt.Errorf("unmarshal secrets: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
key, err := otp.NewKeyFromURL(secrets.GetMFAURL())
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("parse OTP key: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return key, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func packUnlocker(data []byte, ecdhKey *keys.PrivateKey, unlockerKey *keys.PublicKey) (*Unlocker, error) {
|
|
|
|
// derive unique encryption key for unlocker with ECDH algorithm
|
|
|
|
uniqueUnlockerKey, err := deriveECDH(ecdhKey, unlockerKey)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("generate ECDH: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// encrypt data based on unique encryption key
|
|
|
|
encryptedData, salt, err := encryptData(data, uniqueUnlockerKey)
|
|
|
|
|
|
|
|
return &Unlocker{
|
|
|
|
PublicKey: unlockerKey.Bytes(),
|
|
|
|
EncryptedSecretsKey: encryptedData,
|
|
|
|
Salt: salt,
|
|
|
|
}, err
|
|
|
|
}
|
|
|
|
|
|
|
|
func unpackUnlocker(unlocker *Unlocker, ecdhKey *keys.PublicKey, unlockerKey *keys.PrivateKey) ([]byte, error) {
|
|
|
|
// derive unique encryption key for unlocker with ECDH algorithm
|
|
|
|
uniqueUnlockerKey, err := deriveECDH(unlockerKey, ecdhKey)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("generate ECDH: %w", err)
|
|
|
|
}
|
|
|
|
|
2025-03-13 12:55:23 +03:00
|
|
|
return decryptData(unlocker.GetEncryptedSecretsKey(), uniqueUnlockerKey, unlocker.GetSalt())
|
2025-03-13 11:57:16 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
func encryptData(data, encryptionKey []byte) (encryptedData []byte, salt []byte, err error) {
|
|
|
|
// generate salt for HKDF key derive function
|
|
|
|
salt, err = generateRandomBytes(saltLength)
|
|
|
|
if err != nil {
|
|
|
|
return nil, nil, fmt.Errorf("generate HKDF salt: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// get ChaCha20-Poly1305 AEAD cipher based on
|
|
|
|
// a key derived from encryptionKey and random salt
|
|
|
|
enc, err := getCipher(encryptionKey, salt)
|
|
|
|
if err != nil {
|
|
|
|
return nil, nil, fmt.Errorf("prepare AEAD: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// generate random nonce to encrypt data
|
|
|
|
nonce := make([]byte, enc.NonceSize())
|
|
|
|
_, err = rand.Read(nonce)
|
|
|
|
if err != nil {
|
|
|
|
return nil, nil, fmt.Errorf("generate random nonce: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return enc.Seal(nonce, nonce, data, nil), salt, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func decryptData(encryptedData, encryptionKey, salt []byte) (data []byte, err error) {
|
|
|
|
// get ChaCha20-Poly1305 AEAD cipher based on
|
|
|
|
// a key derived from encryptionKey and random salt
|
|
|
|
dec, err := getCipher(encryptionKey, salt)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("prepare AEAD: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
ld, ns := len(encryptedData), dec.NonceSize()
|
|
|
|
if ld < ns {
|
|
|
|
return nil, fmt.Errorf("data size %d, should be greater than nonce size %d", ld, ns)
|
|
|
|
}
|
|
|
|
|
|
|
|
nonce, cypher := encryptedData[:dec.NonceSize()], encryptedData[dec.NonceSize():]
|
|
|
|
|
|
|
|
return dec.Open(nil, nonce, cypher, nil)
|
|
|
|
}
|
|
|
|
|
|
|
|
func getCipher(secret, salt []byte) (cipher.AEAD, error) {
|
|
|
|
key, err := deriveHKDF(secret, salt)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("derive key: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return chacha20poly1305.NewX(key)
|
|
|
|
}
|
|
|
|
|
|
|
|
func deriveECDH(prv *keys.PrivateKey, pub *keys.PublicKey) ([]byte, error) {
|
|
|
|
prvECDH, err := prv.ECDH()
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("invalid ECDH private key: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
pubECDH, err := (*ecdsa.PublicKey)(pub).ECDH()
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("invalid ECDH public key: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return prvECDH.ECDH(pubECDH)
|
|
|
|
}
|
|
|
|
|
|
|
|
func deriveHKDF(secret, salt []byte) ([]byte, error) {
|
|
|
|
hash := sha256.New
|
|
|
|
kdf := hkdf.New(hash, secret, salt, nil)
|
|
|
|
key := make([]byte, 32)
|
|
|
|
_, err := io.ReadFull(kdf, key)
|
|
|
|
return key, err
|
|
|
|
}
|
|
|
|
|
|
|
|
func generateRandomBytes(length int) ([]byte, error) {
|
|
|
|
b := make([]byte, length)
|
|
|
|
_, err := rand.Read(b)
|
|
|
|
return b, err
|
|
|
|
}
|