callflag: add C#-compliant JSONization, fix #2040
This commit is contained in:
parent
e5d26a5df1
commit
ac126a300b
3 changed files with 161 additions and 1 deletions
|
@ -1,5 +1,11 @@
|
||||||
package callflag
|
package callflag
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
// CallFlag represents call flag.
|
// CallFlag represents call flag.
|
||||||
type CallFlag byte
|
type CallFlag byte
|
||||||
|
|
||||||
|
@ -16,7 +22,98 @@ const (
|
||||||
NoneFlag CallFlag = 0
|
NoneFlag CallFlag = 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var flagString = map[CallFlag]string{
|
||||||
|
ReadStates: "ReadStates",
|
||||||
|
WriteStates: "WriteStates",
|
||||||
|
AllowCall: "AllowCall",
|
||||||
|
AllowNotify: "AllowNotify",
|
||||||
|
States: "States",
|
||||||
|
ReadOnly: "ReadOnly",
|
||||||
|
All: "All",
|
||||||
|
NoneFlag: "None",
|
||||||
|
}
|
||||||
|
|
||||||
|
// basicFlags are all flags except All and None. It's used to stringify CallFlag
|
||||||
|
// where its bits are matched against these values from values with sets of bits
|
||||||
|
// to simple flags which is important to produce proper string representation
|
||||||
|
// matching C# Enum handling.
|
||||||
|
var basicFlags = []CallFlag{ReadOnly, States, ReadStates, WriteStates, AllowCall, AllowNotify}
|
||||||
|
|
||||||
|
// FromString parses input string and returns corresponding CallFlag.
|
||||||
|
func FromString(s string) (CallFlag, error) {
|
||||||
|
flags := strings.Split(s, ",")
|
||||||
|
if len(flags) == 0 {
|
||||||
|
return NoneFlag, errors.New("empty flags")
|
||||||
|
}
|
||||||
|
if len(flags) == 1 {
|
||||||
|
for f, str := range flagString {
|
||||||
|
if s == str {
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NoneFlag, errors.New("unknown flag")
|
||||||
|
}
|
||||||
|
|
||||||
|
var res CallFlag
|
||||||
|
|
||||||
|
for _, flag := range flags {
|
||||||
|
var knownFlag bool
|
||||||
|
|
||||||
|
flag = strings.TrimSpace(flag)
|
||||||
|
for _, f := range basicFlags {
|
||||||
|
if flag == flagString[f] {
|
||||||
|
res |= f
|
||||||
|
knownFlag = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !knownFlag {
|
||||||
|
return NoneFlag, errors.New("unknown/inappropriate flag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Has returns true iff all bits set in cf are also set in f.
|
// Has returns true iff all bits set in cf are also set in f.
|
||||||
func (f CallFlag) Has(cf CallFlag) bool {
|
func (f CallFlag) Has(cf CallFlag) bool {
|
||||||
return f&cf == cf
|
return f&cf == cf
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// String implements Stringer interface.
|
||||||
|
func (f CallFlag) String() string {
|
||||||
|
if flagString[f] != "" {
|
||||||
|
return flagString[f]
|
||||||
|
}
|
||||||
|
|
||||||
|
var res string
|
||||||
|
|
||||||
|
for _, flag := range basicFlags {
|
||||||
|
if f.Has(flag) {
|
||||||
|
if len(res) != 0 {
|
||||||
|
res += ", "
|
||||||
|
}
|
||||||
|
res += flagString[flag]
|
||||||
|
f &= ^flag // Some "States" shouldn't be combined with "ReadStates".
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaler interface.
|
||||||
|
func (f CallFlag) MarshalJSON() ([]byte, error) {
|
||||||
|
return []byte(`"` + f.String() + `"`), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler interface.
|
||||||
|
func (f *CallFlag) UnmarshalJSON(data []byte) error {
|
||||||
|
var js string
|
||||||
|
if err := json.Unmarshal(data, &js); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
flag, err := FromString(js)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*f = flag
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package callflag
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/internal/testserdes"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,3 +13,65 @@ func TestCallFlag_Has(t *testing.T) {
|
||||||
require.False(t, (AllowCall).Has(AllowCall|AllowNotify))
|
require.False(t, (AllowCall).Has(AllowCall|AllowNotify))
|
||||||
require.True(t, All.Has(ReadOnly))
|
require.True(t, All.Has(ReadOnly))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCallFlagString(t *testing.T) {
|
||||||
|
var cases = map[CallFlag]string{
|
||||||
|
NoneFlag: "None",
|
||||||
|
All: "All",
|
||||||
|
ReadStates: "ReadStates",
|
||||||
|
States: "States",
|
||||||
|
ReadOnly: "ReadOnly",
|
||||||
|
States | AllowCall: "ReadOnly, WriteStates",
|
||||||
|
ReadOnly | AllowNotify: "ReadOnly, AllowNotify",
|
||||||
|
States | AllowNotify: "States, AllowNotify",
|
||||||
|
}
|
||||||
|
for f, s := range cases {
|
||||||
|
require.Equal(t, s, f.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromString(t *testing.T) {
|
||||||
|
var cases = map[string]struct {
|
||||||
|
flag CallFlag
|
||||||
|
err bool
|
||||||
|
}{
|
||||||
|
"None": {NoneFlag, false},
|
||||||
|
"All": {All, false},
|
||||||
|
"ReadStates": {ReadStates, false},
|
||||||
|
"States": {States, false},
|
||||||
|
"ReadOnly": {ReadOnly, false},
|
||||||
|
"ReadOnly, WriteStates": {States | AllowCall, false},
|
||||||
|
"States, AllowCall": {States | AllowCall, false},
|
||||||
|
"AllowCall, States": {States | AllowCall, false},
|
||||||
|
"States, ReadOnly": {States | AllowCall, false},
|
||||||
|
" AllowCall,AllowNotify": {AllowNotify | AllowCall, false},
|
||||||
|
"BlahBlah": {NoneFlag, true},
|
||||||
|
"States, All": {NoneFlag, true},
|
||||||
|
"ReadStates,,AllowCall": {NoneFlag, true},
|
||||||
|
"ReadStates;AllowCall": {NoneFlag, true},
|
||||||
|
"readstates": {NoneFlag, true},
|
||||||
|
" All": {NoneFlag, true},
|
||||||
|
"None, All": {NoneFlag, true},
|
||||||
|
}
|
||||||
|
for s, res := range cases {
|
||||||
|
f, err := FromString(s)
|
||||||
|
require.True(t, res.err == (err != nil), "Input: '"+s+"'")
|
||||||
|
require.Equal(t, res.flag, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalUnmarshalJSON(t *testing.T) {
|
||||||
|
var f = States
|
||||||
|
testserdes.MarshalUnmarshalJSON(t, &f, new(CallFlag))
|
||||||
|
f = States | AllowNotify
|
||||||
|
testserdes.MarshalUnmarshalJSON(t, &f, new(CallFlag))
|
||||||
|
|
||||||
|
forig := f
|
||||||
|
err := f.UnmarshalJSON([]byte("42"))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, forig, f)
|
||||||
|
|
||||||
|
err = f.UnmarshalJSON([]byte(`"State"`))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, forig, f)
|
||||||
|
}
|
||||||
|
|
|
@ -147,7 +147,7 @@ func TestMarshalUnmarshalJSON(t *testing.T) {
|
||||||
"method": "someMethod",
|
"method": "someMethod",
|
||||||
"paramcount": 3,
|
"paramcount": 3,
|
||||||
"hasreturnvalue": true,
|
"hasreturnvalue": true,
|
||||||
"callflags": `+strconv.FormatInt(int64(expected.Tokens[0].CallFlag), 10)+`
|
"callflags": "`+expected.Tokens[0].CallFlag.String()+`"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"script": "`+base64.StdEncoding.EncodeToString(expected.Script)+`",
|
"script": "`+base64.StdEncoding.EncodeToString(expected.Script)+`",
|
||||||
|
|
Loading…
Reference in a new issue