diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 086ab630b..9fffc4e15 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -84,6 +84,8 @@ type VM struct { // Public keys cache. keys map[string]*keys.PublicKey + + checkMultisig func(*VM, [][]byte, [][]byte) bool } // New returns a new VM object ready to load .avm bytecode scripts. @@ -1220,24 +1222,14 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro if v.checkhash == nil { panic("VM is not set up properly for signature checks") } - sigok := true - // j counts keys and i counts signatures. - j := 0 - for i := 0; sigok && j < len(pkeys) && i < len(sigs); { - pkey := v.bytesToPublicKey(pkeys[j]) - // We only move to the next signature if the check was - // successful, but if it's not maybe the next key will - // fit, so we always move to the next key. - if pkey.Verify(sigs[i], v.checkhash) { - i++ - } - j++ - // When there are more signatures left to check than - // there are keys the check won't successed for sure. - if len(sigs)-i > len(pkeys)-j { - sigok = false - } + + var sigok bool + if v.checkMultisig == nil { + sigok = checkMultisigPar(v, pkeys, sigs) + } else { + sigok = v.checkMultisig(v, pkeys, sigs) } + v.estack.PushVal(sigok) case opcode.NEWMAP: @@ -1428,6 +1420,135 @@ func (v *VM) getJumpOffset(ctx *Context, parameter []byte) int { return offset } +func checkMultisigPar(v *VM, pkeys [][]byte, sigs [][]byte) bool { + if len(sigs) == 1 { + return checkMultisig1(v, pkeys, sigs[0]) + } + + k1, k2 := 0, len(pkeys)-1 + s1, s2 := 0, len(sigs)-1 + + type task struct { + pub *keys.PublicKey + signum int + } + + type verify struct { + ok bool + signum int + } + + worker := func(ch <-chan task, result chan verify) { + for { + t, ok := <-ch + if !ok { + return + } + + result <- verify{ + signum: t.signum, + ok: t.pub.Verify(sigs[t.signum], v.checkhash), + } + } + } + + const workerCount = 3 + tasks := make(chan task, 2) + results := make(chan verify, len(sigs)) + for i := 0; i < workerCount; i++ { + go worker(tasks, results) + } + + tasks <- task{pub: v.bytesToPublicKey(pkeys[k1]), signum: s1} + tasks <- task{pub: v.bytesToPublicKey(pkeys[k2]), signum: s2} + + sigok := true + taskCount := 2 + +loop: + for r := range results { + taskCount-- + if r.signum == s1 { + if k1+1 == k2 { + sigok = r.ok && s1+1 == s2 + if taskCount != 0 && sigok { + continue + } + break loop + } else if r.ok { + if s1+1 == s2 { + if taskCount != 0 && sigok { + continue + } + break loop + } + s1++ + } + k1++ + taskCount++ + tasks <- task{pub: v.bytesToPublicKey(pkeys[k1]), signum: s1} + } else { + if k1+1 == k2 { + sigok = r.ok && s1+1 == s2 + if taskCount != 0 && sigok { + continue + } + break loop + } else if r.ok { + if s1+1 == s2 { + if taskCount != 0 && sigok { + continue + } + break loop + } + s2-- + } + k2-- + taskCount++ + tasks <- task{pub: v.bytesToPublicKey(pkeys[k2]), signum: s2} + } + } + + close(tasks) + + return sigok +} + +func checkMultisig1(v *VM, pkeys [][]byte, sig []byte) bool { + for i := range pkeys { + pkey := v.bytesToPublicKey(pkeys[i]) + if pkey.Verify(sig, v.checkhash) { + return true + } + } + + return false +} + +func checkMultisigSeq(v *VM, pkeys [][]byte, sigs [][]byte) bool { + // j counts keys and i counts signatures. + j := 0 + for i := 0; j < len(pkeys) && i < len(sigs); { + pkey := v.bytesToPublicKey(pkeys[j]) + + // We only move to the next signature if the check was + // successful, but if it's not maybe the next key will + // fit, so we always move to the next key. + if pkey.Verify(sigs[i], v.checkhash) { + i++ + } + j++ + + // When there are more signatures left to check than + // there are keys the check won't successed for sure. + if len(sigs)-i > len(pkeys)-j { + return false + } + } + + return true +} + func cloneIfStruct(item StackItem) StackItem { switch it := item.(type) { case *StructItem: diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index cab01864a..9b56ad9c7 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -2699,24 +2699,188 @@ func TestCHECKMULTISIGBadSig(t *testing.T) { assert.Equal(t, false, vm.estack.Pop().Bool()) } -func TestCHECKMULTISIGGood(t *testing.T) { +func initCHECKMULTISIG(msg []byte, n int) ([]StackItem, []StackItem, map[string]*keys.PublicKey, error) { + var err error + + keyMap := make(map[string]*keys.PublicKey) + pkeys := make([]*keys.PrivateKey, n) + pubs := make([]StackItem, n) + for i := range pubs { + pkeys[i], err = keys.NewPrivateKey() + if err != nil { + return nil, nil, nil, err + } + + pk := pkeys[i].PublicKey() + data := pk.Bytes() + pubs[i] = NewByteArrayItem(data) + keyMap[string(data)] = pk + } + + sigs := make([]StackItem, n) + for i := range sigs { + sig := pkeys[i].Sign(msg) + sigs[i] = NewByteArrayItem(sig) + } + + return pubs, sigs, keyMap, nil +} + +func subSlice(arr []StackItem, indices []int) []StackItem { + if indices == nil { + return arr + } + + result := make([]StackItem, len(indices)) + for i, j := range indices { + result[i] = arr[j] + } + + return result +} + +func initCHECKMULTISIGVM(t *testing.T, n int, ik, is []int) *VM { prog := makeProgram(opcode.CHECKMULTISIG) - pk1, err := keys.NewPrivateKey() - assert.Nil(t, err) - pk2, err := keys.NewPrivateKey() - assert.Nil(t, err) + v := load(prog) msg := []byte("NEO - An Open Network For Smart Economy") - sig1 := pk1.Sign(msg) - sig2 := pk2.Sign(msg) - pbytes1 := pk1.PublicKey().Bytes() - pbytes2 := pk2.PublicKey().Bytes() - vm := load(prog) - vm.SetCheckedHash(hash.Sha256(msg).BytesBE()) - vm.estack.PushVal([]StackItem{NewByteArrayItem(sig1), NewByteArrayItem(sig2)}) - vm.estack.PushVal([]StackItem{NewByteArrayItem(pbytes1), NewByteArrayItem(pbytes2)}) - runVM(t, vm) - assert.Equal(t, 1, vm.estack.Len()) - assert.Equal(t, true, vm.estack.Pop().Bool()) + + v.SetCheckedHash(hash.Sha256(msg).BytesBE()) + + pubs, sigs, _, err := initCHECKMULTISIG(msg, n) + require.NoError(t, err) + + pubs = subSlice(pubs, ik) + sigs = subSlice(sigs, is) + + v.estack.PushVal(sigs) + v.estack.PushVal(pubs) + + return v +} + +func testCHECKMULTISIGGood(t *testing.T, n int, is []int) { + v := initCHECKMULTISIGVM(t, n, nil, is) + + runVM(t, v) + assert.Equal(t, 1, v.estack.Len()) + assert.True(t, v.estack.Pop().Bool()) +} + +func TestCHECKMULTISIGGood(t *testing.T) { + t.Run("3_1", func(t *testing.T) { testCHECKMULTISIGGood(t, 3, []int{1}) }) + t.Run("2_2", func(t *testing.T) { testCHECKMULTISIGGood(t, 2, []int{0, 1}) }) + t.Run("3_3", func(t *testing.T) { testCHECKMULTISIGGood(t, 3, []int{0, 1, 2}) }) + t.Run("3_2", func(t *testing.T) { testCHECKMULTISIGGood(t, 3, []int{0, 2}) }) + t.Run("4_2", func(t *testing.T) { testCHECKMULTISIGGood(t, 4, []int{0, 2}) }) + t.Run("10_7", func(t *testing.T) { testCHECKMULTISIGGood(t, 10, []int{2, 3, 4, 5, 6, 8, 9}) }) + t.Run("12_9", func(t *testing.T) { testCHECKMULTISIGGood(t, 12, []int{0, 1, 4, 5, 6, 7, 8, 9}) }) +} + +func testCHECKMULTISIGBad(t *testing.T, n int, ik, is []int) { + v := initCHECKMULTISIGVM(t, n, ik, is) + + runVM(t, v) + assert.Equal(t, 1, v.estack.Len()) + assert.False(t, v.estack.Pop().Bool()) +} + +func TestCHECKMULTISIGBad(t *testing.T) { + t.Run("1_1 wrong signature", func(t *testing.T) { testCHECKMULTISIGBad(t, 2, []int{0}, []int{1}) }) + t.Run("3_2 wrong order", func(t *testing.T) { testCHECKMULTISIGBad(t, 3, []int{0, 2}, []int{2, 0}) }) + t.Run("3_2 duplicate sig", func(t *testing.T) { testCHECKMULTISIGBad(t, 3, nil, []int{0, 0}) }) +} + +func benchCHECKMULTISIG(b *testing.B, size int, is [][]int, check func(*VM, [][]byte, [][]byte) bool) { + prog := makeProgram(opcode.CHECKMULTISIG) + msg := []byte("NEO - An Open Network For Smart Economy") + h := hash.Sha256(msg).BytesBE() + + pubs, sigs, keyMap, err := initCHECKMULTISIG(msg, size) + if err != nil { + b.Fatalf("error on initialize: %v", err) + } + + sigsStack := make([][]StackItem, 0, len(is)) + for i := range is { + sigsStack = append(sigsStack, subSlice(sigs, is[i])) + } + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for j := range sigsStack { + vm := load(prog) + vm.SetCheckedHash(h) + vm.SetPublicKeys(keyMap) + vm.checkMultisig = check + vm.estack.PushVal(sigsStack[j]) + vm.estack.PushVal(pubs) + _ = vm.Run() + } + } +} + +func BenchmarkCHECKMULTISIG(b *testing.B) { + fs := []struct { + name string + f func(*VM, [][]byte, [][]byte) bool + }{ + {"seq", checkMultisigSeq}, + {"par", checkMultisigPar}, + } + + b.Run("4_3 start", func(b *testing.B) { + for i := range fs { + b.Run(fs[i].name, func(b *testing.B) { + benchCHECKMULTISIG(b, 4, [][]int{{0, 1, 2}}, fs[i].f) + }) + } + }) + + b.Run("4_3 all", func(b *testing.B) { + for i := range fs { + b.Run(fs[i].name, func(b *testing.B) { + benchCHECKMULTISIG(b, 4, combinations(4, 3), fs[i].f) + }) + } + }) + + b.Run("7_5 start", func(b *testing.B) { + for i := range fs { + b.Run(fs[i].name, func(b *testing.B) { + benchCHECKMULTISIG(b, 7, [][]int{{0, 1, 2, 3, 4}}, fs[i].f) + }) + } + }) + + b.Run("7_5 all", func(b *testing.B) { + for i := range fs { + b.Run(fs[i].name, func(b *testing.B) { + benchCHECKMULTISIG(b, 7, combinations(7, 5), fs[i].f) + }) + } + }) +} + +func combinations(n, k int) [][]int { + if n < k { + return nil + } else if k == 0 { + return [][]int{{}} + } else if n == k { + res := make([]int, k) + for i := 0; i < k; i++ { + res[i] = i + } + return [][]int{res} + } + + result := combinations(n-1, k) + for _, c := range combinations(n-1, k-1) { + result = append(result, append(c, n-1)) + } + + return result } func TestSWAPGood(t *testing.T) {