diff --git a/pkg/rpc/types.go b/pkg/rpc/types.go index 25b0e50af..3a51fd9ed 100644 --- a/pkg/rpc/types.go +++ b/pkg/rpc/types.go @@ -1,6 +1,9 @@ package rpc -import "github.com/CityOfZion/neo-go/pkg/core/transaction" +import ( + "github.com/CityOfZion/neo-go/pkg/core/transaction" + "github.com/CityOfZion/neo-go/pkg/vm" +) type InvokeScriptResponse struct { responseHeader @@ -11,9 +14,9 @@ type InvokeScriptResponse struct { // InvokeResult represents the outcome of a script that is // executed by the NEO VM. type InvokeResult struct { - State string `json:"state"` - GasConsumed string `json:"gas_consumed"` - Script string `json:"script"` + State vm.State `json:"state"` + GasConsumed string `json:"gas_consumed"` + Script string `json:"script"` Stack []StackParam } diff --git a/pkg/vm/state.go b/pkg/vm/state.go index 3ec884b68..62d4367b2 100644 --- a/pkg/vm/state.go +++ b/pkg/vm/state.go @@ -1,25 +1,75 @@ package vm +import ( + "strings" + + "github.com/pkg/errors" +) + // State of the VM. -type State uint +type State uint8 // Available States. const ( - noneState State = iota - haltState + noneState State = 0 + haltState State = 1 << iota faultState breakState ) +func (s State) HasFlag(f State) bool { + return s&f != 0 +} + func (s State) String() string { - switch s { - case haltState: - return "HALT" - case faultState: - return "FAULT" - case breakState: - return "BREAK" - default: + if s == noneState { return "NONE" } + + ss := make([]string, 0, 3) + if s.HasFlag(haltState) { + ss = append(ss, "HALT") + } + if s.HasFlag(faultState) { + ss = append(ss, "FAULT") + } + if s.HasFlag(breakState) { + ss = append(ss, "BREAK") + } + return strings.Join(ss, ", ") +} + +func StateFromString(s string) (st State, err error) { + if s = strings.TrimSpace(s); s == "NONE" { + return noneState, nil + } + + ss := strings.Split(s, ",") + for _, state := range ss { + switch state = strings.TrimSpace(state); state { + case "HALT": + st |= haltState + case "FAULT": + st |= faultState + case "BREAK": + st |= breakState + default: + return 0, errors.New("unknown state") + } + } + return +} + +func (s State) MarshalJSON() (data []byte, err error) { + return []byte(`"` + s.String() + `"`), nil +} + +func (s *State) UnmarshalJSON(data []byte) (err error) { + l := len(data) + if l < 2 || data[0] != '"' || data[l-1] != '"' { + return errors.New("wrong format") + } + + *s, err = StateFromString(string(data[1 : l-1])) + return } diff --git a/pkg/vm/state_test.go b/pkg/vm/state_test.go new file mode 100644 index 000000000..f1d2154f1 --- /dev/null +++ b/pkg/vm/state_test.go @@ -0,0 +1,88 @@ +package vm + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStateFromString(t *testing.T) { + var ( + s State + err error + ) + + s, err = StateFromString("HALT") + assert.NoError(t, err) + assert.Equal(t, haltState, s) + + s, err = StateFromString("BREAK") + assert.NoError(t, err) + assert.Equal(t, breakState, s) + + s, err = StateFromString("FAULT") + assert.NoError(t, err) + assert.Equal(t, faultState, s) + + s, err = StateFromString("NONE") + assert.NoError(t, err) + assert.Equal(t, noneState, s) + + s, err = StateFromString("HALT, BREAK") + assert.NoError(t, err) + assert.Equal(t, haltState|breakState, s) + + s, err = StateFromString("FAULT, BREAK") + assert.NoError(t, err) + assert.Equal(t, faultState|breakState, s) + + s, err = StateFromString("HALT, KEK") + assert.Error(t, err) +} + +func TestState_HasFlag(t *testing.T) { + assert.True(t, haltState.HasFlag(haltState)) + assert.True(t, breakState.HasFlag(breakState)) + assert.True(t, faultState.HasFlag(faultState)) + assert.True(t, (haltState | breakState).HasFlag(haltState)) + assert.True(t, (haltState | breakState).HasFlag(breakState)) + + assert.False(t, haltState.HasFlag(breakState)) + assert.False(t, noneState.HasFlag(haltState)) + assert.False(t, (faultState | breakState).HasFlag(haltState)) +} + +func TestState_MarshalJSON(t *testing.T) { + var ( + data []byte + err error + ) + + data, err = json.Marshal(haltState | breakState) + assert.NoError(t, err) + assert.Equal(t, data, []byte(`"HALT, BREAK"`)) + + data, err = json.Marshal(faultState) + assert.NoError(t, err) + assert.Equal(t, data, []byte(`"FAULT"`)) +} + +func TestState_UnmarshalJSON(t *testing.T) { + var ( + s State + err error + ) + + err = json.Unmarshal([]byte(`"HALT, BREAK"`), &s) + assert.NoError(t, err) + assert.Equal(t, haltState|breakState, s) + + err = json.Unmarshal([]byte(`"FAULT, BREAK"`), &s) + assert.NoError(t, err) + assert.Equal(t, faultState|breakState, s) + + err = json.Unmarshal([]byte(`"NONE"`), &s) + assert.NoError(t, err) + assert.Equal(t, noneState, s) +} diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 9876ab5f0..f201ae569 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -194,21 +194,21 @@ func (v *VM) Run() { v.state = noneState for { - switch v.state { - case haltState: + switch { + case v.state.HasFlag(haltState): if !v.mute { fmt.Println(v.Stack("estack")) } return - case breakState: + case v.state.HasFlag(breakState): ctx := v.Context() i, op := ctx.CurrInstr() fmt.Printf("at breakpoint %d (%s)\n", i, op.String()) return - case faultState: + case v.state.HasFlag(faultState): fmt.Println("FAULT") return - case noneState: + case v.state == noneState: v.Step() } }