Inline asm function in loop for AVX2 implementation

Right now AVX2 implementation looses to C binding in speed.
This is probably, because of 2 things:
1. Go does not inline `mulBitRightx2` in loop iteration.
2. `minmax` is loaded every time from memory.

In this PR:
1. Unroll `mulBitRightx2` manually and use `mulByteRightx2` instead.
2. Generate `minmax` in place without `LOAD/LEA` instructions.
This commit is contained in:
Evgenii 2019-07-19 16:11:01 +03:00
parent 5c2544cf3b
commit c68e38b943
5 changed files with 274 additions and 2 deletions

View file

@ -31,8 +31,9 @@ The example of how it works can be seen in tests.
## AVX vs AVX2 version ## AVX vs AVX2 version
``` ```
BenchmarkAVX-8 500 3579980 ns/op 27.93 MB/s 64 B/op 4 allocs/op BenchmarkAVX-8 500 3492019 ns/op 28.64 MB/s 64 B/op 4 allocs/op
BenchmarkAVX2-8 500 2997518 ns/op 33.36 MB/s 64 B/op 2 allocs/op BenchmarkAVX2-8 500 2752693 ns/op 36.33 MB/s 64 B/op 2 allocs/op
BenchmarkAVX2Inline-8 1000 1877260 ns/op 53.27 MB/s 64 B/op 2 allocs/op
``` ```
# Contributing # Contributing

186
tz/avx2_unroll_amd64.s Normal file
View file

@ -0,0 +1,186 @@
#include "textflag.h"
// func mulByteRightx2(c00c10, c01c11 *[4]uint64, b byte)
TEXT ·mulByteRightx2(SB),NOSPLIT,$0
MOVQ c00c10+0(FP), AX
VMOVDQA (AX), Y0
MOVQ c01c11+8(FP), BX
VMOVDQA (BX), Y8
MOVB b+16(FP), CX
// 1 bit
VPSLLQ $1, Y0, Y1
VPALIGNR $8, Y1, Y0, Y2
VPSRLQ $63, Y2, Y2
VPXOR Y1, Y2, Y2
VPSRLQ $63, Y1, Y3
VPSLLQ $63, Y3, Y3
VPUNPCKHQDQ Y3, Y3, Y3
VPXOR Y2, Y3, Y3
MOVQ CX, DX
SHRQ $7, DX
ANDQ $1, DX
NEGQ DX
MOVQ DX, X1
VPBROADCASTB X1, Y2
VPXOR Y3, Y8, Y3
VPAND Y3, Y2, Y4
VPXOR Y4, Y0, Y8
VMOVDQA Y3, Y0
// 2 bit
VPSLLQ $1, Y0, Y1
VPALIGNR $8, Y1, Y0, Y2
VPSRLQ $63, Y2, Y2
VPXOR Y1, Y2, Y2
VPSRLQ $63, Y1, Y3
VPSLLQ $63, Y3, Y3
VPUNPCKHQDQ Y3, Y3, Y3
VPXOR Y2, Y3, Y3
MOVQ CX, DX
SHRQ $6, DX
ANDQ $1, DX
NEGQ DX
MOVQ DX, X1
VPBROADCASTB X1, Y2
VPXOR Y3, Y8, Y3
VPAND Y3, Y2, Y4
VPXOR Y4, Y0, Y8
VMOVDQA Y3, Y0
// 3 bit
VPSLLQ $1, Y0, Y1
VPALIGNR $8, Y1, Y0, Y2
VPSRLQ $63, Y2, Y2
VPXOR Y1, Y2, Y2
VPSRLQ $63, Y1, Y3
VPSLLQ $63, Y3, Y3
VPUNPCKHQDQ Y3, Y3, Y3
VPXOR Y2, Y3, Y3
MOVQ CX, DX
SHRQ $5, DX
ANDQ $1, DX
NEGQ DX
MOVQ DX, X1
VPBROADCASTB X1, Y2
VPXOR Y3, Y8, Y3
VPAND Y3, Y2, Y4
VPXOR Y4, Y0, Y8
VMOVDQA Y3, Y0
// 4 bit
VPSLLQ $1, Y0, Y1
VPALIGNR $8, Y1, Y0, Y2
VPSRLQ $63, Y2, Y2
VPXOR Y1, Y2, Y2
VPSRLQ $63, Y1, Y3
VPSLLQ $63, Y3, Y3
VPUNPCKHQDQ Y3, Y3, Y3
VPXOR Y2, Y3, Y3
MOVQ CX, DX
SHRQ $4, DX
ANDQ $1, DX
NEGQ DX
MOVQ DX, X1
VPBROADCASTB X1, Y2
VPXOR Y3, Y8, Y3
VPAND Y3, Y2, Y4
VPXOR Y4, Y0, Y8
VMOVDQA Y3, Y0
// 5 bit
VPSLLQ $1, Y0, Y1
VPALIGNR $8, Y1, Y0, Y2
VPSRLQ $63, Y2, Y2
VPXOR Y1, Y2, Y2
VPSRLQ $63, Y1, Y3
VPSLLQ $63, Y3, Y3
VPUNPCKHQDQ Y3, Y3, Y3
VPXOR Y2, Y3, Y3
MOVQ CX, DX
SHRQ $3, DX
ANDQ $1, DX
NEGQ DX
MOVQ DX, X1
VPBROADCASTB X1, Y2
VPXOR Y3, Y8, Y3
VPAND Y3, Y2, Y4
VPXOR Y4, Y0, Y8
VMOVDQA Y3, Y0
// 6 bit
VPSLLQ $1, Y0, Y1
VPALIGNR $8, Y1, Y0, Y2
VPSRLQ $63, Y2, Y2
VPXOR Y1, Y2, Y2
VPSRLQ $63, Y1, Y3
VPSLLQ $63, Y3, Y3
VPUNPCKHQDQ Y3, Y3, Y3
VPXOR Y2, Y3, Y3
MOVQ CX, DX
SHRQ $2, DX
ANDQ $1, DX
NEGQ DX
MOVQ DX, X1
VPBROADCASTB X1, Y2
VPXOR Y3, Y8, Y3
VPAND Y3, Y2, Y4
VPXOR Y4, Y0, Y8
VMOVDQA Y3, Y0
// 7 bit
VPSLLQ $1, Y0, Y1
VPALIGNR $8, Y1, Y0, Y2
VPSRLQ $63, Y2, Y2
VPXOR Y1, Y2, Y2
VPSRLQ $63, Y1, Y3
VPSLLQ $63, Y3, Y3
VPUNPCKHQDQ Y3, Y3, Y3
VPXOR Y2, Y3, Y3
MOVQ CX, DX
SHRQ $1, DX
ANDQ $1, DX
NEGQ DX
MOVQ DX, X1
VPBROADCASTB X1, Y2
VPXOR Y3, Y8, Y3
VPAND Y3, Y2, Y4
VPXOR Y4, Y0, Y8
VMOVDQA Y3, Y0
// 8 bit
VPSLLQ $1, Y0, Y1
VPALIGNR $8, Y1, Y0, Y2
VPSRLQ $63, Y2, Y2
VPXOR Y1, Y2, Y2
VPSRLQ $63, Y1, Y3
VPSLLQ $63, Y3, Y3
VPUNPCKHQDQ Y3, Y3, Y3
VPXOR Y2, Y3, Y3
MOVQ CX, DX
ANDQ $1, DX
NEGQ DX
MOVQ DX, X1
VPBROADCASTB X1, Y2
VPXOR Y3, Y8, Y3
VPAND Y3, Y2, Y4
VPXOR Y4, Y0, Y8
VMOVDQA Y8, (BX)
VMOVDQA Y3, (AX)
RET

