From 9485f49f3b512f4728e17354843c99bb1d8cfc78 Mon Sep 17 00:00:00 2001 From: Evgenii Date: Fri, 21 Jun 2019 22:29:08 +0300 Subject: [PATCH] Get rid of unsafe usage and add tests --- tz/hash.go | 58 ++++++------------------------------------------- tz/hash_avx2.go | 55 ++++++++++++++++++++++++++++++++++++++++++++++ tz/hash_test.go | 48 +++++++++++++++++++++++++++++++++++++--- 3 files changed, 107 insertions(+), 54 deletions(-) create mode 100644 tz/hash_avx2.go diff --git a/tz/hash.go b/tz/hash.go index dd12474..4d9b65f 100644 --- a/tz/hash.go +++ b/tz/hash.go @@ -7,7 +7,6 @@ import ( "errors" "hash" "math" - "unsafe" "github.com/nspcc-dev/tzhash/gf127" ) @@ -17,19 +16,12 @@ const ( hashBlockSize = 128 ) -type ( - digest struct { - x [4]gf127.GF127 - } +type digest struct { + x [4]gf127.GF127 +} - digest2 digest -) - -// type assertions -var ( - _ hash.Hash = new(digest) - _ hash.Hash = new(digest2) -) +// type assertion +var _ hash.Hash = new(digest) var ( minmax = [2]gf127.GF127{{0, 0}, {math.MaxUint64, math.MaxUint64}} @@ -93,46 +85,11 @@ func (d *digest) BlockSize() int { return hashBlockSize } -func (d *digest2) Write(data []byte) (n int, err error) { - n = len(data) - - // We need to transpose matrix, because - // mulBitRightx2 accepts matrix by columns, not rows - a := d.x[1] - d.x[1] = d.x[2] - d.x[2] = a - - h1 := (*gf127.GF127x2)(unsafe.Pointer(&d.x[0])) - h2 := (*gf127.GF127x2)(unsafe.Pointer(&d.x[2])) - for _, b := range data { - mulBitRightx2(h1, h2, &minmax[(b>>7)&1]) - mulBitRightx2(h1, h2, &minmax[(b>>6)&1]) - mulBitRightx2(h1, h2, &minmax[(b>>5)&1]) - mulBitRightx2(h1, h2, &minmax[(b>>4)&1]) - mulBitRightx2(h1, h2, &minmax[(b>>3)&1]) - mulBitRightx2(h1, h2, &minmax[(b>>2)&1]) - mulBitRightx2(h1, h2, &minmax[(b>>1)&1]) - mulBitRightx2(h1, h2, &minmax[(b>>0)&1]) - } - - // transpose matrix back - a = d.x[1] - d.x[1] = d.x[2] - d.x[2] = a - - return -} -func (d *digest2) Sum(b []byte) []byte { return (*digest)(d).Sum(b) } -func (d *digest2) Reset() { (*digest)(d).Reset() } -func (d *digest2) Size() int { return (*digest)(d).Size() } -func (d *digest2) BlockSize() int { return (*digest)(d).BlockSize() } -func (d *digest2) checkSum() [hashSize]byte { return (*digest)(d).checkSum() } - // Sum returnz Tillich-ZĂ©mor checksum of data func Sum(data []byte) [hashSize]byte { - d := new(digest2) + d := new(digest) d.Reset() - d.Write(data) + _, _ = d.Write(data) // no errors return d.checkSum() } @@ -215,4 +172,3 @@ func SubtractL(c, a []byte) (b []byte, err error) { } func mulBitRight(c00, c01, c10, c11, e *gf127.GF127) -func mulBitRightx2(c00c01 *gf127.GF127x2, c10c11 *gf127.GF127x2, e *gf127.GF127) diff --git a/tz/hash_avx2.go b/tz/hash_avx2.go new file mode 100644 index 0000000..3acfb56 --- /dev/null +++ b/tz/hash_avx2.go @@ -0,0 +1,55 @@ +package tz + +import ( + "hash" + + "github.com/nspcc-dev/tzhash/gf127" +) + +type digest2 struct { + x [2]gf127.GF127x2 +} + +var _ hash.Hash = new(digest2) + +func (d *digest2) Write(data []byte) (n int, err error) { + n = len(data) + for _, b := range data { + mulBitRightx2(&d.x[0], &d.x[1], &minmax[(b>>7)&1]) + mulBitRightx2(&d.x[0], &d.x[1], &minmax[(b>>6)&1]) + mulBitRightx2(&d.x[0], &d.x[1], &minmax[(b>>5)&1]) + mulBitRightx2(&d.x[0], &d.x[1], &minmax[(b>>4)&1]) + mulBitRightx2(&d.x[0], &d.x[1], &minmax[(b>>3)&1]) + mulBitRightx2(&d.x[0], &d.x[1], &minmax[(b>>2)&1]) + mulBitRightx2(&d.x[0], &d.x[1], &minmax[(b>>1)&1]) + mulBitRightx2(&d.x[0], &d.x[1], &minmax[(b>>0)&1]) + } + return +} + +func (d *digest2) Sum(in []byte) []byte { + // Make a copy of d so that caller can keep writing and summing. + d0 := *d + h := d0.checkSum() + return append(in, h[:]...) +} +func (d *digest2) Reset() { + d.x[0] = gf127.GF127x2{1, 0, 0, 0} + d.x[1] = gf127.GF127x2{0, 0, 0, 1} +} +func (d *digest2) Size() int { return hashSize } +func (d *digest2) BlockSize() int { return hashBlockSize } +func (d *digest2) checkSum() (b [hashSize]byte) { + // Matrix is stored transposed, + // but we need to use order consistent with digest. + h := d.x[0].ByteArray() + copy(b[:], h[:8]) + copy(b[16:], h[8:]) + + h = d.x[1].ByteArray() + copy(b[8:], h[:8]) + copy(b[24:], h[8:]) + return +} + +func mulBitRightx2(c00c10 *gf127.GF127x2, c01c11 *gf127.GF127x2, e *gf127.GF127) diff --git a/tz/hash_test.go b/tz/hash_test.go index a9d5105..08ec0c0 100644 --- a/tz/hash_test.go +++ b/tz/hash_test.go @@ -8,7 +8,49 @@ import ( "github.com/stretchr/testify/require" ) +var testCases = []struct { + input []byte + hash string +}{ + { + []byte{0, 1, 2, 3, 4, 5, 6, 7, 8}, + "00000000000001e4a545e5b90fb6882b00000000000000c849cd88f79307f67100000000000000cd0c898cb68356e624000000000000007cbcdc7c5e89b16e4b", + }, + { + []byte{4, 8, 15, 16, 23, 42, 255, 0, 127, 65, 32, 123, 42, 45, 201, 210, 213, 244}, + "4db8a8e253903c70ab0efb65fe6de05a36d1dc9f567a147152d0148a86817b2062908d9b026a506007c1118e86901b672a39317c55ee3c10ac8efafa79efe8ee", + }, +} + func TestHash(t *testing.T) { + t.Run("test AVX digest", func(t *testing.T) { + d := new(digest) + for _, tc := range testCases { + d.Reset() + _, _ = d.Write(tc.input) + sum := d.checkSum() + hash := hex.EncodeToString(sum[:]) + if hash != tc.hash { + t.Errorf("expected (%s), got (%s)", tc.hash, hash) + } + } + }) + + t.Run("test AVX2 digest", func(t *testing.T) { + d := new(digest) + for _, tc := range testCases { + d.Reset() + _, _ = d.Write(tc.input) + sum := d.checkSum() + hash := hex.EncodeToString(sum[:]) + if hash != tc.hash { + t.Errorf("expected (%s), got (%s)", tc.hash, hash) + } + } + }) +} + +func TestHomomorphism(t *testing.T) { var ( c1, c2 sl2 n int @@ -36,7 +78,7 @@ func TestHash(t *testing.T) { require.Equal(t, h, c1.ByteArray()) } -var testCases = []struct { +var testCasesConcat = []struct { Hash string Parts []string }{{ @@ -62,7 +104,7 @@ func TestConcat(t *testing.T) { err error ) - for _, tc := range testCases { + for _, tc := range testCasesConcat { expect, err = hex.DecodeString(tc.Hash) require.NoError(t, err) @@ -86,7 +128,7 @@ func TestValidate(t *testing.T) { err error ) - for _, tc := range testCases { + for _, tc := range testCasesConcat { hash, _ = hex.DecodeString(tc.Hash) require.NoError(t, err)