diff --git a/cli/wallet/wallet.go b/cli/wallet/wallet.go index e0cecb402..503508503 100644 --- a/cli/wallet/wallet.go +++ b/cli/wallet/wallet.go @@ -40,21 +40,48 @@ func NewCommands() []cli.Command { }, }, { - Name: "open", - Usage: "open a existing NEO wallet", - Action: openWallet, + Name: "dump", + Usage: "check and dump an existing NEO wallet", + Action: dumpWallet, Flags: []cli.Flag{ cli.StringFlag{ Name: "path, p", Usage: "Target location of the wallet file.", }, + cli.BoolFlag{ + Name: "decrypt, d", + Usage: "Decrypt encrypted keys.", + }, }, }, }, }} } -func openWallet(ctx *cli.Context) error { +func dumpWallet(ctx *cli.Context) error { + path := ctx.String("path") + if len(path) == 0 { + return cli.NewExitError(errNoPath, 1) + } + wall, err := wallet.NewWalletFromFile(path) + if err != nil { + return cli.NewExitError(err, 1) + } + if ctx.Bool("decrypt") { + fmt.Print("Wallet password: ") + pass, err := terminal.ReadPassword(int(syscall.Stdin)) + if err != nil { + return cli.NewExitError(err, 1) + } + for i := range wall.Accounts { + // Just testing the decryption here. + err := wall.Accounts[i].Decrypt(string(pass)) + if err != nil { + return cli.NewExitError(err, 1) + } + } + } + fmtPrintWallet(wall) return nil } @@ -77,7 +104,7 @@ func createWallet(ctx *cli.Context) error { } } - dumpWallet(wall) + fmtPrintWallet(wall) fmt.Printf("wallet successfully created, file location is %s\n", wall.Path()) return nil } @@ -110,7 +137,7 @@ func createAccount(ctx *cli.Context, wall *wallet.Wallet) error { return wall.CreateAccount(name, phrase) } -func dumpWallet(wall *wallet.Wallet) { +func fmtPrintWallet(wall *wallet.Wallet) { b, _ := wall.JSON() fmt.Println("") fmt.Println(string(b)) diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index 8faf2c080..5f8999644 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -14,7 +14,6 @@ import ( "github.com/CityOfZion/neo-go/pkg/smartcontract" "github.com/CityOfZion/neo-go/pkg/util" "github.com/CityOfZion/neo-go/pkg/vm/opcode" - "github.com/CityOfZion/neo-go/pkg/wallet" "github.com/nspcc-dev/dbft" "github.com/nspcc-dev/dbft/block" "github.com/nspcc-dev/dbft/crypto" @@ -178,16 +177,12 @@ func (s *service) validatePayload(p *Payload) bool { } func getKeyPair(cfg *config.WalletConfig) (crypto.PrivateKey, crypto.PublicKey) { - acc, err := wallet.DecryptAccount(cfg.Path, cfg.Password) + // TODO: replace with wallet opening from the given path (#588) + key, err := keys.NEP2Decrypt(cfg.Path, cfg.Password) if err != nil { return nil, nil } - key := acc.PrivateKey() - if key == nil { - return nil, nil - } - return &privateKey{PrivateKey: key}, &publicKey{PublicKey: key.PublicKey()} } diff --git a/pkg/crypto/keys/nep2.go b/pkg/crypto/keys/nep2.go index abe478490..1f49298d2 100644 --- a/pkg/crypto/keys/nep2.go +++ b/pkg/crypto/keys/nep2.go @@ -77,13 +77,13 @@ func NEP2Encrypt(priv *PrivateKey, passphrase string) (s string, err error) { // NEP2Decrypt decrypts an encrypted key using a given passphrase // under the NEP-2 standard. -func NEP2Decrypt(key, passphrase string) (s string, err error) { +func NEP2Decrypt(key, passphrase string) (*PrivateKey, error) { b, err := base58.CheckDecode(key) if err != nil { - return s, nil + return nil, err } if err := validateNEP2Format(b); err != nil { - return s, err + return nil, err } addrHash := b[3:7] @@ -91,7 +91,7 @@ func NEP2Decrypt(key, passphrase string) (s string, err error) { phraseNorm := norm.NFC.Bytes([]byte(passphrase)) derivedKey, err := scrypt.Key(phraseNorm, addrHash, n, r, p, keyLen) if err != nil { - return s, err + return nil, err } derivedKey1 := derivedKey[:32] @@ -100,7 +100,7 @@ func NEP2Decrypt(key, passphrase string) (s string, err error) { decrypted, err := aesDecrypt(encryptedBytes, derivedKey2) if err != nil { - return s, err + return nil, err } privBytes := xor(decrypted, derivedKey1) @@ -108,14 +108,14 @@ func NEP2Decrypt(key, passphrase string) (s string, err error) { // Rebuild the private key. privKey, err := NewPrivateKeyFromBytes(privBytes) if err != nil { - return s, err + return nil, err } if !compareAddressHash(privKey, addrHash) { - return s, errors.New("password mismatch") + return nil, errors.New("password mismatch") } - return privKey.WIF(), nil + return privKey, nil } func compareAddressHash(priv *PrivateKey, inhash []byte) bool { diff --git a/pkg/crypto/keys/nep2_test.go b/pkg/crypto/keys/nep2_test.go index 4c5e86c23..a12c21447 100644 --- a/pkg/crypto/keys/nep2_test.go +++ b/pkg/crypto/keys/nep2_test.go @@ -27,18 +27,13 @@ func TestNEP2Encrypt(t *testing.T) { func TestNEP2Decrypt(t *testing.T) { for _, testCase := range keytestcases.Arr { - - privKeyString, err := NEP2Decrypt(testCase.EncryptedWif, testCase.Passphrase) + privKey, err := NEP2Decrypt(testCase.EncryptedWif, testCase.Passphrase) if testCase.Invalid { assert.Error(t, err) continue } assert.Nil(t, err) - - privKey, err := NewPrivateKeyFromWIF(privKeyString) - assert.Nil(t, err) - assert.Equal(t, testCase.PrivateKey, privKey.String()) wif := privKey.WIF() @@ -48,3 +43,39 @@ func TestNEP2Decrypt(t *testing.T) { assert.Equal(t, testCase.Address, address) } } + +func TestNEP2DecryptErrors(t *testing.T) { + p := "qwerty" + + // Not a base58-encoded value + s := "qazwsx" + _, err := NEP2Decrypt(s, p) + assert.Error(t, err) + + // Valid base58, but not a NEP-2 format. + s = "KxhEDBQyyEFymvfJD96q8stMbJMbZUb6D1PmXqBWZDU2WvbvVs9o" + _, err = NEP2Decrypt(s, p) + assert.Error(t, err) +} + +func TestValidateNEP2Format(t *testing.T) { + // Wrong length. + s := []byte("gobbledygook") + assert.Error(t, validateNEP2Format(s)) + + // Wrong header 1. + s = []byte("gobbledygookgobbledygookgobbledygookgob") + assert.Error(t, validateNEP2Format(s)) + + // Wrong header 2. + s[0] = 0x01 + assert.Error(t, validateNEP2Format(s)) + + // Wrong header 3. + s[1] = 0x42 + assert.Error(t, validateNEP2Format(s)) + + // OK + s[2] = 0xe0 + assert.NoError(t, validateNEP2Format(s)) +} diff --git a/pkg/wallet/account.go b/pkg/wallet/account.go index 02fb52b75..c1ddd1317 100644 --- a/pkg/wallet/account.go +++ b/pkg/wallet/account.go @@ -1,6 +1,8 @@ package wallet import ( + "errors" + "github.com/CityOfZion/neo-go/pkg/crypto/keys" "github.com/CityOfZion/neo-go/pkg/util" ) @@ -60,14 +62,16 @@ func NewAccount() (*Account, error) { return newAccountFromPrivateKey(priv), nil } -// DecryptAccount decrypts the encryptedWIF with the given passphrase and -// return the decrypted Account. -func DecryptAccount(encryptedWIF, passphrase string) (*Account, error) { - wif, err := keys.NEP2Decrypt(encryptedWIF, passphrase) - if err != nil { - return nil, err +// Decrypt decrypts the EncryptedWIF with the given passphrase returning error +// if anything goes wrong. +func (a *Account) Decrypt(passphrase string) error { + var err error + + if a.EncryptedWIF == "" { + return errors.New("no encrypted wif in the account") } - return NewAccountFromWIF(wif) + a.privateKey, err = keys.NEP2Decrypt(a.EncryptedWIF, passphrase) + return err } // Encrypt encrypts the wallet's PrivateKey with the given passphrase diff --git a/pkg/wallet/account_test.go b/pkg/wallet/account_test.go index 343c993bc..da84055f2 100644 --- a/pkg/wallet/account_test.go +++ b/pkg/wallet/account_test.go @@ -6,32 +6,32 @@ import ( "github.com/CityOfZion/neo-go/pkg/internal/keytestcases" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewAccount(t *testing.T) { - for _, testCase := range keytestcases.Arr { - acc, err := NewAccountFromWIF(testCase.Wif) - if testCase.Invalid { - assert.Error(t, err) - continue - } - - assert.NoError(t, err) - compareFields(t, testCase, acc) - } + acc, err := NewAccount() + require.NoError(t, err) + require.NotNil(t, acc) } func TestDecryptAccount(t *testing.T) { for _, testCase := range keytestcases.Arr { - acc, err := DecryptAccount(testCase.EncryptedWif, testCase.Passphrase) + acc := &Account{EncryptedWIF: testCase.EncryptedWif} + assert.Nil(t, acc.PrivateKey()) + err := acc.Decrypt(testCase.Passphrase) if testCase.Invalid { assert.Error(t, err) continue } assert.NoError(t, err) - compareFields(t, testCase, acc) + assert.NotNil(t, acc.PrivateKey()) + assert.Equal(t, testCase.PrivateKey, acc.privateKey.String()) } + // No encrypted key. + acc := &Account{} + require.Error(t, acc.Decrypt("qwerty")) } func TestNewFromWif(t *testing.T) {