crypto: drop home-grown elliptic crypto, use crypto/elliptic

As NEO uses P256 we can use standard crypto/elliptic library for almost
everything, the only exception being decompression of the Y coordinate. For
some reason the standard library only supports uncompressed format in its
Marshal()/Unmarshal() functions. elliptic.P256() is known to have
constant-time implementation, so it fixes #245 (and the decompression using
big.Int operates on public key, so nobody really cares about that part being
constant-time).

New decompress function is inspired by
https://stackoverflow.com/questions/46283760, even though the previous one
really did the same thing just in a little less obvious way.
This commit is contained in:
Roman Khimov 2019-09-05 00:12:39 +03:00
parent 0b884b92b3
commit f0fbe9f6c9
5 changed files with 66 additions and 348 deletions

View file

@ -1,256 +0,0 @@
package crypto
// Original work completed by @vsergeev: https://github.com/vsergeev/btckeygenie
// Expanded and tweaked upon here under MIT license.
import (
"bytes"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"io"
"math/big"
"github.com/CityOfZion/neo-go/pkg/util"
)
type (
// EllipticCurve represents the parameters of a short Weierstrass equation elliptic
// curve.
EllipticCurve struct {
A *big.Int
B *big.Int
P *big.Int
G ECPoint
N *big.Int
H *big.Int
}
// ECPoint represents a point on the EllipticCurve.
ECPoint struct {
X *big.Int
Y *big.Int
}
)
// NewEllipticCurve returns a ready to use EllipticCurve with preconfigured
// fields for the NEO protocol.
func NewEllipticCurve() EllipticCurve {
c := EllipticCurve{}
c.P, _ = new(big.Int).SetString(
"FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF", 16,
)
c.A, _ = new(big.Int).SetString(
"FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFC", 16,
)
c.B, _ = new(big.Int).SetString(
"5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B", 16,
)
c.G.X, _ = new(big.Int).SetString(
"6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296", 16,
)
c.G.Y, _ = new(big.Int).SetString(
"4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5", 16,
)
c.N, _ = new(big.Int).SetString(
"FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551", 16,
)
c.H, _ = new(big.Int).SetString("01", 16)
return c
}
// ECPointFromReader return a new point from the given reader.
// f == 4, 6 or 7 are not implemented.
func ECPointFromReader(r io.Reader) (point ECPoint, err error) {
var f uint8
if err = binary.Read(r, binary.LittleEndian, &f); err != nil {
return
}
// Infinity
if f == 0 {
return ECPoint{
X: new(big.Int),
Y: new(big.Int),
}, nil
}
if f == 2 || f == 3 {
y := new(big.Int).SetBytes([]byte{f & 1})
data := make([]byte, 32)
if err = binary.Read(r, binary.LittleEndian, data); err != nil {
return
}
data = util.ArrayReverse(data)
data = append(data, byte(0x00))
return ECPoint{
X: new(big.Int).SetBytes(data),
Y: y,
}, nil
}
return
}
// EncodeBinary encodes the point to the given io.Writer.
func (p ECPoint) EncodeBinary(w io.Writer) error {
bx := p.X.Bytes()
padded := append(
bytes.Repeat(
[]byte{0x00},
32-len(bx),
),
bx...,
)
prefix := byte(0x03)
if p.Y.Bit(0) == 0 {
prefix = byte(0x02)
}
buf := make([]byte, len(padded)+1)
buf[0] = prefix
copy(buf[1:], padded)
return binary.Write(w, binary.LittleEndian, buf)
}
// String implements the Stringer interface.
func (p *ECPoint) String() string {
if p.IsInfinity() {
return "00"
}
bx := hex.EncodeToString(p.X.Bytes())
by := hex.EncodeToString(p.Y.Bytes())
return fmt.Sprintf("%s%s", bx, by)
}
// IsInfinity checks if point P is infinity on EllipticCurve ec.
func (p *ECPoint) IsInfinity() bool {
return p.X == nil && p.Y == nil
}
// IsInfinity checks if point P is infinity on EllipticCurve ec.
func (c *EllipticCurve) IsInfinity(P ECPoint) bool {
return P.X == nil && P.Y == nil
}
// IsOnCurve checks if point P is on EllipticCurve ec.
func (c *EllipticCurve) IsOnCurve(P ECPoint) bool {
if c.IsInfinity(P) {
return false
}
lhs := mulMod(P.Y, P.Y, c.P)
rhs := addMod(
addMod(
expMod(P.X, big.NewInt(3), c.P),
mulMod(c.A, P.X, c.P), c.P),
c.B, c.P)
return lhs.Cmp(rhs) == 0
}
// Add computes R = P + Q on EllipticCurve ec.
func (c *EllipticCurve) Add(P, Q ECPoint) (R ECPoint) {
// See rules 1-5 on SEC1 pg.7 http://www.secg.org/collateral/sec1_final.pdf
if c.IsInfinity(P) && c.IsInfinity(Q) {
R.X = nil
R.Y = nil
} else if c.IsInfinity(P) {
R.X = new(big.Int).Set(Q.X)
R.Y = new(big.Int).Set(Q.Y)
} else if c.IsInfinity(Q) {
R.X = new(big.Int).Set(P.X)
R.Y = new(big.Int).Set(P.Y)
} else if P.X.Cmp(Q.X) == 0 && addMod(P.Y, Q.Y, c.P).Sign() == 0 {
R.X = nil
R.Y = nil
} else if P.X.Cmp(Q.X) == 0 && P.Y.Cmp(Q.Y) == 0 && P.Y.Sign() != 0 {
num := addMod(
mulMod(big.NewInt(3),
mulMod(P.X, P.X, c.P), c.P),
c.A, c.P)
den := invMod(mulMod(big.NewInt(2), P.Y, c.P), c.P)
lambda := mulMod(num, den, c.P)
R.X = subMod(
mulMod(lambda, lambda, c.P),
mulMod(big.NewInt(2), P.X, c.P),
c.P)
R.Y = subMod(
mulMod(lambda, subMod(P.X, R.X, c.P), c.P),
P.Y, c.P)
} else if P.X.Cmp(Q.X) != 0 {
num := subMod(Q.Y, P.Y, c.P)
den := invMod(subMod(Q.X, P.X, c.P), c.P)
lambda := mulMod(num, den, c.P)
R.X = subMod(
subMod(
mulMod(lambda, lambda, c.P),
P.X, c.P),
Q.X, c.P)
R.Y = subMod(
mulMod(lambda,
subMod(P.X, R.X, c.P), c.P),
P.Y, c.P)
} else {
panic(fmt.Sprintf("Unsupported point addition: %v + %v", P, Q))
}
return R
}
// ScalarMult computes Q = k * P on EllipticCurve ec.
func (c *EllipticCurve) ScalarMult(k *big.Int, P ECPoint) (Q ECPoint) {
// Implementation based on pseudocode here:
// https://en.wikipedia.org/wiki/Elliptic_curve_point_multiplication#Montgomery_ladder
var R0 ECPoint
var R1 ECPoint
R0.X = nil
R0.Y = nil
R1.X = new(big.Int).Set(P.X)
R1.Y = new(big.Int).Set(P.Y)
for i := c.N.BitLen() - 1; i >= 0; i-- {
if k.Bit(i) == 0 {
R1 = c.Add(R0, R1)
R0 = c.Add(R0, R0)
} else {
R0 = c.Add(R0, R1)
R1 = c.Add(R1, R1)
}
}
return R0
}
// ScalarBaseMult computes Q = k * G on EllipticCurve ec.
func (c *EllipticCurve) ScalarBaseMult(k *big.Int) (Q ECPoint) {
return c.ScalarMult(k, c.G)
}
// Decompress decompresses coordinate x and ylsb (y's least significant bit) into a ECPoint P on EllipticCurve ec.
func (c *EllipticCurve) Decompress(x *big.Int, ylsb uint) (P ECPoint, err error) {
/* y**2 = x**3 + a*x + b % p */
rhs := addMod(
addMod(
expMod(x, big.NewInt(3), c.P),
mulMod(c.A, x, c.P),
c.P),
c.B, c.P)
y := sqrtMod(rhs, c.P)
if y.Bit(0) != (ylsb & 0x1) {
y = subMod(big.NewInt(0), y, c.P)
}
P.X = x
P.Y = y
if !c.IsOnCurve(P) {
return P, errors.New("compressed (x, ylsb) not on curve")
}
return P, nil
}

