diff --git a/pkg/crypto/keys/private_key.go b/pkg/crypto/keys/private_key.go index 744e49ab1..8facca0d9 100644 --- a/pkg/crypto/keys/private_key.go +++ b/pkg/crypto/keys/private_key.go @@ -110,7 +110,9 @@ func NewPrivateKeyFromWIF(wif string) (*PrivateKey, error) { // Good documentation about this process can be found here: // https://en.bitcoin.it/wiki/Wallet_import_format func (p *PrivateKey) WIF() string { - w, err := WIFEncode(p.Bytes(), WIFVersion, true) + pb := p.Bytes() + defer slice.Clean(pb) + w, err := WIFEncode(pb, WIFVersion, true) // The only way WIFEncode() can fail is if we're to give it a key of // wrong size, but we have a proper key here, aren't we? if err != nil { diff --git a/pkg/crypto/keys/wif.go b/pkg/crypto/keys/wif.go index 7da78ea8e..1ec908cdd 100644 --- a/pkg/crypto/keys/wif.go +++ b/pkg/crypto/keys/wif.go @@ -59,36 +59,31 @@ func WIFDecode(wif string, version byte) (*WIF, error) { if version == 0x00 { version = WIFVersion } + w := &WIF{ + Version: version, + S: wif, + } + switch len(b) { + case 33: // OK, uncompressed public key. + case 34: // OK, compressed public key. + // Check the compression flag. + if b[33] != 0x01 { + return nil, fmt.Errorf("invalid compression flag %d expecting %d", b[33], 0x01) + } + w.Compressed = true + default: + return nil, fmt.Errorf("invalid WIF length %d, expecting 33 or 34", len(b)) + } + if b[0] != version { return nil, fmt.Errorf("invalid WIF version got %d, expected %d", b[0], version) } // Derive the PrivateKey. - privKey, err := NewPrivateKeyFromBytes(b[1:33]) + w.PrivateKey, err = NewPrivateKeyFromBytes(b[1:33]) if err != nil { return nil, err } - w := &WIF{ - Version: version, - PrivateKey: privKey, - S: wif, - } - // This is an uncompressed WIF. - if len(b) == 33 { - w.Compressed = false - return w, nil - } - - if len(b) != 34 { - return nil, fmt.Errorf("invalid WIF length: %d expecting 34", len(b)) - } - - // Check the compression flag. - if b[33] != 0x01 { - return nil, fmt.Errorf("invalid compression flag %d expecting %d", b[34], 0x01) - } - - w.Compressed = true return w, nil } diff --git a/pkg/crypto/keys/wif_test.go b/pkg/crypto/keys/wif_test.go index 6fae58167..f29269e72 100644 --- a/pkg/crypto/keys/wif_test.go +++ b/pkg/crypto/keys/wif_test.go @@ -4,6 +4,7 @@ import ( "encoding/hex" "testing" + "github.com/nspcc-dev/neo-go/pkg/encoding/base58" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -65,3 +66,35 @@ func TestWIFEncodeDecode(t *testing.T) { _, err := WIFEncode(wifInv, 0, true) require.Error(t, err) } + +func TestBadWIFDecode(t *testing.T) { + _, err := WIFDecode("garbage", 0) + require.Error(t, err) + + s := base58.CheckEncode([]byte{}) + _, err = WIFDecode(s, 0) + require.Error(t, err) + + uncompr := make([]byte, 33) + compr := make([]byte, 34) + + s = base58.CheckEncode(compr) + _, err = WIFDecode(s, 0) + require.Error(t, err) + + s = base58.CheckEncode(uncompr) + _, err = WIFDecode(s, 0) + require.Error(t, err) + + compr[33] = 1 + compr[0] = WIFVersion + uncompr[0] = WIFVersion + + s = base58.CheckEncode(compr) + _, err = WIFDecode(s, 0) + require.NoError(t, err) + + s = base58.CheckEncode(uncompr) + _, err = WIFDecode(s, 0) + require.NoError(t, err) +}