crypto/keys: enforce length in PublicKey.DecodeBytes()

Signed-off-by: Evgeniy Stratonikov <evgeniy@nspcc.ru>
This commit is contained in:
Evgeniy Stratonikov 2021-08-13 10:37:24 +03:00
parent 5aff82aef4
commit bb137abb03
2 changed files with 14 additions and 2 deletions

View file

@ -8,6 +8,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
gio "io"
"math/big" "math/big"
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
@ -234,9 +235,17 @@ func decodeCompressedY(x *big.Int, ylsb uint, curve elliptic.Curve) (*big.Int, e
func (p *PublicKey) DecodeBytes(data []byte) error { func (p *PublicKey) DecodeBytes(data []byte) error {
b := io.NewBinReaderFromBuf(data) b := io.NewBinReaderFromBuf(data)
p.DecodeBinary(b) p.DecodeBinary(b)
if b.Err != nil {
return b.Err return b.Err
} }
b.ReadB()
if b.Err != gio.EOF {
return errors.New("extra data")
}
return nil
}
// DecodeBinary decodes a PublicKey from the given BinReader using information // DecodeBinary decodes a PublicKey from the given BinReader using information
// about the EC curve to decompress Y point. Secp256r1 is a default value for EC curve. // about the EC curve to decompress Y point. Secp256r1 is a default value for EC curve.
func (p *PublicKey) DecodeBinary(r *io.BinReader) { func (p *PublicKey) DecodeBinary(r *io.BinReader) {
@ -375,7 +384,7 @@ func (p *PublicKey) UnmarshalJSON(data []byte) error {
return errors.New("wrong format") return errors.New("wrong format")
} }
bytes := make([]byte, l-2) bytes := make([]byte, hex.DecodedLen(l-2))
_, err := hex.Decode(bytes, data[1:l-1]) _, err := hex.Decode(bytes, data[1:l-1])
if err != nil { if err != nil {
return err return err

View file

@ -67,6 +67,9 @@ func TestNewPublicKeyFromBytes(t *testing.T) {
pub2, err := NewPublicKeyFromBytes(b, elliptic.P256()) pub2, err := NewPublicKeyFromBytes(b, elliptic.P256())
require.NoError(t, err) require.NoError(t, err)
require.Same(t, pub, pub2) require.Same(t, pub, pub2)
_, err = NewPublicKeyFromBytes([]byte{0x00, 0x01}, elliptic.P256())
require.Error(t, err)
} }
func TestDecodeFromString(t *testing.T) { func TestDecodeFromString(t *testing.T) {