View file

@ -12,6 +12,12 @@ type digest2 struct {
var _ hash.Hash = (*digest2)(nil) var _ hash.Hash = (*digest2)(nil)
func NewAVX2() hash.Hash {
d := new(digest2)
d.Reset()
return d
}
func (d *digest2) Write(data []byte) (n int, err error) { func (d *digest2) Write(data []byte) (n int, err error) {
n = len(data) n = len(data)
for _, b := range data { for _, b := range data {

54
tz/hash_avx2_inline.go Normal file
View file

@ -0,0 +1,54 @@
package tz
import (
"hash"
"github.com/nspcc-dev/tzhash/gf127"
)
type digest3 struct {
x [2]gf127.GF127x2
}
var _ hash.Hash = (*digest3)(nil)
func NewAVX2Inline() hash.Hash {
d := new(digest3)
d.Reset()
return d
}
func (d *digest3) Write(data []byte) (n int, err error) {
n = len(data)
for _, b := range data {
mulByteRightx2(&d.x[0], &d.x[1], b)
}
return
}
func (d *digest3) Sum(in []byte) []byte {
// Make a copy of d so that caller can keep writing and summing.
d0 := *d
h := d0.checkSum()
return append(in, h[:]...)
}
func (d *digest3) Reset() {
d.x[0] = gf127.GF127x2{1, 0, 0, 0}
d.x[1] = gf127.GF127x2{0, 0, 1, 0}
}
func (d *digest3) Size() int { return hashSize }
func (d *digest3) BlockSize() int { return hashBlockSize }
func (d *digest3) checkSum() (b [hashSize]byte) {
// Matrix is stored transposed,
// but we need to use order consistent with digest.
h := d.x[0].ByteArray()
copy(b[:], h[:16])
copy(b[32:], h[16:])
h = d.x[1].ByteArray()
copy(b[16:], h[:16])
copy(b[48:], h[16:])
return
}
func mulByteRightx2(c00c10 *gf127.GF127x2, c01c11 *gf127.GF127x2, b byte)

View file

@ -51,6 +51,17 @@ func TestHash(t *testing.T) {
require.Equal(t, tc.hash, hex.EncodeToString(sum[:])) require.Equal(t, tc.hash, hex.EncodeToString(sum[:]))
} }
}) })
t.Run("test AVX2 digest with inline asm function", func(t *testing.T) {
d := new(digest3)
for _, tc := range testCases {
d.Reset()
_, _ = d.Write(tc.input)
sum := d.checkSum()
require.Equal(t, tc.hash, hex.EncodeToString(sum[:]))
}
})
} }
func newBuffer() (data []byte) { func newBuffer() (data []byte) {
@ -92,6 +103,20 @@ func BenchmarkAVX2(b *testing.B) {
b.SetBytes(int64(len(data))) b.SetBytes(int64(len(data)))
} }
func BenchmarkAVX2Inline(b *testing.B) {
data := newBuffer()
b.ResetTimer()
b.ReportAllocs()
d := new(digest3)
for i := 0; i < b.N; i++ {
d.Reset()
_, _ = d.Write(data)
d.checkSum()
}
b.SetBytes(int64(len(data)))
}
func TestHomomorphism(t *testing.T) { func TestHomomorphism(t *testing.T) {
var ( var (
c1, c2 sl2 c1, c2 sl2