keys: check length first, then do things in WIFDecode

Otherwise we can easily panic there on bad input.
This commit is contained in:
Roman Khimov 2022-09-01 22:23:26 +03:00
parent 3c722a9498
commit eb67145f81
3 changed files with 53 additions and 23 deletions

View file

@ -110,7 +110,9 @@ func NewPrivateKeyFromWIF(wif string) (*PrivateKey, error) {
// Good documentation about this process can be found here: // Good documentation about this process can be found here:
// https://en.bitcoin.it/wiki/Wallet_import_format // https://en.bitcoin.it/wiki/Wallet_import_format
func (p *PrivateKey) WIF() string { func (p *PrivateKey) WIF() string {
w, err := WIFEncode(p.Bytes(), WIFVersion, true) pb := p.Bytes()
defer slice.Clean(pb)
w, err := WIFEncode(pb, WIFVersion, true)
// The only way WIFEncode() can fail is if we're to give it a key of // The only way WIFEncode() can fail is if we're to give it a key of
// wrong size, but we have a proper key here, aren't we? // wrong size, but we have a proper key here, aren't we?
if err != nil { if err != nil {

View file

@ -59,36 +59,31 @@ func WIFDecode(wif string, version byte) (*WIF, error) {
if version == 0x00 { if version == 0x00 {
version = WIFVersion version = WIFVersion
} }
w := &WIF{
Version: version,
S: wif,
}
switch len(b) {
case 33: // OK, uncompressed public key.
case 34: // OK, compressed public key.
// Check the compression flag.
if b[33] != 0x01 {
return nil, fmt.Errorf("invalid compression flag %d expecting %d", b[33], 0x01)
}
w.Compressed = true
default:
return nil, fmt.Errorf("invalid WIF length %d, expecting 33 or 34", len(b))
}
if b[0] != version { if b[0] != version {
return nil, fmt.Errorf("invalid WIF version got %d, expected %d", b[0], version) return nil, fmt.Errorf("invalid WIF version got %d, expected %d", b[0], version)
} }
// Derive the PrivateKey. // Derive the PrivateKey.
privKey, err := NewPrivateKeyFromBytes(b[1:33]) w.PrivateKey, err = NewPrivateKeyFromBytes(b[1:33])
if err != nil { if err != nil {
return nil, err return nil, err
} }
w := &WIF{
Version: version,
PrivateKey: privKey,
S: wif,
}
// This is an uncompressed WIF.
if len(b) == 33 {
w.Compressed = false
return w, nil
}
if len(b) != 34 {
return nil, fmt.Errorf("invalid WIF length: %d expecting 34", len(b))
}
// Check the compression flag.
if b[33] != 0x01 {
return nil, fmt.Errorf("invalid compression flag %d expecting %d", b[34], 0x01)
}
w.Compressed = true
return w, nil return w, nil
} }

View file

@ -4,6 +4,7 @@ import (
"encoding/hex" "encoding/hex"
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/encoding/base58"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -65,3 +66,35 @@ func TestWIFEncodeDecode(t *testing.T) {
_, err := WIFEncode(wifInv, 0, true) _, err := WIFEncode(wifInv, 0, true)
require.Error(t, err) require.Error(t, err)
} }
func TestBadWIFDecode(t *testing.T) {
_, err := WIFDecode("garbage", 0)
require.Error(t, err)
s := base58.CheckEncode([]byte{})
_, err = WIFDecode(s, 0)
require.Error(t, err)
uncompr := make([]byte, 33)
compr := make([]byte, 34)
s = base58.CheckEncode(compr)
_, err = WIFDecode(s, 0)
require.Error(t, err)
s = base58.CheckEncode(uncompr)
_, err = WIFDecode(s, 0)
require.Error(t, err)
compr[33] = 1
compr[0] = WIFVersion
uncompr[0] = WIFVersion
s = base58.CheckEncode(compr)
_, err = WIFDecode(s, 0)
require.NoError(t, err)
s = base58.CheckEncode(uncompr)
_, err = WIFDecode(s, 0)
require.NoError(t, err)
}