diff --git a/pkg/smartcontract/callflag/call_flags.go b/pkg/smartcontract/callflag/call_flags.go index 07bbb15b4..c31b819b4 100644 --- a/pkg/smartcontract/callflag/call_flags.go +++ b/pkg/smartcontract/callflag/call_flags.go @@ -1,5 +1,11 @@ package callflag +import ( + "encoding/json" + "errors" + "strings" +) + // CallFlag represents call flag. type CallFlag byte @@ -16,7 +22,98 @@ const ( NoneFlag CallFlag = 0 ) +var flagString = map[CallFlag]string{ + ReadStates: "ReadStates", + WriteStates: "WriteStates", + AllowCall: "AllowCall", + AllowNotify: "AllowNotify", + States: "States", + ReadOnly: "ReadOnly", + All: "All", + NoneFlag: "None", +} + +// basicFlags are all flags except All and None. It's used to stringify CallFlag +// where its bits are matched against these values from values with sets of bits +// to simple flags which is important to produce proper string representation +// matching C# Enum handling. +var basicFlags = []CallFlag{ReadOnly, States, ReadStates, WriteStates, AllowCall, AllowNotify} + +// FromString parses input string and returns corresponding CallFlag. +func FromString(s string) (CallFlag, error) { + flags := strings.Split(s, ",") + if len(flags) == 0 { + return NoneFlag, errors.New("empty flags") + } + if len(flags) == 1 { + for f, str := range flagString { + if s == str { + return f, nil + } + } + return NoneFlag, errors.New("unknown flag") + } + + var res CallFlag + + for _, flag := range flags { + var knownFlag bool + + flag = strings.TrimSpace(flag) + for _, f := range basicFlags { + if flag == flagString[f] { + res |= f + knownFlag = true + break + } + } + if !knownFlag { + return NoneFlag, errors.New("unknown/inappropriate flag") + } + } + return res, nil +} + // Has returns true iff all bits set in cf are also set in f. func (f CallFlag) Has(cf CallFlag) bool { return f&cf == cf } + +// String implements Stringer interface. +func (f CallFlag) String() string { + if flagString[f] != "" { + return flagString[f] + } + + var res string + + for _, flag := range basicFlags { + if f.Has(flag) { + if len(res) != 0 { + res += ", " + } + res += flagString[flag] + f &= ^flag // Some "States" shouldn't be combined with "ReadStates". + } + } + return res +} + +// MarshalJSON implements json.Marshaler interface. +func (f CallFlag) MarshalJSON() ([]byte, error) { + return []byte(`"` + f.String() + `"`), nil +} + +// UnmarshalJSON implements json.Unmarshaler interface. +func (f *CallFlag) UnmarshalJSON(data []byte) error { + var js string + if err := json.Unmarshal(data, &js); err != nil { + return err + } + flag, err := FromString(js) + if err != nil { + return err + } + *f = flag + return nil +} diff --git a/pkg/smartcontract/callflag/call_flags_test.go b/pkg/smartcontract/callflag/call_flags_test.go index 7dd9048b8..b75145150 100644 --- a/pkg/smartcontract/callflag/call_flags_test.go +++ b/pkg/smartcontract/callflag/call_flags_test.go @@ -3,6 +3,7 @@ package callflag import ( "testing" + "github.com/nspcc-dev/neo-go/internal/testserdes" "github.com/stretchr/testify/require" ) @@ -12,3 +13,65 @@ func TestCallFlag_Has(t *testing.T) { require.False(t, (AllowCall).Has(AllowCall|AllowNotify)) require.True(t, All.Has(ReadOnly)) } + +func TestCallFlagString(t *testing.T) { + var cases = map[CallFlag]string{ + NoneFlag: "None", + All: "All", + ReadStates: "ReadStates", + States: "States", + ReadOnly: "ReadOnly", + States | AllowCall: "ReadOnly, WriteStates", + ReadOnly | AllowNotify: "ReadOnly, AllowNotify", + States | AllowNotify: "States, AllowNotify", + } + for f, s := range cases { + require.Equal(t, s, f.String()) + } +} + +func TestFromString(t *testing.T) { + var cases = map[string]struct { + flag CallFlag + err bool + }{ + "None": {NoneFlag, false}, + "All": {All, false}, + "ReadStates": {ReadStates, false}, + "States": {States, false}, + "ReadOnly": {ReadOnly, false}, + "ReadOnly, WriteStates": {States | AllowCall, false}, + "States, AllowCall": {States | AllowCall, false}, + "AllowCall, States": {States | AllowCall, false}, + "States, ReadOnly": {States | AllowCall, false}, + " AllowCall,AllowNotify": {AllowNotify | AllowCall, false}, + "BlahBlah": {NoneFlag, true}, + "States, All": {NoneFlag, true}, + "ReadStates,,AllowCall": {NoneFlag, true}, + "ReadStates;AllowCall": {NoneFlag, true}, + "readstates": {NoneFlag, true}, + " All": {NoneFlag, true}, + "None, All": {NoneFlag, true}, + } + for s, res := range cases { + f, err := FromString(s) + require.True(t, res.err == (err != nil), "Input: '"+s+"'") + require.Equal(t, res.flag, f) + } +} + +func TestMarshalUnmarshalJSON(t *testing.T) { + var f = States + testserdes.MarshalUnmarshalJSON(t, &f, new(CallFlag)) + f = States | AllowNotify + testserdes.MarshalUnmarshalJSON(t, &f, new(CallFlag)) + + forig := f + err := f.UnmarshalJSON([]byte("42")) + require.Error(t, err) + require.Equal(t, forig, f) + + err = f.UnmarshalJSON([]byte(`"State"`)) + require.Error(t, err) + require.Equal(t, forig, f) +} diff --git a/pkg/smartcontract/nef/nef_test.go b/pkg/smartcontract/nef/nef_test.go index 8c52bae2c..03d3af86c 100644 --- a/pkg/smartcontract/nef/nef_test.go +++ b/pkg/smartcontract/nef/nef_test.go @@ -147,7 +147,7 @@ func TestMarshalUnmarshalJSON(t *testing.T) { "method": "someMethod", "paramcount": 3, "hasreturnvalue": true, - "callflags": `+strconv.FormatInt(int64(expected.Tokens[0].CallFlag), 10)+` + "callflags": "`+expected.Tokens[0].CallFlag.String()+`" } ], "script": "`+base64.StdEncoding.EncodeToString(expected.Script)+`",