diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 086ab630b..b08795ef7 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -1220,24 +1220,8 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro if v.checkhash == nil { panic("VM is not set up properly for signature checks") } - sigok := true - // j counts keys and i counts signatures. - j := 0 - for i := 0; sigok && j < len(pkeys) && i < len(sigs); { - pkey := v.bytesToPublicKey(pkeys[j]) - // We only move to the next signature if the check was - // successful, but if it's not maybe the next key will - // fit, so we always move to the next key. - if pkey.Verify(sigs[i], v.checkhash) { - i++ - } - j++ - // When there are more signatures left to check than - // there are keys the check won't successed for sure. - if len(sigs)-i > len(pkeys)-j { - sigok = false - } - } + + sigok := checkMultisigPar(v, pkeys, sigs) v.estack.PushVal(sigok) case opcode.NEWMAP: @@ -1428,6 +1412,109 @@ func (v *VM) getJumpOffset(ctx *Context, parameter []byte) int { return offset } +func checkMultisigPar(v *VM, pkeys [][]byte, sigs [][]byte) bool { + if len(sigs) == 1 { + return checkMultisig1(v, pkeys, sigs[0]) + } + + k1, k2 := 0, len(pkeys)-1 + s1, s2 := 0, len(sigs)-1 + + type task struct { + pub *keys.PublicKey + signum int + } + + type verify struct { + ok bool + signum int + } + + worker := func(ch <-chan task, result chan verify) { + for { + t, ok := <-ch + if !ok { + return + } + + result <- verify{ + signum: t.signum, + ok: t.pub.Verify(sigs[t.signum], v.checkhash), + } + } + } + + const workerCount = 3 + tasks := make(chan task, 2) + results := make(chan verify, len(sigs)) + for i := 0; i < workerCount; i++ { + go worker(tasks, results) + } + + tasks <- task{pub: v.bytesToPublicKey(pkeys[k1]), signum: s1} + tasks <- task{pub: v.bytesToPublicKey(pkeys[k2]), signum: s2} + + sigok := true + taskCount := 2 + +loop: + for r := range results { + goingForward := true + + taskCount-- + if r.signum == s2 { + goingForward = false + } + if k1+1 == k2 { + sigok = r.ok && s1+1 == s2 + if taskCount != 0 && sigok { + continue + } + break loop + } else if r.ok { + if s1+1 == s2 { + if taskCount != 0 && sigok { + continue + } + break loop + } + if goingForward { + s1++ + } else { + s2-- + } + } + + var nextSig, nextKey int + if goingForward { + k1++ + nextSig = s1 + nextKey = k1 + } else { + k2-- + nextSig = s2 + nextKey = k2 + } + taskCount++ + tasks <- task{pub: v.bytesToPublicKey(pkeys[nextKey]), signum: nextSig} + } + + close(tasks) + + return sigok +} + +func checkMultisig1(v *VM, pkeys [][]byte, sig []byte) bool { + for i := range pkeys { + pkey := v.bytesToPublicKey(pkeys[i]) + if pkey.Verify(sig, v.checkhash) { + return true + } + } + + return false +} + func cloneIfStruct(item StackItem) StackItem { switch it := item.(type) { case *StructItem: diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index f20993c46..ff0116157 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -2702,24 +2702,95 @@ func TestCHECKMULTISIGBadSig(t *testing.T) { assert.Equal(t, false, vm.estack.Pop().Bool()) } -func TestCHECKMULTISIGGood(t *testing.T) { +func initCHECKMULTISIG(msg []byte, n int) ([]StackItem, []StackItem, map[string]*keys.PublicKey, error) { + var err error + + keyMap := make(map[string]*keys.PublicKey) + pkeys := make([]*keys.PrivateKey, n) + pubs := make([]StackItem, n) + for i := range pubs { + pkeys[i], err = keys.NewPrivateKey() + if err != nil { + return nil, nil, nil, err + } + + pk := pkeys[i].PublicKey() + data := pk.Bytes() + pubs[i] = NewByteArrayItem(data) + keyMap[string(data)] = pk + } + + sigs := make([]StackItem, n) + for i := range sigs { + sig := pkeys[i].Sign(msg) + sigs[i] = NewByteArrayItem(sig) + } + + return pubs, sigs, keyMap, nil +} + +func subSlice(arr []StackItem, indices []int) []StackItem { + if indices == nil { + return arr + } + + result := make([]StackItem, len(indices)) + for i, j := range indices { + result[i] = arr[j] + } + + return result +} + +func initCHECKMULTISIGVM(t *testing.T, n int, ik, is []int) *VM { prog := makeProgram(opcode.CHECKMULTISIG) - pk1, err := keys.NewPrivateKey() - assert.Nil(t, err) - pk2, err := keys.NewPrivateKey() - assert.Nil(t, err) + v := load(prog) msg := []byte("NEO - An Open Network For Smart Economy") - sig1 := pk1.Sign(msg) - sig2 := pk2.Sign(msg) - pbytes1 := pk1.PublicKey().Bytes() - pbytes2 := pk2.PublicKey().Bytes() - vm := load(prog) - vm.SetCheckedHash(hash.Sha256(msg).BytesBE()) - vm.estack.PushVal([]StackItem{NewByteArrayItem(sig1), NewByteArrayItem(sig2)}) - vm.estack.PushVal([]StackItem{NewByteArrayItem(pbytes1), NewByteArrayItem(pbytes2)}) - runVM(t, vm) - assert.Equal(t, 1, vm.estack.Len()) - assert.Equal(t, true, vm.estack.Pop().Bool()) + + v.SetCheckedHash(hash.Sha256(msg).BytesBE()) + + pubs, sigs, _, err := initCHECKMULTISIG(msg, n) + require.NoError(t, err) + + pubs = subSlice(pubs, ik) + sigs = subSlice(sigs, is) + + v.estack.PushVal(sigs) + v.estack.PushVal(pubs) + + return v +} + +func testCHECKMULTISIGGood(t *testing.T, n int, is []int) { + v := initCHECKMULTISIGVM(t, n, nil, is) + + runVM(t, v) + assert.Equal(t, 1, v.estack.Len()) + assert.True(t, v.estack.Pop().Bool()) +} + +func TestCHECKMULTISIGGood(t *testing.T) { + t.Run("3_1", func(t *testing.T) { testCHECKMULTISIGGood(t, 3, []int{1}) }) + t.Run("2_2", func(t *testing.T) { testCHECKMULTISIGGood(t, 2, []int{0, 1}) }) + t.Run("3_3", func(t *testing.T) { testCHECKMULTISIGGood(t, 3, []int{0, 1, 2}) }) + t.Run("3_2", func(t *testing.T) { testCHECKMULTISIGGood(t, 3, []int{0, 2}) }) + t.Run("4_2", func(t *testing.T) { testCHECKMULTISIGGood(t, 4, []int{0, 2}) }) + t.Run("10_7", func(t *testing.T) { testCHECKMULTISIGGood(t, 10, []int{2, 3, 4, 5, 6, 8, 9}) }) + t.Run("12_9", func(t *testing.T) { testCHECKMULTISIGGood(t, 12, []int{0, 1, 4, 5, 6, 7, 8, 9}) }) +} + +func testCHECKMULTISIGBad(t *testing.T, n int, ik, is []int) { + v := initCHECKMULTISIGVM(t, n, ik, is) + + runVM(t, v) + assert.Equal(t, 1, v.estack.Len()) + assert.False(t, v.estack.Pop().Bool()) +} + +func TestCHECKMULTISIGBad(t *testing.T) { + t.Run("1_1 wrong signature", func(t *testing.T) { testCHECKMULTISIGBad(t, 2, []int{0}, []int{1}) }) + t.Run("3_2 wrong order", func(t *testing.T) { testCHECKMULTISIGBad(t, 3, []int{0, 2}, []int{2, 0}) }) + t.Run("3_2 duplicate sig", func(t *testing.T) { testCHECKMULTISIGBad(t, 3, nil, []int{0, 0}) }) } func TestSWAPGood(t *testing.T) {