callflag: add C#-compliant JSONization, fix #2040

This commit is contained in:
Roman Khimov 2021-07-05 18:23:16 +03:00
parent e5d26a5df1
commit ac126a300b
3 changed files with 161 additions and 1 deletions

View file

@ -1,5 +1,11 @@
package callflag package callflag
import (
"encoding/json"
"errors"
"strings"
)
// CallFlag represents call flag. // CallFlag represents call flag.
type CallFlag byte type CallFlag byte
@ -16,7 +22,98 @@ const (
NoneFlag CallFlag = 0 NoneFlag CallFlag = 0
) )
var flagString = map[CallFlag]string{
ReadStates: "ReadStates",
WriteStates: "WriteStates",
AllowCall: "AllowCall",
AllowNotify: "AllowNotify",
States: "States",
ReadOnly: "ReadOnly",
All: "All",
NoneFlag: "None",
}
// basicFlags are all flags except All and None. It's used to stringify CallFlag
// where its bits are matched against these values from values with sets of bits
// to simple flags which is important to produce proper string representation
// matching C# Enum handling.
var basicFlags = []CallFlag{ReadOnly, States, ReadStates, WriteStates, AllowCall, AllowNotify}
// FromString parses input string and returns corresponding CallFlag.
func FromString(s string) (CallFlag, error) {
flags := strings.Split(s, ",")
if len(flags) == 0 {
return NoneFlag, errors.New("empty flags")
}
if len(flags) == 1 {
for f, str := range flagString {
if s == str {
return f, nil
}
}
return NoneFlag, errors.New("unknown flag")
}
var res CallFlag
for _, flag := range flags {
var knownFlag bool
flag = strings.TrimSpace(flag)
for _, f := range basicFlags {
if flag == flagString[f] {
res |= f
knownFlag = true
break
}
}
if !knownFlag {
return NoneFlag, errors.New("unknown/inappropriate flag")
}
}
return res, nil
}
// Has returns true iff all bits set in cf are also set in f. // Has returns true iff all bits set in cf are also set in f.
func (f CallFlag) Has(cf CallFlag) bool { func (f CallFlag) Has(cf CallFlag) bool {
return f&cf == cf return f&cf == cf
} }
// String implements Stringer interface.
func (f CallFlag) String() string {
if flagString[f] != "" {
return flagString[f]
}
var res string
for _, flag := range basicFlags {
if f.Has(flag) {
if len(res) != 0 {
res += ", "
}
res += flagString[flag]
f &= ^flag // Some "States" shouldn't be combined with "ReadStates".
}
}
return res
}
// MarshalJSON implements json.Marshaler interface.
func (f CallFlag) MarshalJSON() ([]byte, error) {
return []byte(`"` + f.String() + `"`), nil
}
// UnmarshalJSON implements json.Unmarshaler interface.
func (f *CallFlag) UnmarshalJSON(data []byte) error {
var js string
if err := json.Unmarshal(data, &js); err != nil {
return err
}
flag, err := FromString(js)
if err != nil {
return err
}
*f = flag
return nil
}

View file

@ -3,6 +3,7 @@ package callflag
import ( import (
"testing" "testing"
"github.com/nspcc-dev/neo-go/internal/testserdes"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -12,3 +13,65 @@ func TestCallFlag_Has(t *testing.T) {
require.False(t, (AllowCall).Has(AllowCall|AllowNotify)) require.False(t, (AllowCall).Has(AllowCall|AllowNotify))
require.True(t, All.Has(ReadOnly)) require.True(t, All.Has(ReadOnly))
} }
func TestCallFlagString(t *testing.T) {
var cases = map[CallFlag]string{
NoneFlag: "None",
All: "All",
ReadStates: "ReadStates",
States: "States",
ReadOnly: "ReadOnly",
States | AllowCall: "ReadOnly, WriteStates",
ReadOnly | AllowNotify: "ReadOnly, AllowNotify",
States | AllowNotify: "States, AllowNotify",
}
for f, s := range cases {
require.Equal(t, s, f.String())
}
}
func TestFromString(t *testing.T) {
var cases = map[string]struct {
flag CallFlag
err bool
}{
"None": {NoneFlag, false},
"All": {All, false},
"ReadStates": {ReadStates, false},
"States": {States, false},
"ReadOnly": {ReadOnly, false},
"ReadOnly, WriteStates": {States | AllowCall, false},
"States, AllowCall": {States | AllowCall, false},
"AllowCall, States": {States | AllowCall, false},
"States, ReadOnly": {States | AllowCall, false},
" AllowCall,AllowNotify": {AllowNotify | AllowCall, false},
"BlahBlah": {NoneFlag, true},
"States, All": {NoneFlag, true},
"ReadStates,,AllowCall": {NoneFlag, true},
"ReadStates;AllowCall": {NoneFlag, true},
"readstates": {NoneFlag, true},
" All": {NoneFlag, true},
"None, All": {NoneFlag, true},
}
for s, res := range cases {
f, err := FromString(s)
require.True(t, res.err == (err != nil), "Input: '"+s+"'")
require.Equal(t, res.flag, f)
}
}
func TestMarshalUnmarshalJSON(t *testing.T) {
var f = States
testserdes.MarshalUnmarshalJSON(t, &f, new(CallFlag))
f = States | AllowNotify
testserdes.MarshalUnmarshalJSON(t, &f, new(CallFlag))
forig := f
err := f.UnmarshalJSON([]byte("42"))
require.Error(t, err)
require.Equal(t, forig, f)
err = f.UnmarshalJSON([]byte(`"State"`))
require.Error(t, err)
require.Equal(t, forig, f)
}

View file

@ -147,7 +147,7 @@ func TestMarshalUnmarshalJSON(t *testing.T) {
"method": "someMethod", "method": "someMethod",
"paramcount": 3, "paramcount": 3,
"hasreturnvalue": true, "hasreturnvalue": true,
"callflags": `+strconv.FormatInt(int64(expected.Tokens[0].CallFlag), 10)+` "callflags": "`+expected.Tokens[0].CallFlag.String()+`"
} }
], ],
"script": "`+base64.StdEncoding.EncodeToString(expected.Script)+`", "script": "`+base64.StdEncoding.EncodeToString(expected.Script)+`",