View file

@ -10,10 +10,8 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"io"
"math/big" "math/big"
"github.com/CityOfZion/neo-go/pkg/crypto"
"github.com/nspcc-dev/rfc6979" "github.com/nspcc-dev/rfc6979"
) )
@ -24,18 +22,11 @@ type PrivateKey struct {
// NewPrivateKey creates a new random private key. // NewPrivateKey creates a new random private key.
func NewPrivateKey() (*PrivateKey, error) { func NewPrivateKey() (*PrivateKey, error) {
c := crypto.NewEllipticCurve() priv, _, _, err := elliptic.GenerateKey(elliptic.P256(), rand.Reader)
b := make([]byte, c.N.BitLen()/8+8) if err != nil {
if _, err := io.ReadFull(rand.Reader, b); err != nil {
return nil, err return nil, err
} }
return &PrivateKey{b: priv}, nil
d := new(big.Int).SetBytes(b)
d.Mod(d, new(big.Int).Sub(c.N, big.NewInt(1)))
d.Add(d, big.NewInt(1))
p := &PrivateKey{b: d.Bytes()}
return p, nil
} }
// NewPrivateKeyFromHex returns a PrivateKey created from the // NewPrivateKeyFromHex returns a PrivateKey created from the
@ -72,16 +63,16 @@ func (p *PrivateKey) PublicKey() (*PublicKey, error) {
var ( var (
err error err error
pk PublicKey pk PublicKey
c = crypto.NewEllipticCurve() c = elliptic.P256()
q = new(big.Int).SetBytes(p.b) q = new(big.Int).SetBytes(p.b)
) )
point := c.ScalarBaseMult(q) x, y := c.ScalarBaseMult(q.Bytes())
if !c.IsOnCurve(point) { if !c.IsOnCurve(x, y) {
return nil, errors.New("failed to derive public key using elliptic curve") return nil, errors.New("failed to derive public key using elliptic curve")
} }
bx := point.X.Bytes() bx := x.Bytes()
padded := append( padded := append(
bytes.Repeat( bytes.Repeat(
[]byte{0x00}, []byte{0x00},
@ -91,7 +82,7 @@ func (p *PrivateKey) PublicKey() (*PublicKey, error) {
) )
prefix := []byte{0x03} prefix := []byte{0x03}
if point.Y.Bit(0) == 0 { if y.Bit(0) == 0 {
prefix = []byte{0x02} prefix = []byte{0x02}
} }
b := append(prefix, padded...) b := append(prefix, padded...)

View file

@ -7,6 +7,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"fmt"
"io" "io"
"math/big" "math/big"
@ -35,9 +36,10 @@ func (keys PublicKeys) Less(i, j int) bool {
} }
// PublicKey represents a public key and provides a high level // PublicKey represents a public key and provides a high level
// API around the ECPoint. // API around the X/Y point.
type PublicKey struct { type PublicKey struct {
crypto.ECPoint X *big.Int
Y *big.Int
} }
// NewPublicKeyFromString return a public key created from the // NewPublicKeyFromString return a public key created from the
@ -58,7 +60,7 @@ func NewPublicKeyFromString(s string) (*PublicKey, error) {
// Bytes returns the byte array representation of the public key. // Bytes returns the byte array representation of the public key.
func (p *PublicKey) Bytes() []byte { func (p *PublicKey) Bytes() []byte {
if p.IsInfinity() { if p.isInfinity() {
return []byte{0x00} return []byte{0x00}
} }
@ -89,14 +91,38 @@ func NewPublicKeyFromRawBytes(data []byte) (*PublicKey, error) {
return nil, errors.New("given bytes aren't ECDSA public key") return nil, errors.New("given bytes aren't ECDSA public key")
} }
key := PublicKey{ key := PublicKey{
crypto.ECPoint{ X: pk.X,
X: pk.X, Y: pk.Y,
Y: pk.Y,
},
} }
return &key, nil return &key, nil
} }
// decodeCompressedY performs decompression of Y coordinate for given X and Y's least significant bit
func decodeCompressedY(x *big.Int, ylsb uint) (*big.Int, error) {
c := elliptic.P256()
cp := c.Params()
three := big.NewInt(3)
/* y**2 = x**3 + a*x + b % p */
xCubed := new(big.Int).Exp(x, three, cp.P)
threeX := new(big.Int).Mul(x, three)
threeX.Mod(threeX, cp.P)
ySquared := new(big.Int).Sub(xCubed, threeX)
ySquared.Add(ySquared, cp.B)
ySquared.Mod(ySquared, cp.P)
y := new(big.Int).ModSqrt(ySquared, cp.P)
if y == nil {
return nil, errors.New("error computing Y for compressed point")
}
if y.Bit(0) != ylsb {
y.Neg(y)
y.Mod(y, cp.P)
}
if !c.IsOnCurve(x, y) {
return nil, errors.New("compressed (x, ylsb) not on curve")
}
return y, nil
}
// DecodeBytes decodes a PublicKey from the given slice of bytes. // DecodeBytes decodes a PublicKey from the given slice of bytes.
func (p *PublicKey) DecodeBytes(data []byte) error { func (p *PublicKey) DecodeBytes(data []byte) error {
l := len(data) l := len(data)
@ -104,19 +130,22 @@ func (p *PublicKey) DecodeBytes(data []byte) error {
switch prefix := data[0]; prefix { switch prefix := data[0]; prefix {
// Infinity // Infinity
case 0x00: case 0x00:
p.ECPoint = crypto.ECPoint{} p.X = nil
p.Y = nil
// Compressed public keys // Compressed public keys
case 0x02, 0x03: case 0x02, 0x03:
if l < 33 { if l < 33 {
return errors.Errorf("bad binary size(%d)", l) return errors.Errorf("bad binary size(%d)", l)
} }
c := crypto.NewEllipticCurve() x := new(big.Int).SetBytes(data[1:])
var err error ylsb := uint(prefix&0x1)
p.ECPoint, err = c.Decompress(new(big.Int).SetBytes(data[1:]), uint(prefix&0x1)) y, err := decodeCompressedY(x, ylsb)
if err != nil { if err != nil {
return err return err
} }
p.X = x
p.Y = y
case 0x04: case 0x04:
if l < 66 { if l < 66 {
return errors.Errorf("bad binary size(%d)", l) return errors.Errorf("bad binary size(%d)", l)
@ -141,7 +170,8 @@ func (p *PublicKey) DecodeBinary(r io.Reader) error {
// Infinity // Infinity
switch prefix { switch prefix {
case 0x00: case 0x00:
p.ECPoint = crypto.ECPoint{} p.X = nil
p.Y = nil
return nil return nil
// Compressed public keys // Compressed public keys
case 0x02, 0x03: case 0x02, 0x03:
@ -206,3 +236,18 @@ func (p *PublicKey) Verify(signature []byte, hash []byte) bool {
sBytes := new(big.Int).SetBytes(signature[32:64]) sBytes := new(big.Int).SetBytes(signature[32:64])
return ecdsa.Verify(publicKey, hash, rBytes, sBytes) return ecdsa.Verify(publicKey, hash, rBytes, sBytes)
} }
// isInfinity checks if point P is infinity on EllipticCurve ec.
func (p *PublicKey) isInfinity() bool {
return p.X == nil && p.Y == nil
}
// String implements the Stringer interface.
func (p *PublicKey) String() string {
if p.isInfinity() {
return "00"
}
bx := hex.EncodeToString(p.X.Bytes())
by := hex.EncodeToString(p.Y.Bytes())
return fmt.Sprintf("%s%s", bx, by)
}

View file

@ -5,12 +5,11 @@ import (
"encoding/hex" "encoding/hex"
"testing" "testing"
"github.com/CityOfZion/neo-go/pkg/crypto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestEncodeDecodeInfinity(t *testing.T) { func TestEncodeDecodeInfinity(t *testing.T) {
key := &PublicKey{crypto.ECPoint{}} key := &PublicKey{}
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
assert.Nil(t, key.EncodeBinary(buf)) assert.Nil(t, key.EncodeBinary(buf))
assert.Equal(t, 1, buf.Len()) assert.Equal(t, 1, buf.Len())

View file

@ -1,61 +0,0 @@
package crypto
import "math/big"
// addMod computes z = (x + y) % p.
func addMod(x *big.Int, y *big.Int, p *big.Int) (z *big.Int) {
z = new(big.Int).Add(x, y)
z.Mod(z, p)
return z
}
// subMod computes z = (x - y) % p.
func subMod(x *big.Int, y *big.Int, p *big.Int) (z *big.Int) {
z = new(big.Int).Sub(x, y)
z.Mod(z, p)
return z
}
// mulMod computes z = (x * y) % p.
func mulMod(x *big.Int, y *big.Int, p *big.Int) (z *big.Int) {
n := new(big.Int).Set(x)
z = big.NewInt(0)
for i := 0; i < y.BitLen(); i++ {
if y.Bit(i) == 1 {
z = addMod(z, n, p)
}
n = addMod(n, n, p)
}
return z
}
// invMod computes z = (1/x) % p.
func invMod(x *big.Int, p *big.Int) (z *big.Int) {
z = new(big.Int).ModInverse(x, p)
return z
}
// expMod computes z = (x^e) % p.
func expMod(x *big.Int, y *big.Int, p *big.Int) (z *big.Int) {
z = new(big.Int).Exp(x, y, p)
return z
}
// sqrtMod computes z = sqrt(x) % p.
func sqrtMod(x *big.Int, p *big.Int) (z *big.Int) {
/* assert that p % 4 == 3 */
if new(big.Int).Mod(p, big.NewInt(4)).Cmp(big.NewInt(3)) != 0 {
panic("p is not equal to 3 mod 4!")
}
/* z = sqrt(x) % p = x^((p+1)/4) % p */
/* e = (p+1)/4 */
e := new(big.Int).Add(p, big.NewInt(1))
e = e.Rsh(e, 2)
z = expMod(x, e, p)
return z
}