From fab8dfb9f8630d2ed2ddb9d4ed96e4c506d437ab Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Fri, 8 Jul 2022 17:28:29 +0300 Subject: [PATCH] vm: move State type into a package of its own It's used a lot in other places that need it, but don't need whole VM at the same time. --- cli/contract_test.go | 14 +-- cli/executor_test.go | 4 +- cli/multisig_test.go | 4 +- cli/query/query.go | 5 +- cli/query_test.go | 5 +- pkg/compiler/native_test.go | 9 +- pkg/core/blockchain.go | 5 +- pkg/core/blockchain_neotest_test.go | 8 +- pkg/core/interop/contract/call.go | 3 +- pkg/core/native/ledger.go | 4 +- pkg/core/native/native_test/ledger_test.go | 10 +- .../native/native_test/management_test.go | 4 +- pkg/core/state/notification_event.go | 8 +- pkg/core/state/notification_event_test.go | 16 +-- pkg/neotest/basic.go | 5 +- pkg/neotest/client.go | 3 +- pkg/rpc/client/rpc_test.go | 4 +- pkg/rpc/server/client_test.go | 8 +- pkg/rpc/server/server_test.go | 13 +-- pkg/vm/json_test.go | 7 +- pkg/vm/state_test.go | 97 ------------------- pkg/vm/vm.go | 71 +++++++------- pkg/vm/{ => vmstate}/state.go | 50 +++++----- pkg/vm/vmstate/state_test.go | 97 +++++++++++++++++++ 24 files changed, 234 insertions(+), 220 deletions(-) delete mode 100644 pkg/vm/state_test.go rename pkg/vm/{ => vmstate}/state.go (51%) create mode 100644 pkg/vm/vmstate/state_test.go diff --git a/cli/contract_test.go b/cli/contract_test.go index a45680a1b..1db9e312d 100644 --- a/cli/contract_test.go +++ b/cli/contract_test.go @@ -23,8 +23,8 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/nef" "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/nspcc-dev/neo-go/pkg/wallet" "github.com/stretchr/testify/require" "gopkg.in/yaml.v3" @@ -354,7 +354,7 @@ func TestContractDeployWithData(t *testing.T) { res := new(result.Invoke) require.NoError(t, json.Unmarshal(e.Out.Bytes(), res)) - require.Equal(t, vm.HaltState.String(), res.State, res.FaultException) + require.Equal(t, vmstate.Halt.String(), res.State, res.FaultException) require.Len(t, res.Stack, 1) require.Equal(t, []byte{12}, res.Stack[0].Value()) @@ -366,7 +366,7 @@ func TestContractDeployWithData(t *testing.T) { res = new(result.Invoke) require.NoError(t, json.Unmarshal(e.Out.Bytes(), res)) - require.Equal(t, vm.HaltState.String(), res.State, res.FaultException) + require.Equal(t, vmstate.Halt.String(), res.State, res.FaultException) require.Len(t, res.Stack, 1) require.Equal(t, []byte("take_me_to_church"), res.Stack[0].Value()) } @@ -672,7 +672,7 @@ func TestComlileAndInvokeFunction(t *testing.T) { res := new(result.Invoke) require.NoError(t, json.Unmarshal(e.Out.Bytes(), res)) - require.Equal(t, vm.HaltState.String(), res.State, res.FaultException) + require.Equal(t, vmstate.Halt.String(), res.State, res.FaultException) require.Len(t, res.Stack, 1) require.Equal(t, []byte("on create|sub create"), res.Stack[0].Value()) @@ -821,7 +821,7 @@ func TestComlileAndInvokeFunction(t *testing.T) { e.Run(t, append(cmd, strconv.FormatInt(storage.FindKeysOnly, 10))...) res := new(result.Invoke) require.NoError(t, json.Unmarshal(e.Out.Bytes(), res)) - require.Equal(t, vm.HaltState.String(), res.State) + require.Equal(t, vmstate.Halt.String(), res.State) require.Len(t, res.Stack, 1) require.Equal(t, []stackitem.Item{ stackitem.Make("findkey1"), @@ -832,7 +832,7 @@ func TestComlileAndInvokeFunction(t *testing.T) { e.Run(t, append(cmd, strconv.FormatInt(storage.FindDefault, 10))...) res := new(result.Invoke) require.NoError(t, json.Unmarshal(e.Out.Bytes(), res)) - require.Equal(t, vm.HaltState.String(), res.State) + require.Equal(t, vmstate.Halt.String(), res.State) require.Len(t, res.Stack, 1) arr, ok := res.Stack[0].Value().([]stackitem.Item) @@ -883,7 +883,7 @@ func TestComlileAndInvokeFunction(t *testing.T) { res := new(result.Invoke) require.NoError(t, json.Unmarshal(e.Out.Bytes(), res)) - require.Equal(t, vm.HaltState.String(), res.State) + require.Equal(t, vmstate.Halt.String(), res.State) require.Len(t, res.Stack, 1) require.Equal(t, []byte("on update|sub update"), res.Stack[0].Value()) }) diff --git a/cli/executor_test.go b/cli/executor_test.go index a61b7f0a2..6f69a6881 100644 --- a/cli/executor_test.go +++ b/cli/executor_test.go @@ -23,7 +23,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/rpc/server" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/nspcc-dev/neo-go/pkg/vm" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/stretchr/testify/require" "github.com/urfave/cli" "go.uber.org/zap" @@ -293,7 +293,7 @@ func (e *executor) checkTxPersisted(t *testing.T, prefix ...string) (*transactio aer, err := e.Chain.GetAppExecResults(tx.Hash(), trigger.Application) require.NoError(t, err) require.Equal(t, 1, len(aer)) - require.Equal(t, vm.HaltState, aer[0].VMState) + require.Equal(t, vmstate.Halt, aer[0].VMState) return tx, height } diff --git a/cli/multisig_test.go b/cli/multisig_test.go index 654f131ce..c203a5d4c 100644 --- a/cli/multisig_test.go +++ b/cli/multisig_test.go @@ -15,7 +15,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/rpc/response/result" "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/nspcc-dev/neo-go/pkg/vm" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/stretchr/testify/require" ) @@ -154,7 +154,7 @@ func TestSignMultisigTx(t *testing.T) { e.checkTxTestInvokeOutput(t, 11) res := new(result.Invoke) require.NoError(t, json.Unmarshal(e.Out.Bytes(), res)) - require.Equal(t, vm.HaltState.String(), res.State, res.FaultException) + require.Equal(t, vmstate.Halt.String(), res.State, res.FaultException) }) e.In.WriteString("pass\r") diff --git a/cli/query/query.go b/cli/query/query.go index 2bd29dcc9..6dd387ac6 100644 --- a/cli/query/query.go +++ b/cli/query/query.go @@ -22,6 +22,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/urfave/cli" ) @@ -129,7 +130,7 @@ func DumpApplicationLog( if len(res.Executions) != 1 { _, _ = tw.Write([]byte("Success:\tunknown (no execution data)\n")) } else { - _, _ = tw.Write([]byte(fmt.Sprintf("Success:\t%t\n", res.Executions[0].VMState == vm.HaltState))) + _, _ = tw.Write([]byte(fmt.Sprintf("Success:\t%t\n", res.Executions[0].VMState == vmstate.Halt))) } } if verbose { @@ -146,7 +147,7 @@ func DumpApplicationLog( v.PrintOps(tw) if res != nil { for _, e := range res.Executions { - if e.VMState != vm.HaltState { + if e.VMState != vmstate.Halt { _, _ = tw.Write([]byte("Exception:\t" + e.FaultException + "\n")) } } diff --git a/cli/query_test.go b/cli/query_test.go index 02a29e36d..850c0a480 100644 --- a/cli/query_test.go +++ b/cli/query_test.go @@ -16,6 +16,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/nspcc-dev/neo-go/pkg/wallet" "github.com/stretchr/testify/require" ) @@ -117,7 +118,7 @@ func (e *executor) compareQueryTxVerbose(t *testing.T, tx *transaction.Transacti e.checkNextLine(t, `BlockHash:\s+`+e.Chain.GetHeaderHash(int(height)).StringLE()) res, _ := e.Chain.GetAppExecResults(tx.Hash(), trigger.Application) - e.checkNextLine(t, fmt.Sprintf(`Success:\s+%t`, res[0].Execution.VMState == vm.HaltState)) + e.checkNextLine(t, fmt.Sprintf(`Success:\s+%t`, res[0].Execution.VMState == vmstate.Halt)) for _, s := range tx.Signers { e.checkNextLine(t, fmt.Sprintf(`Signer:\s+%s\s*\(%s\)`, address.Uint160ToString(s.Account), s.Scopes.String())) } @@ -132,7 +133,7 @@ func (e *executor) compareQueryTxVerbose(t *testing.T, tx *transaction.Transacti } e.checkScriptDump(t, n) - if res[0].Execution.VMState != vm.HaltState { + if res[0].Execution.VMState != vmstate.Halt { e.checkNextLine(t, `Exception:\s+`+regexp.QuoteMeta(res[0].Execution.FaultException)) } e.checkEOF(t) diff --git a/pkg/compiler/native_test.go b/pkg/compiler/native_test.go index a486e49c0..6b8027205 100644 --- a/pkg/compiler/native_test.go +++ b/pkg/compiler/native_test.go @@ -29,6 +29,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/nef" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/stretchr/testify/require" ) @@ -116,10 +117,10 @@ func TestLedgerTransactionWitnessCondition(t *testing.T) { } func TestLedgerVMStates(t *testing.T) { - require.EqualValues(t, ledger.NoneState, vm.NoneState) - require.EqualValues(t, ledger.HaltState, vm.HaltState) - require.EqualValues(t, ledger.FaultState, vm.FaultState) - require.EqualValues(t, ledger.BreakState, vm.BreakState) + require.EqualValues(t, ledger.NoneState, vmstate.None) + require.EqualValues(t, ledger.HaltState, vmstate.Halt) + require.EqualValues(t, ledger.FaultState, vmstate.Fault) + require.EqualValues(t, ledger.BreakState, vmstate.Break) } type nativeTestCase struct { diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index c0e386c8c..e1ddea8ac 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -40,6 +40,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "go.uber.org/zap" ) @@ -816,7 +817,7 @@ func (bc *Blockchain) notificationDispatcher() { for ch := range executionFeed { ch <- aer } - if aer.VMState == vm.HaltState { + if aer.VMState == vmstate.Halt { for i := range aer.Events { for ch := range notificationFeed { ch <- &subscriptions.NotificationEvent{ @@ -1065,7 +1066,7 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error err = fmt.Errorf("failed to store exec result: %w", err) break } - if aer.Execution.VMState == vm.HaltState { + if aer.Execution.VMState == vmstate.Halt { for j := range aer.Execution.Events { bc.handleNotification(&aer.Execution.Events[j], kvcache, transCache, block, aer.Container) } diff --git a/pkg/core/blockchain_neotest_test.go b/pkg/core/blockchain_neotest_test.go index 50788ed90..e53a12d94 100644 --- a/pkg/core/blockchain_neotest_test.go +++ b/pkg/core/blockchain_neotest_test.go @@ -39,10 +39,10 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/nspcc-dev/neo-go/pkg/wallet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -826,7 +826,7 @@ func TestBlockchain_Subscriptions(t *testing.T) { exec := <-executionCh require.Equal(t, b.Hash(), exec.Container) - require.Equal(t, exec.VMState, vm.HaltState) + require.Equal(t, exec.VMState, vmstate.Halt) // 3 burn events for every tx and 1 mint for primary node require.True(t, len(notificationCh) >= 4) @@ -841,7 +841,7 @@ func TestBlockchain_Subscriptions(t *testing.T) { require.Equal(t, txExpected, tx) exec := <-executionCh require.Equal(t, tx.Hash(), exec.Container) - if exec.VMState == vm.HaltState { + if exec.VMState == vmstate.Halt { notif := <-notificationCh require.Equal(t, hash.Hash160(tx.Script), notif.ScriptHash) } @@ -855,7 +855,7 @@ func TestBlockchain_Subscriptions(t *testing.T) { exec = <-executionCh require.Equal(t, b.Hash(), exec.Container) - require.Equal(t, exec.VMState, vm.HaltState) + require.Equal(t, exec.VMState, vmstate.Halt) bc.UnsubscribeFromBlocks(blockCh) bc.UnsubscribeFromTransactions(txCh) diff --git a/pkg/core/interop/contract/call.go b/pkg/core/interop/contract/call.go index 7bed15e45..3f47fd5ca 100644 --- a/pkg/core/interop/contract/call.go +++ b/pkg/core/interop/contract/call.go @@ -14,7 +14,6 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -168,7 +167,7 @@ func CallFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract return fmt.Errorf("%w: %v", ErrNativeCall, err) } } - if ic.VM.State() == vm.FaultState { + if ic.VM.HasFailed() { return ErrNativeCall } return nil diff --git a/pkg/core/native/ledger.go b/pkg/core/native/ledger.go index e75f476eb..708bc7e6d 100644 --- a/pkg/core/native/ledger.go +++ b/pkg/core/native/ledger.go @@ -13,8 +13,8 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" ) // Ledger provides an interface to blocks/transactions storage for smart @@ -169,7 +169,7 @@ func (l *Ledger) getTransactionVMState(ic *interop.Context, params []stackitem.I } h, _, aer, err := ic.DAO.GetTxExecResult(hash) if err != nil || !isTraceableBlock(ic, h) { - return stackitem.Make(vm.NoneState) + return stackitem.Make(vmstate.None) } return stackitem.Make(aer.VMState) } diff --git a/pkg/core/native/native_test/ledger_test.go b/pkg/core/native/native_test/ledger_test.go index dbdd3a7d5..83455f4c5 100644 --- a/pkg/core/native/native_test/ledger_test.go +++ b/pkg/core/native/native_test/ledger_test.go @@ -13,9 +13,9 @@ import ( "github.com/nspcc-dev/neo-go/pkg/neotest" "github.com/nspcc-dev/neo-go/pkg/neotest/chain" "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/stretchr/testify/require" ) @@ -56,22 +56,22 @@ func TestLedger_GetTransactionState(t *testing.T) { hash := e.InvokeScript(t, []byte{byte(opcode.RET)}, []neotest.Signer{c.Committee}) t.Run("unknown transaction", func(t *testing.T) { - ledgerInvoker.Invoke(t, vm.NoneState, "getTransactionVMState", util.Uint256{1, 2, 3}) + ledgerInvoker.Invoke(t, vmstate.None, "getTransactionVMState", util.Uint256{1, 2, 3}) }) t.Run("not a hash", func(t *testing.T) { ledgerInvoker.InvokeFail(t, "expected []byte of size 32", "getTransactionVMState", []byte{1, 2, 3}) }) t.Run("good: HALT", func(t *testing.T) { - ledgerInvoker.Invoke(t, vm.HaltState, "getTransactionVMState", hash) + ledgerInvoker.Invoke(t, vmstate.Halt, "getTransactionVMState", hash) }) t.Run("isn't traceable", func(t *testing.T) { // Add more blocks so that tx becomes untraceable. e.GenerateNewBlocks(t, int(e.Chain.GetConfig().MaxTraceableBlocks)) - ledgerInvoker.Invoke(t, vm.NoneState, "getTransactionVMState", hash) + ledgerInvoker.Invoke(t, vmstate.None, "getTransactionVMState", hash) }) t.Run("good: FAULT", func(t *testing.T) { faultedH := e.InvokeScript(t, []byte{byte(opcode.ABORT)}, []neotest.Signer{c.Committee}) - ledgerInvoker.Invoke(t, vm.FaultState, "getTransactionVMState", faultedH) + ledgerInvoker.Invoke(t, vmstate.Fault, "getTransactionVMState", faultedH) }) } diff --git a/pkg/core/native/native_test/management_test.go b/pkg/core/native/native_test/management_test.go index ffbc9f37a..50d1b894d 100644 --- a/pkg/core/native/native_test/management_test.go +++ b/pkg/core/native/native_test/management_test.go @@ -21,10 +21,10 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/nef" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" - "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/stretchr/testify/require" ) @@ -69,7 +69,7 @@ func TestManagement_ContractCache(t *testing.T) { managementInvoker.CheckHalt(t, tx1.Hash()) aer, err := managementInvoker.Chain.GetAppExecResults(tx2.Hash(), trigger.Application) require.NoError(t, err) - require.Equal(t, vm.HaltState, aer[0].VMState, aer[0].FaultException) + require.Equal(t, vmstate.Halt, aer[0].VMState, aer[0].FaultException) require.NotEqual(t, stackitem.Null{}, aer[0].Stack) } diff --git a/pkg/core/state/notification_event.go b/pkg/core/state/notification_event.go index d87bb40e0..317df0808 100644 --- a/pkg/core/state/notification_event.go +++ b/pkg/core/state/notification_event.go @@ -8,8 +8,8 @@ import ( "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" ) // NotificationEvent is a tuple of the scripthash that has emitted the Item as a @@ -94,7 +94,7 @@ func (aer *AppExecResult) EncodeBinaryWithContext(w *io.BinWriter, sc *stackitem func (aer *AppExecResult) DecodeBinary(r *io.BinReader) { r.ReadBytes(aer.Container[:]) aer.Trigger = trigger.Type(r.ReadB()) - aer.VMState = vm.State(r.ReadB()) + aer.VMState = vmstate.State(r.ReadB()) aer.GasConsumed = int64(r.ReadU64LE()) sz := r.ReadVarUint() if stackitem.MaxDeserialized < sz && r.Err == nil { @@ -197,7 +197,7 @@ func (aer *AppExecResult) UnmarshalJSON(data []byte) error { // all resulting notifications, state, stack and other metadata. type Execution struct { Trigger trigger.Type - VMState vm.State + VMState vmstate.State GasConsumed int64 Stack []stackitem.Item Events []NotificationEvent @@ -266,7 +266,7 @@ func (e *Execution) UnmarshalJSON(data []byte) error { return err } e.Trigger = trigger - state, err := vm.StateFromString(aux.VMState) + state, err := vmstate.FromString(aux.VMState) if err != nil { return err } diff --git a/pkg/core/state/notification_event_test.go b/pkg/core/state/notification_event_test.go index 896e2b82a..0a4a11082 100644 --- a/pkg/core/state/notification_event_test.go +++ b/pkg/core/state/notification_event_test.go @@ -8,8 +8,8 @@ import ( "github.com/nspcc-dev/neo-go/internal/testserdes" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" - "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/stretchr/testify/require" ) @@ -18,7 +18,7 @@ func BenchmarkAppExecResult_EncodeBinary(b *testing.B) { Container: random.Uint256(), Execution: Execution{ Trigger: trigger.Application, - VMState: vm.HaltState, + VMState: vmstate.Halt, GasConsumed: 12345, Stack: []stackitem.Item{}, Events: []NotificationEvent{{ @@ -54,7 +54,7 @@ func TestEncodeDecodeAppExecResult(t *testing.T) { Container: random.Uint256(), Execution: Execution{ Trigger: 1, - VMState: vm.HaltState, + VMState: vmstate.Halt, GasConsumed: 10, Stack: []stackitem.Item{stackitem.NewBool(true)}, Events: []NotificationEvent{}, @@ -63,12 +63,12 @@ func TestEncodeDecodeAppExecResult(t *testing.T) { } t.Run("halt", func(t *testing.T) { appExecResult := newAer() - appExecResult.VMState = vm.HaltState + appExecResult.VMState = vmstate.Halt testserdes.EncodeDecodeBinary(t, appExecResult, new(AppExecResult)) }) t.Run("fault", func(t *testing.T) { appExecResult := newAer() - appExecResult.VMState = vm.FaultState + appExecResult.VMState = vmstate.Fault testserdes.EncodeDecodeBinary(t, appExecResult, new(AppExecResult)) }) t.Run("with interop", func(t *testing.T) { @@ -150,7 +150,7 @@ func TestMarshalUnmarshalJSONAppExecResult(t *testing.T) { Container: random.Uint256(), Execution: Execution{ Trigger: trigger.Application, - VMState: vm.HaltState, + VMState: vmstate.Halt, GasConsumed: 10, Stack: []stackitem.Item{}, Events: []NotificationEvent{}, @@ -164,7 +164,7 @@ func TestMarshalUnmarshalJSONAppExecResult(t *testing.T) { Container: random.Uint256(), Execution: Execution{ Trigger: trigger.Application, - VMState: vm.FaultState, + VMState: vmstate.Fault, GasConsumed: 10, Stack: []stackitem.Item{stackitem.NewBool(true)}, Events: []NotificationEvent{}, @@ -178,7 +178,7 @@ func TestMarshalUnmarshalJSONAppExecResult(t *testing.T) { Container: random.Uint256(), Execution: Execution{ Trigger: trigger.OnPersist, - VMState: vm.HaltState, + VMState: vmstate.Halt, GasConsumed: 10, Stack: []stackitem.Item{}, Events: []NotificationEvent{}, diff --git a/pkg/neotest/basic.go b/pkg/neotest/basic.go index dd9077887..1a5879ceb 100644 --- a/pkg/neotest/basic.go +++ b/pkg/neotest/basic.go @@ -22,6 +22,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/nspcc-dev/neo-go/pkg/wallet" "github.com/stretchr/testify/require" ) @@ -210,7 +211,7 @@ func (e *Executor) InvokeScriptCheckFAULT(t testing.TB, script []byte, signers [ func (e *Executor) CheckHalt(t testing.TB, h util.Uint256, stack ...stackitem.Item) *state.AppExecResult { aer, err := e.Chain.GetAppExecResults(h, trigger.Application) require.NoError(t, err) - require.Equal(t, vm.HaltState, aer[0].VMState, aer[0].FaultException) + require.Equal(t, vmstate.Halt, aer[0].VMState, aer[0].FaultException) if len(stack) != 0 { require.Equal(t, stack, aer[0].Stack) } @@ -222,7 +223,7 @@ func (e *Executor) CheckHalt(t testing.TB, h util.Uint256, stack ...stackitem.It func (e *Executor) CheckFault(t testing.TB, h util.Uint256, s string) { aer, err := e.Chain.GetAppExecResults(h, trigger.Application) require.NoError(t, err) - require.Equal(t, vm.FaultState, aer[0].VMState) + require.Equal(t, vmstate.Fault, aer[0].VMState) require.True(t, strings.Contains(aer[0].FaultException, s), "expected: %s, got: %s", s, aer[0].FaultException) } diff --git a/pkg/neotest/client.go b/pkg/neotest/client.go index 513780d7f..b9d676f2b 100644 --- a/pkg/neotest/client.go +++ b/pkg/neotest/client.go @@ -9,6 +9,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/stretchr/testify/require" ) @@ -91,7 +92,7 @@ func (c *ContractInvoker) InvokeAndCheck(t testing.TB, checkResult func(t testin c.AddNewBlock(t, tx) aer, err := c.Chain.GetAppExecResults(tx.Hash(), trigger.Application) require.NoError(t, err) - require.Equal(t, vm.HaltState, aer[0].VMState, aer[0].FaultException) + require.Equal(t, vmstate.Halt, aer[0].VMState, aer[0].FaultException) if checkResult != nil { checkResult(t, aer[0].Stack) } diff --git a/pkg/rpc/client/rpc_test.go b/pkg/rpc/client/rpc_test.go index 043e7a83f..c50c2b271 100644 --- a/pkg/rpc/client/rpc_test.go +++ b/pkg/rpc/client/rpc_test.go @@ -34,9 +34,9 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/nef" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -129,7 +129,7 @@ var rpcClientTestCases = map[string][]rpcClientTestCase{ Executions: []state.Execution{ { Trigger: trigger.Application, - VMState: vm.HaltState, + VMState: vmstate.Halt, GasConsumed: 1, Stack: []stackitem.Item{stackitem.NewBigInteger(big.NewInt(1))}, Events: []state.NotificationEvent{}, diff --git a/pkg/rpc/server/client_test.go b/pkg/rpc/server/client_test.go index b0be1e62b..40df78ee8 100644 --- a/pkg/rpc/server/client_test.go +++ b/pkg/rpc/server/client_test.go @@ -35,10 +35,10 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/nspcc-dev/neo-go/pkg/wallet" "github.com/stretchr/testify/require" ) @@ -648,7 +648,7 @@ func TestSignAndPushP2PNotaryRequest(t *testing.T) { require.NoError(t, err) require.Equal(t, 1, len(appLogs)) appLog := appLogs[0] - require.Equal(t, vm.HaltState, appLog.VMState) + require.Equal(t, vmstate.Halt, appLog.VMState) require.Equal(t, appLog.GasConsumed, req.FallbackTransaction.SystemFee) }) } @@ -1282,7 +1282,7 @@ func TestClient_InvokeAndPackIteratorResults(t *testing.T) { t.Run("default max items constraint", func(t *testing.T) { res, err := c.InvokeAndPackIteratorResults(storageHash, "iterateOverValues", []smartcontract.Parameter{}, nil) require.NoError(t, err) - require.Equal(t, vm.HaltState.String(), res.State) + require.Equal(t, vmstate.Halt.String(), res.State) require.Equal(t, 1, len(res.Stack)) require.Equal(t, stackitem.ArrayT, res.Stack[0].Type()) arr, ok := res.Stack[0].Value().([]stackitem.Item) @@ -1298,7 +1298,7 @@ func TestClient_InvokeAndPackIteratorResults(t *testing.T) { max := 123 res, err := c.InvokeAndPackIteratorResults(storageHash, "iterateOverValues", []smartcontract.Parameter{}, nil, max) require.NoError(t, err) - require.Equal(t, vm.HaltState.String(), res.State) + require.Equal(t, vmstate.Halt.String(), res.State) require.Equal(t, 1, len(res.Stack)) require.Equal(t, stackitem.ArrayT, res.Stack[0].Type()) arr, ok := res.Stack[0].Value().([]stackitem.Item) diff --git a/pkg/rpc/server/server_test.go b/pkg/rpc/server/server_test.go index 6e80c8eee..412c1f692 100644 --- a/pkg/rpc/server/server_test.go +++ b/pkg/rpc/server/server_test.go @@ -45,6 +45,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/nspcc-dev/neo-go/pkg/wallet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -102,7 +103,7 @@ var rpcTestCases = map[string][]rpcTestCase{ assert.Equal(t, 1, len(res.Executions)) assert.Equal(t, expectedTxHash, res.Container) assert.Equal(t, trigger.Application, res.Executions[0].Trigger) - assert.Equal(t, vm.HaltState, res.Executions[0].VMState) + assert.Equal(t, vmstate.Halt, res.Executions[0].VMState) }, }, { @@ -116,7 +117,7 @@ var rpcTestCases = map[string][]rpcTestCase{ assert.Equal(t, 2, len(res.Executions)) assert.Equal(t, trigger.OnPersist, res.Executions[0].Trigger) assert.Equal(t, trigger.PostPersist, res.Executions[1].Trigger) - assert.Equal(t, vm.HaltState, res.Executions[0].VMState) + assert.Equal(t, vmstate.Halt, res.Executions[0].VMState) }, }, { @@ -129,7 +130,7 @@ var rpcTestCases = map[string][]rpcTestCase{ assert.Equal(t, genesisBlockHash, res.Container.StringLE()) assert.Equal(t, 1, len(res.Executions)) assert.Equal(t, trigger.PostPersist, res.Executions[0].Trigger) - assert.Equal(t, vm.HaltState, res.Executions[0].VMState) + assert.Equal(t, vmstate.Halt, res.Executions[0].VMState) }, }, { @@ -142,7 +143,7 @@ var rpcTestCases = map[string][]rpcTestCase{ assert.Equal(t, genesisBlockHash, res.Container.StringLE()) assert.Equal(t, 1, len(res.Executions)) assert.Equal(t, trigger.OnPersist, res.Executions[0].Trigger) - assert.Equal(t, vm.HaltState, res.Executions[0].VMState) + assert.Equal(t, vmstate.Halt, res.Executions[0].VMState) }, }, { @@ -1966,9 +1967,9 @@ func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) [] require.NoError(t, json.Unmarshal(data, &res)) require.Equal(t, 2, len(res.Executions)) require.Equal(t, trigger.OnPersist, res.Executions[0].Trigger) - require.Equal(t, vm.HaltState, res.Executions[0].VMState) + require.Equal(t, vmstate.Halt, res.Executions[0].VMState) require.Equal(t, trigger.PostPersist, res.Executions[1].Trigger) - require.Equal(t, vm.HaltState, res.Executions[1].VMState) + require.Equal(t, vmstate.Halt, res.Executions[1].VMState) }) t.Run("submit", func(t *testing.T) { diff --git a/pkg/vm/json_test.go b/pkg/vm/json_test.go index 5d7050dbb..86a68540d 100644 --- a/pkg/vm/json_test.go +++ b/pkg/vm/json_test.go @@ -18,6 +18,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/stretchr/testify/require" ) @@ -44,7 +45,7 @@ type ( } vmUTExecutionEngineState struct { - State State `json:"state"` + State vmstate.State `json:"state"` ResultStack []vmUTStackItem `json:"resultStack"` InvocationStack []vmUTExecutionContextState `json:"invocationStack"` } @@ -152,14 +153,14 @@ func testFile(t *testing.T, filename string) { t.Run(ut.Tests[i].Name, func(t *testing.T) { prog := []byte(test.Script) vm := load(prog) - vm.state = BreakState + vm.state = vmstate.Break vm.SyscallHandler = testSyscallHandler for i := range test.Steps { execStep(t, vm, test.Steps[i]) result := test.Steps[i].Result require.Equal(t, result.State, vm.state) - if result.State == FaultState { // do not compare stacks on fault + if result.State == vmstate.Fault { // do not compare stacks on fault continue } diff --git a/pkg/vm/state_test.go b/pkg/vm/state_test.go deleted file mode 100644 index 7c4d87015..000000000 --- a/pkg/vm/state_test.go +++ /dev/null @@ -1,97 +0,0 @@ -package vm - -import ( - "encoding/json" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestStateFromString(t *testing.T) { - var ( - s State - err error - ) - - s, err = StateFromString("HALT") - assert.NoError(t, err) - assert.Equal(t, HaltState, s) - - s, err = StateFromString("BREAK") - assert.NoError(t, err) - assert.Equal(t, BreakState, s) - - s, err = StateFromString("FAULT") - assert.NoError(t, err) - assert.Equal(t, FaultState, s) - - s, err = StateFromString("NONE") - assert.NoError(t, err) - assert.Equal(t, NoneState, s) - - s, err = StateFromString("HALT, BREAK") - assert.NoError(t, err) - assert.Equal(t, HaltState|BreakState, s) - - s, err = StateFromString("FAULT, BREAK") - assert.NoError(t, err) - assert.Equal(t, FaultState|BreakState, s) - - _, err = StateFromString("HALT, KEK") - assert.Error(t, err) -} - -func TestState_HasFlag(t *testing.T) { - assert.True(t, HaltState.HasFlag(HaltState)) - assert.True(t, BreakState.HasFlag(BreakState)) - assert.True(t, FaultState.HasFlag(FaultState)) - assert.True(t, (HaltState | BreakState).HasFlag(HaltState)) - assert.True(t, (HaltState | BreakState).HasFlag(BreakState)) - - assert.False(t, HaltState.HasFlag(BreakState)) - assert.False(t, NoneState.HasFlag(HaltState)) - assert.False(t, (FaultState | BreakState).HasFlag(HaltState)) -} - -func TestState_MarshalJSON(t *testing.T) { - var ( - data []byte - err error - ) - - data, err = json.Marshal(HaltState | BreakState) - assert.NoError(t, err) - assert.Equal(t, data, []byte(`"HALT, BREAK"`)) - - data, err = json.Marshal(FaultState) - assert.NoError(t, err) - assert.Equal(t, data, []byte(`"FAULT"`)) -} - -func TestState_UnmarshalJSON(t *testing.T) { - var ( - s State - err error - ) - - err = json.Unmarshal([]byte(`"HALT, BREAK"`), &s) - assert.NoError(t, err) - assert.Equal(t, HaltState|BreakState, s) - - err = json.Unmarshal([]byte(`"FAULT, BREAK"`), &s) - assert.NoError(t, err) - assert.Equal(t, FaultState|BreakState, s) - - err = json.Unmarshal([]byte(`"NONE"`), &s) - assert.NoError(t, err) - assert.Equal(t, NoneState, s) -} - -// TestState_EnumCompat tests that byte value of State matches the C#'s one got from -// https://github.com/neo-project/neo-vm/blob/0028d862e253bda3c12eb8bb007a2d95822d3922/src/neo-vm/VMState.cs#L16. -func TestState_EnumCompat(t *testing.T) { - assert.Equal(t, byte(0), byte(NoneState)) - assert.Equal(t, byte(1<<0), byte(HaltState)) - assert.Equal(t, byte(1<<1), byte(FaultState)) - assert.Equal(t, byte(1<<2), byte(BreakState)) -} diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 95e03a2c1..d29b2f9fc 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -23,6 +23,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/util/slice" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" ) type errorAtInstruct struct { @@ -62,7 +63,7 @@ type SyscallHandler = func(*VM, uint32) error // VM represents the virtual machine. type VM struct { - state State + state vmstate.State // callback to get interop price getPrice func(opcode.Opcode, []byte) int64 @@ -104,7 +105,7 @@ func New() *VM { // NewWithTrigger returns a new VM for executions triggered by t. func NewWithTrigger(t trigger.Type) *VM { vm := &VM{ - state: NoneState, + state: vmstate.None, trigger: t, SyscallHandler: defaultSyscallHandler, @@ -126,7 +127,7 @@ func (v *VM) SetPriceGetter(f func(opcode.Opcode, []byte) int64) { // more efficient. It reuses invocation and evaluation stacks as well as VM structure // itself. func (v *VM) Reset(t trigger.Type) { - v.state = NoneState + v.state = vmstate.None v.getPrice = nil v.istack.elems = v.istack.elems[:0] v.estack.elems = v.estack.elems[:0] @@ -288,7 +289,7 @@ func (v *VM) LoadWithFlags(prog []byte, f callflag.CallFlag) { // Clear all stacks and state, it could be a reload. v.istack.Clear() v.estack.Clear() - v.state = NoneState + v.state = vmstate.None v.gasConsumed = 0 v.invTree = nil v.LoadScriptWithFlags(prog, f) @@ -398,7 +399,7 @@ func dumpStack(s *Stack) string { } // State returns the state for the VM. -func (v *VM) State() State { +func (v *VM) State() vmstate.State { return v.state } @@ -413,39 +414,39 @@ func (v *VM) Run() error { var ctx *Context if !v.Ready() { - v.state = FaultState + v.state = vmstate.Fault return errors.New("no program loaded") } - if v.state.HasFlag(FaultState) { + if v.state.HasFlag(vmstate.Fault) { // VM already ran something and failed, in general its state is // undefined in this case so we can't run anything. return errors.New("VM has failed") } - // HaltState (the default) or BreakState are safe to continue. - v.state = NoneState + // vmstate.Halt (the default) or vmstate.Break are safe to continue. + v.state = vmstate.None ctx = v.Context() for { switch { - case v.state.HasFlag(FaultState): + case v.state.HasFlag(vmstate.Fault): // Should be caught and reported already by the v.Step(), // but we're checking here anyway just in case. return errors.New("VM has failed") - case v.state.HasFlag(HaltState), v.state.HasFlag(BreakState): + case v.state.HasFlag(vmstate.Halt), v.state.HasFlag(vmstate.Break): // Normal exit from this loop. return nil - case v.state == NoneState: + case v.state == vmstate.None: if err := v.step(ctx); err != nil { return err } default: - v.state = FaultState + v.state = vmstate.Fault return errors.New("unknown state") } // check for breakpoint before executing the next instruction ctx = v.Context() if ctx != nil && ctx.atBreakPoint() { - v.state = BreakState + v.state = vmstate.Break } } } @@ -460,7 +461,7 @@ func (v *VM) Step() error { func (v *VM) step(ctx *Context) error { op, param, err := ctx.Next() if err != nil { - v.state = FaultState + v.state = vmstate.Fault return newError(ctx.ip, op, err) } return v.execute(ctx, op, param) @@ -472,7 +473,7 @@ func (v *VM) StepInto() error { ctx := v.Context() if ctx == nil { - v.state = HaltState + v.state = vmstate.Halt } if v.HasStopped() { @@ -482,7 +483,7 @@ func (v *VM) StepInto() error { if ctx != nil && ctx.prog != nil { op, param, err := ctx.Next() if err != nil { - v.state = FaultState + v.state = vmstate.Fault return newError(ctx.ip, op, err) } vErr := v.execute(ctx, op, param) @@ -493,7 +494,7 @@ func (v *VM) StepInto() error { cctx := v.Context() if cctx != nil && cctx.atBreakPoint() { - v.state = BreakState + v.state = vmstate.Break } return nil } @@ -501,16 +502,16 @@ func (v *VM) StepInto() error { // StepOut takes the debugger to the line where the current function was called. func (v *VM) StepOut() error { var err error - if v.state == BreakState { - v.state = NoneState + if v.state == vmstate.Break { + v.state = vmstate.None } expSize := v.istack.Len() - for v.state == NoneState && v.istack.Len() >= expSize { + for v.state == vmstate.None && v.istack.Len() >= expSize { err = v.StepInto() } - if v.state == NoneState { - v.state = BreakState + if v.state == vmstate.None { + v.state = vmstate.Break } return err } @@ -523,20 +524,20 @@ func (v *VM) StepOver() error { return err } - if v.state == BreakState { - v.state = NoneState + if v.state == vmstate.Break { + v.state = vmstate.None } expSize := v.istack.Len() for { err = v.StepInto() - if !(v.state == NoneState && v.istack.Len() > expSize) { + if !(v.state == vmstate.None && v.istack.Len() > expSize) { break } } - if v.state == NoneState { - v.state = BreakState + if v.state == vmstate.None { + v.state = vmstate.Break } return err @@ -545,22 +546,22 @@ func (v *VM) StepOver() error { // HasFailed returns whether the VM is in the failed state now. Usually, it's used to // check status after Run. func (v *VM) HasFailed() bool { - return v.state.HasFlag(FaultState) + return v.state.HasFlag(vmstate.Fault) } // HasStopped returns whether the VM is in the Halt or Failed state. func (v *VM) HasStopped() bool { - return v.state.HasFlag(HaltState) || v.state.HasFlag(FaultState) + return v.state.HasFlag(vmstate.Halt) || v.state.HasFlag(vmstate.Fault) } // HasHalted returns whether the VM is in the Halt state. func (v *VM) HasHalted() bool { - return v.state.HasFlag(HaltState) + return v.state.HasFlag(vmstate.Halt) } // AtBreakpoint returns whether the VM is at breakpoint. func (v *VM) AtBreakpoint() bool { - return v.state.HasFlag(BreakState) + return v.state.HasFlag(vmstate.Break) } // GetInteropID converts instruction parameter to an interop ID. @@ -574,10 +575,10 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro // each panic at a central point, putting the VM in a fault state and setting error. defer func() { if errRecover := recover(); errRecover != nil { - v.state = FaultState + v.state = vmstate.Fault err = newError(ctx.ip, op, errRecover) } else if v.refs > MaxStackSize { - v.state = FaultState + v.state = vmstate.Fault err = newError(ctx.ip, op, "stack is too big") } }() @@ -1469,7 +1470,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro v.unloadContext(oldCtx) if v.istack.Len() == 0 { - v.state = HaltState + v.state = vmstate.Halt break } diff --git a/pkg/vm/state.go b/pkg/vm/vmstate/state.go similarity index 51% rename from pkg/vm/state.go rename to pkg/vm/vmstate/state.go index 64984b62b..75cf367af 100644 --- a/pkg/vm/state.go +++ b/pkg/vm/vmstate/state.go @@ -1,23 +1,29 @@ -package vm +/* +Package vmstate contains a set of VM state flags along with appropriate type. +It provides a set of conversion/marshaling functions/methods for this type as +well. This package is made to make VM state reusable across all of the other +components that need it without importing whole VM package. +*/ +package vmstate import ( "errors" "strings" ) -// State of the VM. +// State of the VM. It's a set of flags stored in the integer number. type State uint8 // Available States. const ( - // HaltState represents HALT VM state. - HaltState State = 1 << iota - // FaultState represents FAULT VM state. - FaultState - // BreakState represents BREAK VM state. - BreakState - // NoneState represents NONE VM state. - NoneState State = 0 + // Halt represents HALT VM state (finished normally). + Halt State = 1 << iota + // Fault represents FAULT VM state (finished with an error). + Fault + // Break represents BREAK VM state (running, debug mode). + Break + // None represents NONE VM state (not started yet). + None State = 0 ) // HasFlag checks for State flag presence. @@ -25,40 +31,40 @@ func (s State) HasFlag(f State) bool { return s&f != 0 } -// String implements the stringer interface. +// String implements the fmt.Stringer interface. func (s State) String() string { - if s == NoneState { + if s == None { return "NONE" } ss := make([]string, 0, 3) - if s.HasFlag(HaltState) { + if s.HasFlag(Halt) { ss = append(ss, "HALT") } - if s.HasFlag(FaultState) { + if s.HasFlag(Fault) { ss = append(ss, "FAULT") } - if s.HasFlag(BreakState) { + if s.HasFlag(Break) { ss = append(ss, "BREAK") } return strings.Join(ss, ", ") } -// StateFromString converts a string into the VM State. -func StateFromString(s string) (st State, err error) { +// FromString converts a string into the State. +func FromString(s string) (st State, err error) { if s = strings.TrimSpace(s); s == "NONE" { - return NoneState, nil + return None, nil } ss := strings.Split(s, ",") for _, state := range ss { switch state = strings.TrimSpace(state); state { case "HALT": - st |= HaltState + st |= Halt case "FAULT": - st |= FaultState + st |= Fault case "BREAK": - st |= BreakState + st |= Break default: return 0, errors.New("unknown state") } @@ -78,6 +84,6 @@ func (s *State) UnmarshalJSON(data []byte) (err error) { return errors.New("wrong format") } - *s, err = StateFromString(string(data[1 : l-1])) + *s, err = FromString(string(data[1 : l-1])) return } diff --git a/pkg/vm/vmstate/state_test.go b/pkg/vm/vmstate/state_test.go new file mode 100644 index 000000000..e7dcccce8 --- /dev/null +++ b/pkg/vm/vmstate/state_test.go @@ -0,0 +1,97 @@ +package vmstate + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFromString(t *testing.T) { + var ( + s State + err error + ) + + s, err = FromString("HALT") + assert.NoError(t, err) + assert.Equal(t, Halt, s) + + s, err = FromString("BREAK") + assert.NoError(t, err) + assert.Equal(t, Break, s) + + s, err = FromString("FAULT") + assert.NoError(t, err) + assert.Equal(t, Fault, s) + + s, err = FromString("NONE") + assert.NoError(t, err) + assert.Equal(t, None, s) + + s, err = FromString("HALT, BREAK") + assert.NoError(t, err) + assert.Equal(t, Halt|Break, s) + + s, err = FromString("FAULT, BREAK") + assert.NoError(t, err) + assert.Equal(t, Fault|Break, s) + + _, err = FromString("HALT, KEK") + assert.Error(t, err) +} + +func TestState_HasFlag(t *testing.T) { + assert.True(t, Halt.HasFlag(Halt)) + assert.True(t, Break.HasFlag(Break)) + assert.True(t, Fault.HasFlag(Fault)) + assert.True(t, (Halt | Break).HasFlag(Halt)) + assert.True(t, (Halt | Break).HasFlag(Break)) + + assert.False(t, Halt.HasFlag(Break)) + assert.False(t, None.HasFlag(Halt)) + assert.False(t, (Fault | Break).HasFlag(Halt)) +} + +func TestState_MarshalJSON(t *testing.T) { + var ( + data []byte + err error + ) + + data, err = json.Marshal(Halt | Break) + assert.NoError(t, err) + assert.Equal(t, data, []byte(`"HALT, BREAK"`)) + + data, err = json.Marshal(Fault) + assert.NoError(t, err) + assert.Equal(t, data, []byte(`"FAULT"`)) +} + +func TestState_UnmarshalJSON(t *testing.T) { + var ( + s State + err error + ) + + err = json.Unmarshal([]byte(`"HALT, BREAK"`), &s) + assert.NoError(t, err) + assert.Equal(t, Halt|Break, s) + + err = json.Unmarshal([]byte(`"FAULT, BREAK"`), &s) + assert.NoError(t, err) + assert.Equal(t, Fault|Break, s) + + err = json.Unmarshal([]byte(`"NONE"`), &s) + assert.NoError(t, err) + assert.Equal(t, None, s) +} + +// TestState_EnumCompat tests that byte value of State matches the C#'s one got from +// https://github.com/neo-project/neo-vm/blob/0028d862e253bda3c12eb8bb007a2d95822d3922/src/neo-vm/VMState.cs#L16. +func TestState_EnumCompat(t *testing.T) { + assert.Equal(t, byte(0), byte(None)) + assert.Equal(t, byte(1<<0), byte(Halt)) + assert.Equal(t, byte(1<<1), byte(Fault)) + assert.Equal(t, byte(1<<2), byte(Break)) +}