diff --git a/README.md b/README.md index 60fe8b0..9d3b687 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,9 @@ The example of how it works can be seen in tests. ## AVX vs AVX2 version ``` -BenchmarkAVX-8 500 3579980 ns/op 27.93 MB/s 64 B/op 4 allocs/op -BenchmarkAVX2-8 500 2997518 ns/op 33.36 MB/s 64 B/op 2 allocs/op +BenchmarkAVX-8 500 3492019 ns/op 28.64 MB/s 64 B/op 4 allocs/op +BenchmarkAVX2-8 500 2752693 ns/op 36.33 MB/s 64 B/op 2 allocs/op +BenchmarkAVX2Inline-8 1000 1877260 ns/op 53.27 MB/s 64 B/op 2 allocs/op ``` # Contributing diff --git a/tz/avx2_unroll_amd64.s b/tz/avx2_unroll_amd64.s new file mode 100644 index 0000000..d4b368c --- /dev/null +++ b/tz/avx2_unroll_amd64.s @@ -0,0 +1,186 @@ +#include "textflag.h" + +// func mulByteRightx2(c00c10, c01c11 *[4]uint64, b byte) +TEXT ·mulByteRightx2(SB),NOSPLIT,$0 + MOVQ c00c10+0(FP), AX + VMOVDQA (AX), Y0 + MOVQ c01c11+8(FP), BX + VMOVDQA (BX), Y8 + MOVB b+16(FP), CX + + // 1 bit + VPSLLQ $1, Y0, Y1 + VPALIGNR $8, Y1, Y0, Y2 + VPSRLQ $63, Y2, Y2 + VPXOR Y1, Y2, Y2 + VPSRLQ $63, Y1, Y3 + VPSLLQ $63, Y3, Y3 + VPUNPCKHQDQ Y3, Y3, Y3 + VPXOR Y2, Y3, Y3 + + MOVQ CX, DX + SHRQ $7, DX + ANDQ $1, DX + NEGQ DX + MOVQ DX, X1 + VPBROADCASTB X1, Y2 + + VPXOR Y3, Y8, Y3 + VPAND Y3, Y2, Y4 + VPXOR Y4, Y0, Y8 + VMOVDQA Y3, Y0 + + // 2 bit + VPSLLQ $1, Y0, Y1 + VPALIGNR $8, Y1, Y0, Y2 + VPSRLQ $63, Y2, Y2 + VPXOR Y1, Y2, Y2 + VPSRLQ $63, Y1, Y3 + VPSLLQ $63, Y3, Y3 + VPUNPCKHQDQ Y3, Y3, Y3 + VPXOR Y2, Y3, Y3 + + MOVQ CX, DX + SHRQ $6, DX + ANDQ $1, DX + NEGQ DX + MOVQ DX, X1 + VPBROADCASTB X1, Y2 + + VPXOR Y3, Y8, Y3 + VPAND Y3, Y2, Y4 + VPXOR Y4, Y0, Y8 + VMOVDQA Y3, Y0 + + // 3 bit + VPSLLQ $1, Y0, Y1 + VPALIGNR $8, Y1, Y0, Y2 + VPSRLQ $63, Y2, Y2 + VPXOR Y1, Y2, Y2 + VPSRLQ $63, Y1, Y3 + VPSLLQ $63, Y3, Y3 + VPUNPCKHQDQ Y3, Y3, Y3 + VPXOR Y2, Y3, Y3 + + MOVQ CX, DX + SHRQ $5, DX + ANDQ $1, DX + NEGQ DX + MOVQ DX, X1 + VPBROADCASTB X1, Y2 + + VPXOR Y3, Y8, Y3 + VPAND Y3, Y2, Y4 + VPXOR Y4, Y0, Y8 + VMOVDQA Y3, Y0 + + // 4 bit + VPSLLQ $1, Y0, Y1 + VPALIGNR $8, Y1, Y0, Y2 + VPSRLQ $63, Y2, Y2 + VPXOR Y1, Y2, Y2 + VPSRLQ $63, Y1, Y3 + VPSLLQ $63, Y3, Y3 + VPUNPCKHQDQ Y3, Y3, Y3 + VPXOR Y2, Y3, Y3 + + MOVQ CX, DX + SHRQ $4, DX + ANDQ $1, DX + NEGQ DX + MOVQ DX, X1 + VPBROADCASTB X1, Y2 + + VPXOR Y3, Y8, Y3 + VPAND Y3, Y2, Y4 + VPXOR Y4, Y0, Y8 + VMOVDQA Y3, Y0 + + // 5 bit + VPSLLQ $1, Y0, Y1 + VPALIGNR $8, Y1, Y0, Y2 + VPSRLQ $63, Y2, Y2 + VPXOR Y1, Y2, Y2 + VPSRLQ $63, Y1, Y3 + VPSLLQ $63, Y3, Y3 + VPUNPCKHQDQ Y3, Y3, Y3 + VPXOR Y2, Y3, Y3 + + MOVQ CX, DX + SHRQ $3, DX + ANDQ $1, DX + NEGQ DX + MOVQ DX, X1 + VPBROADCASTB X1, Y2 + + VPXOR Y3, Y8, Y3 + VPAND Y3, Y2, Y4 + VPXOR Y4, Y0, Y8 + VMOVDQA Y3, Y0 + + // 6 bit + VPSLLQ $1, Y0, Y1 + VPALIGNR $8, Y1, Y0, Y2 + VPSRLQ $63, Y2, Y2 + VPXOR Y1, Y2, Y2 + VPSRLQ $63, Y1, Y3 + VPSLLQ $63, Y3, Y3 + VPUNPCKHQDQ Y3, Y3, Y3 + VPXOR Y2, Y3, Y3 + + MOVQ CX, DX + SHRQ $2, DX + ANDQ $1, DX + NEGQ DX + MOVQ DX, X1 + VPBROADCASTB X1, Y2 + + VPXOR Y3, Y8, Y3 + VPAND Y3, Y2, Y4 + VPXOR Y4, Y0, Y8 + VMOVDQA Y3, Y0 + + // 7 bit + VPSLLQ $1, Y0, Y1 + VPALIGNR $8, Y1, Y0, Y2 + VPSRLQ $63, Y2, Y2 + VPXOR Y1, Y2, Y2 + VPSRLQ $63, Y1, Y3 + VPSLLQ $63, Y3, Y3 + VPUNPCKHQDQ Y3, Y3, Y3 + VPXOR Y2, Y3, Y3 + + MOVQ CX, DX + SHRQ $1, DX + ANDQ $1, DX + NEGQ DX + MOVQ DX, X1 + VPBROADCASTB X1, Y2 + + VPXOR Y3, Y8, Y3 + VPAND Y3, Y2, Y4 + VPXOR Y4, Y0, Y8 + VMOVDQA Y3, Y0 + + // 8 bit + VPSLLQ $1, Y0, Y1 + VPALIGNR $8, Y1, Y0, Y2 + VPSRLQ $63, Y2, Y2 + VPXOR Y1, Y2, Y2 + VPSRLQ $63, Y1, Y3 + VPSLLQ $63, Y3, Y3 + VPUNPCKHQDQ Y3, Y3, Y3 + VPXOR Y2, Y3, Y3 + + MOVQ CX, DX + ANDQ $1, DX + NEGQ DX + MOVQ DX, X1 + VPBROADCASTB X1, Y2 + + VPXOR Y3, Y8, Y3 + VPAND Y3, Y2, Y4 + VPXOR Y4, Y0, Y8 + VMOVDQA Y8, (BX) + VMOVDQA Y3, (AX) + RET diff --git a/tz/hash_avx2.go b/tz/hash_avx2.go index 5d43516..d686f86 100644 --- a/tz/hash_avx2.go +++ b/tz/hash_avx2.go @@ -12,6 +12,12 @@ type digest2 struct { var _ hash.Hash = (*digest2)(nil) +func NewAVX2() hash.Hash { + d := new(digest2) + d.Reset() + return d +} + func (d *digest2) Write(data []byte) (n int, err error) { n = len(data) for _, b := range data { diff --git a/tz/hash_avx2_inline.go b/tz/hash_avx2_inline.go new file mode 100644 index 0000000..c2800f2 --- /dev/null +++ b/tz/hash_avx2_inline.go @@ -0,0 +1,54 @@ +package tz + +import ( + "hash" + + "github.com/nspcc-dev/tzhash/gf127" +) + +type digest3 struct { + x [2]gf127.GF127x2 +} + +var _ hash.Hash = (*digest3)(nil) + +func NewAVX2Inline() hash.Hash { + d := new(digest3) + d.Reset() + return d +} + +func (d *digest3) Write(data []byte) (n int, err error) { + n = len(data) + for _, b := range data { + mulByteRightx2(&d.x[0], &d.x[1], b) + } + return +} + +func (d *digest3) Sum(in []byte) []byte { + // Make a copy of d so that caller can keep writing and summing. + d0 := *d + h := d0.checkSum() + return append(in, h[:]...) +} +func (d *digest3) Reset() { + d.x[0] = gf127.GF127x2{1, 0, 0, 0} + d.x[1] = gf127.GF127x2{0, 0, 1, 0} +} +func (d *digest3) Size() int { return hashSize } +func (d *digest3) BlockSize() int { return hashBlockSize } +func (d *digest3) checkSum() (b [hashSize]byte) { + // Matrix is stored transposed, + // but we need to use order consistent with digest. + h := d.x[0].ByteArray() + copy(b[:], h[:16]) + copy(b[32:], h[16:]) + + h = d.x[1].ByteArray() + copy(b[16:], h[:16]) + copy(b[48:], h[16:]) + return +} + +func mulByteRightx2(c00c10 *gf127.GF127x2, c01c11 *gf127.GF127x2, b byte) diff --git a/tz/hash_test.go b/tz/hash_test.go index 650bfb5..750d1c3 100644 --- a/tz/hash_test.go +++ b/tz/hash_test.go @@ -51,6 +51,17 @@ func TestHash(t *testing.T) { require.Equal(t, tc.hash, hex.EncodeToString(sum[:])) } }) + + t.Run("test AVX2 digest with inline asm function", func(t *testing.T) { + d := new(digest3) + for _, tc := range testCases { + d.Reset() + _, _ = d.Write(tc.input) + sum := d.checkSum() + + require.Equal(t, tc.hash, hex.EncodeToString(sum[:])) + } + }) } func newBuffer() (data []byte) { @@ -92,6 +103,20 @@ func BenchmarkAVX2(b *testing.B) { b.SetBytes(int64(len(data))) } +func BenchmarkAVX2Inline(b *testing.B) { + data := newBuffer() + + b.ResetTimer() + b.ReportAllocs() + d := new(digest3) + for i := 0; i < b.N; i++ { + d.Reset() + _, _ = d.Write(data) + d.checkSum() + } + b.SetBytes(int64(len(data))) +} + func TestHomomorphism(t *testing.T) { var ( c1, c2 sl2