diff --git a/dsa.go b/dsa.go index 8f5b95d..f15b3a7 100644 --- a/dsa.go +++ b/dsa.go @@ -2,6 +2,7 @@ package rfc6979 import ( "crypto/dsa" + "hash" "math/big" ) @@ -12,7 +13,7 @@ import ( // Note that FIPS 186-3 section 4.6 specifies that the hash should be truncated // to the byte-length of the subgroup. This function does not perform that // truncation itself. -func SignDSA(priv *dsa.PrivateKey, hash []byte, alg HashFunc) (r, s *big.Int, err error) { +func SignDSA(priv *dsa.PrivateKey, hash []byte, alg func() hash.Hash) (r, s *big.Int, err error) { n := priv.Q.BitLen() if n&7 != 0 { err = dsa.ErrInvalidPublicKey diff --git a/dsa_test.go b/dsa_test.go index 4a64918..edca20b 100644 --- a/dsa_test.go +++ b/dsa_test.go @@ -1,4 +1,4 @@ -package rfc6979 +package rfc6979_test import ( "crypto/dsa" @@ -6,14 +6,17 @@ import ( "crypto/sha256" "crypto/sha512" "encoding/hex" + "hash" "math/big" "testing" + + "github.com/codahale/rfc6979" ) type dsaFixture struct { name string key *dsaKey - alg HashFunc + alg func() hash.Hash message string r, s string } @@ -239,7 +242,7 @@ func testDsaFixture(f *dsaFixture, t *testing.T) { digest = digest[0:g] } - r, s, err := SignDSA(f.key.key, digest, f.alg) + r, s, err := rfc6979.SignDSA(f.key.key, digest, f.alg) if err != nil { t.Error(err) return diff --git a/ecdsa.go b/ecdsa.go index 55e9b63..cfb3b7a 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -3,25 +3,10 @@ package rfc6979 import ( "crypto/ecdsa" "crypto/elliptic" + "hash" "math/big" ) -// copied from crypto/ecdsa -func hashToInt(hash []byte, c elliptic.Curve) *big.Int { - orderBits := c.Params().N.BitLen() - orderBytes := (orderBits + 7) / 8 - if len(hash) > orderBytes { - hash = hash[:orderBytes] - } - - ret := new(big.Int).SetBytes(hash) - excess := len(hash)*8 - orderBits - if excess > 0 { - ret.Rsh(ret, uint(excess)) - } - return ret -} - // SignECDSA signs an arbitrary length hash (which should be the result of // hashing a larger message) using the private key, priv. It returns the // signature as a pair of integers. @@ -29,7 +14,7 @@ func hashToInt(hash []byte, c elliptic.Curve) *big.Int { // Note that FIPS 186-3 section 4.6 specifies that the hash should be truncated // to the byte-length of the subgroup. This function does not perform that // truncation itself. -func SignECDSA(priv *ecdsa.PrivateKey, hash []byte, alg HashFunc) (r, s *big.Int, err error) { +func SignECDSA(priv *ecdsa.PrivateKey, hash []byte, alg func() hash.Hash) (r, s *big.Int, err error) { c := priv.PublicKey.Curve N := c.Params().N @@ -53,3 +38,19 @@ func SignECDSA(priv *ecdsa.PrivateKey, hash []byte, alg HashFunc) (r, s *big.Int return } + +// copied from crypto/ecdsa +func hashToInt(hash []byte, c elliptic.Curve) *big.Int { + orderBits := c.Params().N.BitLen() + orderBytes := (orderBits + 7) / 8 + if len(hash) > orderBytes { + hash = hash[:orderBytes] + } + + ret := new(big.Int).SetBytes(hash) + excess := len(hash)*8 - orderBits + if excess > 0 { + ret.Rsh(ret, uint(excess)) + } + return ret +} diff --git a/ecdsa_test.go b/ecdsa_test.go index 7955b81..d56ca1f 100644 --- a/ecdsa_test.go +++ b/ecdsa_test.go @@ -1,4 +1,4 @@ -package rfc6979 +package rfc6979_test import ( "crypto/ecdsa" @@ -6,14 +6,17 @@ import ( "crypto/sha1" "crypto/sha256" "crypto/sha512" + "hash" "math/big" "testing" + + "github.com/codahale/rfc6979" ) type ecdsaFixture struct { name string key *ecdsaKey - alg HashFunc + alg func() hash.Hash message string r, s string } @@ -425,7 +428,7 @@ func testEcsaFixture(f *ecdsaFixture, t *testing.T) { digest = digest[0:g] } - r, s, err := SignECDSA(f.key.key, digest, f.alg) + r, s, err := rfc6979.SignECDSA(f.key.key, digest, f.alg) if err != nil { t.Error(err) return diff --git a/rfc6979.go b/rfc6979.go index da27a4b..cb7b4b0 100644 --- a/rfc6979.go +++ b/rfc6979.go @@ -22,11 +22,8 @@ import ( "math/big" ) -// HashFunc is a function which provides a fresh Hash (e.g., sha256.New). -type HashFunc func() hash.Hash - // mac returns an HMAC of the given key and message. -func (alg HashFunc) mac(k, m, buf []byte) []byte { +func mac(alg func() hash.Hash, k, m, buf []byte) []byte { h := hmac.New(alg, k) h.Write(m) return h.Sum(buf[:0]) @@ -76,7 +73,7 @@ func bits2octets(in []byte, q *big.Int, qlen, rolen int) []byte { var one = big.NewInt(1) // https://tools.ietf.org/html/rfc6979#section-3.2 -func generateSecret(q, x *big.Int, alg HashFunc, hash []byte, test func(*big.Int) bool) { +func generateSecret(q, x *big.Int, alg func() hash.Hash, hash []byte, test func(*big.Int) bool) { qlen := q.BitLen() holen := alg().Size() rolen := (qlen + 7) >> 3 @@ -89,25 +86,25 @@ func generateSecret(q, x *big.Int, alg HashFunc, hash []byte, test func(*big.Int k := bytes.Repeat([]byte{0x00}, holen) // Step D - k = alg.mac(k, append(append(v, 0x00), bx...), k) + k = mac(alg, k, append(append(v, 0x00), bx...), k) // Step E - v = alg.mac(k, v, v) + v = mac(alg, k, v, v) // Step F - k = alg.mac(k, append(append(v, 0x01), bx...), k) + k = mac(alg, k, append(append(v, 0x01), bx...), k) // Step G - v = alg.mac(k, v, v) + v = mac(alg, k, v, v) // Step H for { // Step H1 - t := make([]byte, 0) + var t []byte // Step H2 for len(t) < qlen/8 { - v = alg.mac(k, v, v) + v = mac(alg, k, v, v) t = append(t, v...) } @@ -116,7 +113,7 @@ func generateSecret(q, x *big.Int, alg HashFunc, hash []byte, test func(*big.Int if secret.Cmp(one) >= 0 && secret.Cmp(q) < 0 && test(secret) { return } - k = alg.mac(k, append(v, 0x00), k) - v = alg.mac(k, v, v) + k = mac(alg, k, append(v, 0x00), k) + v = mac(alg, k, v, v) } }