mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2025-01-22 09:43:47 +00:00
make State a set as in reference C# implementation (#123)
* make State a set as in reference C# implementation * fix issues
This commit is contained in:
parent
001a0e601e
commit
57cb289bcd
4 changed files with 161 additions and 20 deletions
|
@ -1,6 +1,9 @@
|
|||
package rpc
|
||||
|
||||
import "github.com/CityOfZion/neo-go/pkg/core/transaction"
|
||||
import (
|
||||
"github.com/CityOfZion/neo-go/pkg/core/transaction"
|
||||
"github.com/CityOfZion/neo-go/pkg/vm"
|
||||
)
|
||||
|
||||
type InvokeScriptResponse struct {
|
||||
responseHeader
|
||||
|
@ -11,9 +14,9 @@ type InvokeScriptResponse struct {
|
|||
// InvokeResult represents the outcome of a script that is
|
||||
// executed by the NEO VM.
|
||||
type InvokeResult struct {
|
||||
State string `json:"state"`
|
||||
GasConsumed string `json:"gas_consumed"`
|
||||
Script string `json:"script"`
|
||||
State vm.State `json:"state"`
|
||||
GasConsumed string `json:"gas_consumed"`
|
||||
Script string `json:"script"`
|
||||
Stack []StackParam
|
||||
}
|
||||
|
||||
|
|
|
@ -1,25 +1,75 @@
|
|||
package vm
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// State of the VM.
|
||||
type State uint
|
||||
type State uint8
|
||||
|
||||
// Available States.
|
||||
const (
|
||||
noneState State = iota
|
||||
haltState
|
||||
noneState State = 0
|
||||
haltState State = 1 << iota
|
||||
faultState
|
||||
breakState
|
||||
)
|
||||
|
||||
func (s State) HasFlag(f State) bool {
|
||||
return s&f != 0
|
||||
}
|
||||
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case haltState:
|
||||
return "HALT"
|
||||
case faultState:
|
||||
return "FAULT"
|
||||
case breakState:
|
||||
return "BREAK"
|
||||
default:
|
||||
if s == noneState {
|
||||
return "NONE"
|
||||
}
|
||||
|
||||
ss := make([]string, 0, 3)
|
||||
if s.HasFlag(haltState) {
|
||||
ss = append(ss, "HALT")
|
||||
}
|
||||
if s.HasFlag(faultState) {
|
||||
ss = append(ss, "FAULT")
|
||||
}
|
||||
if s.HasFlag(breakState) {
|
||||
ss = append(ss, "BREAK")
|
||||
}
|
||||
return strings.Join(ss, ", ")
|
||||
}
|
||||
|
||||
func StateFromString(s string) (st State, err error) {
|
||||
if s = strings.TrimSpace(s); s == "NONE" {
|
||||
return noneState, nil
|
||||
}
|
||||
|
||||
ss := strings.Split(s, ",")
|
||||
for _, state := range ss {
|
||||
switch state = strings.TrimSpace(state); state {
|
||||
case "HALT":
|
||||
st |= haltState
|
||||
case "FAULT":
|
||||
st |= faultState
|
||||
case "BREAK":
|
||||
st |= breakState
|
||||
default:
|
||||
return 0, errors.New("unknown state")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s State) MarshalJSON() (data []byte, err error) {
|
||||
return []byte(`"` + s.String() + `"`), nil
|
||||
}
|
||||
|
||||
func (s *State) UnmarshalJSON(data []byte) (err error) {
|
||||
l := len(data)
|
||||
if l < 2 || data[0] != '"' || data[l-1] != '"' {
|
||||
return errors.New("wrong format")
|
||||
}
|
||||
|
||||
*s, err = StateFromString(string(data[1 : l-1]))
|
||||
return
|
||||
}
|
||||
|
|
88
pkg/vm/state_test.go
Normal file
88
pkg/vm/state_test.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
package vm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStateFromString(t *testing.T) {
|
||||
var (
|
||||
s State
|
||||
err error
|
||||
)
|
||||
|
||||
s, err = StateFromString("HALT")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, haltState, s)
|
||||
|
||||
s, err = StateFromString("BREAK")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, breakState, s)
|
||||
|
||||
s, err = StateFromString("FAULT")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, faultState, s)
|
||||
|
||||
s, err = StateFromString("NONE")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, noneState, s)
|
||||
|
||||
s, err = StateFromString("HALT, BREAK")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, haltState|breakState, s)
|
||||
|
||||
s, err = StateFromString("FAULT, BREAK")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, faultState|breakState, s)
|
||||
|
||||
s, err = StateFromString("HALT, KEK")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestState_HasFlag(t *testing.T) {
|
||||
assert.True(t, haltState.HasFlag(haltState))
|
||||
assert.True(t, breakState.HasFlag(breakState))
|
||||
assert.True(t, faultState.HasFlag(faultState))
|
||||
assert.True(t, (haltState | breakState).HasFlag(haltState))
|
||||
assert.True(t, (haltState | breakState).HasFlag(breakState))
|
||||
|
||||
assert.False(t, haltState.HasFlag(breakState))
|
||||
assert.False(t, noneState.HasFlag(haltState))
|
||||
assert.False(t, (faultState | breakState).HasFlag(haltState))
|
||||
}
|
||||
|
||||
func TestState_MarshalJSON(t *testing.T) {
|
||||
var (
|
||||
data []byte
|
||||
err error
|
||||
)
|
||||
|
||||
data, err = json.Marshal(haltState | breakState)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, data, []byte(`"HALT, BREAK"`))
|
||||
|
||||
data, err = json.Marshal(faultState)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, data, []byte(`"FAULT"`))
|
||||
}
|
||||
|
||||
func TestState_UnmarshalJSON(t *testing.T) {
|
||||
var (
|
||||
s State
|
||||
err error
|
||||
)
|
||||
|
||||
err = json.Unmarshal([]byte(`"HALT, BREAK"`), &s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, haltState|breakState, s)
|
||||
|
||||
err = json.Unmarshal([]byte(`"FAULT, BREAK"`), &s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, faultState|breakState, s)
|
||||
|
||||
err = json.Unmarshal([]byte(`"NONE"`), &s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, noneState, s)
|
||||
}
|
10
pkg/vm/vm.go
10
pkg/vm/vm.go
|
@ -194,21 +194,21 @@ func (v *VM) Run() {
|
|||
|
||||
v.state = noneState
|
||||
for {
|
||||
switch v.state {
|
||||
case haltState:
|
||||
switch {
|
||||
case v.state.HasFlag(haltState):
|
||||
if !v.mute {
|
||||
fmt.Println(v.Stack("estack"))
|
||||
}
|
||||
return
|
||||
case breakState:
|
||||
case v.state.HasFlag(breakState):
|
||||
ctx := v.Context()
|
||||
i, op := ctx.CurrInstr()
|
||||
fmt.Printf("at breakpoint %d (%s)\n", i, op.String())
|
||||
return
|
||||
case faultState:
|
||||
case v.state.HasFlag(faultState):
|
||||
fmt.Println("FAULT")
|
||||
return
|
||||
case noneState:
|
||||
case v.state == noneState:
|
||||
v.Step()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue