diff --git a/pkg/core/interop/runtime/witness.go b/pkg/core/interop/runtime/witness.go index e9e456d0f..cc951724d 100644 --- a/pkg/core/interop/runtime/witness.go +++ b/pkg/core/interop/runtime/witness.go @@ -9,6 +9,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -28,6 +29,38 @@ func CheckHashedWitness(ic *interop.Context, hash util.Uint160) (bool, error) { return false, errors.New("script container is not a transaction") } +type scopeContext struct { + *vm.VM + ic *interop.Context +} + +func getContractGroups(v *vm.VM, ic *interop.Context, h util.Uint160) (manifest.Groups, error) { + if !v.Context().GetCallFlags().Has(callflag.ReadStates) { + return nil, errors.New("missing ReadStates call flag") + } + cs, err := ic.GetContract(h) + if err != nil { + return nil, nil // It's OK to not have the contract. + } + return manifest.Groups(cs.Manifest.Groups), nil +} + +func (sc scopeContext) checkScriptGroups(h util.Uint160, k *keys.PublicKey) (bool, error) { + groups, err := getContractGroups(sc.VM, sc.ic, h) + if err != nil { + return false, err + } + return groups.Contains(k), nil +} + +func (sc scopeContext) CallingScriptHasGroup(k *keys.PublicKey) (bool, error) { + return sc.checkScriptGroups(sc.GetCallingScriptHash(), k) +} + +func (sc scopeContext) CurrentScriptHasGroup(k *keys.PublicKey) (bool, error) { + return sc.checkScriptGroups(sc.GetCurrentScriptHash(), k) +} + func checkScope(ic *interop.Context, tx *transaction.Transaction, v *vm.VM, hash util.Uint160) (bool, error) { for _, c := range tx.Signers { if c.Account == hash { @@ -50,19 +83,26 @@ func checkScope(ic *interop.Context, tx *transaction.Transaction, v *vm.VM, hash } } if c.Scopes&transaction.CustomGroups != 0 { - if !v.Context().GetCallFlags().Has(callflag.ReadStates) { - return false, errors.New("missing ReadStates call flag") - } - cs, err := ic.GetContract(v.GetCurrentScriptHash()) + groups, err := getContractGroups(v, ic, v.GetCurrentScriptHash()) if err != nil { - return false, nil + return false, err } // check if the current group is the required one for _, allowedGroup := range c.AllowedGroups { - for _, group := range cs.Manifest.Groups { - if group.PublicKey.Equal(allowedGroup) { - return true, nil - } + if groups.Contains(allowedGroup) { + return true, nil + } + } + } + if c.Scopes&transaction.Rules != 0 { + ctx := scopeContext{v, ic} + for _, r := range c.Rules { + res, err := r.Condition.Match(ctx) + if err != nil { + return false, err + } + if res { + return r.Action == transaction.WitnessAllow, nil } } } diff --git a/pkg/core/interop_system_test.go b/pkg/core/interop_system_test.go index f86b73858..e2adda0a8 100644 --- a/pkg/core/interop_system_test.go +++ b/pkg/core/interop_system_test.go @@ -1189,6 +1189,28 @@ func TestRuntimeCheckWitness(t *testing.T) { ic.VM.LoadScriptWithHash([]byte{0x1}, random.Uint160(), callflag.AllowCall) check(t, ic, hash.BytesBE(), true) }) + t.Run("Rules, missing ReadStates flag", func(t *testing.T) { + hash := random.Uint160() + pk, err := keys.NewPrivateKey() + require.NoError(t, err) + tx := &transaction.Transaction{ + Signers: []transaction.Signer{ + { + Account: hash, + Scopes: transaction.Rules, + Rules: []transaction.WitnessRule{{ + Action: transaction.WitnessAllow, + Condition: (*transaction.ConditionGroup)(pk.PublicKey()), + }}, + }, + }, + } + ic.Container = tx + callingScriptHash := scriptHash + loadScriptWithHashAndFlags(ic, script, callingScriptHash, callflag.All) + ic.VM.LoadScriptWithHash([]byte{0x1}, random.Uint160(), callflag.AllowCall) + check(t, ic, hash.BytesBE(), true) + }) }) }) t.Run("positive", func(t *testing.T) { @@ -1301,6 +1323,64 @@ func TestRuntimeCheckWitness(t *testing.T) { check(t, ic, targetHash.BytesBE(), false, true) }) }) + t.Run("Rules", func(t *testing.T) { + t.Run("no match", func(t *testing.T) { + hash := random.Uint160() + tx := &transaction.Transaction{ + Signers: []transaction.Signer{ + { + Account: hash, + Scopes: transaction.Rules, + Rules: []transaction.WitnessRule{{ + Action: transaction.WitnessAllow, + Condition: (*transaction.ConditionScriptHash)(&hash), + }}, + }, + }, + } + loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) + ic.Container = tx + check(t, ic, hash.BytesBE(), false, false) + }) + t.Run("allow", func(t *testing.T) { + hash := random.Uint160() + var cond = true + tx := &transaction.Transaction{ + Signers: []transaction.Signer{ + { + Account: hash, + Scopes: transaction.Rules, + Rules: []transaction.WitnessRule{{ + Action: transaction.WitnessAllow, + Condition: (*transaction.ConditionBoolean)(&cond), + }}, + }, + }, + } + loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) + ic.Container = tx + check(t, ic, hash.BytesBE(), false, true) + }) + t.Run("deny", func(t *testing.T) { + hash := random.Uint160() + var cond = true + tx := &transaction.Transaction{ + Signers: []transaction.Signer{ + { + Account: hash, + Scopes: transaction.Rules, + Rules: []transaction.WitnessRule{{ + Action: transaction.WitnessDeny, + Condition: (*transaction.ConditionBoolean)(&cond), + }}, + }, + }, + } + loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) + ic.Container = tx + check(t, ic, hash.BytesBE(), false, false) + }) + }) t.Run("bad scope", func(t *testing.T) { hash := random.Uint160() tx := &transaction.Transaction{ diff --git a/pkg/core/transaction/signer.go b/pkg/core/transaction/signer.go index ef409e25b..74400daf9 100644 --- a/pkg/core/transaction/signer.go +++ b/pkg/core/transaction/signer.go @@ -17,6 +17,7 @@ type Signer struct { Scopes WitnessScope `json:"scopes"` AllowedContracts []util.Uint160 `json:"allowedcontracts,omitempty"` AllowedGroups []*keys.PublicKey `json:"allowedgroups,omitempty"` + Rules []WitnessRule `json:"rules,omitempty"` } // EncodeBinary implements Serializable interface. @@ -29,6 +30,9 @@ func (c *Signer) EncodeBinary(bw *io.BinWriter) { if c.Scopes&CustomGroups != 0 { bw.WriteArray(c.AllowedGroups) } + if c.Scopes&Rules != 0 { + bw.WriteArray(c.Rules) + } } // DecodeBinary implements Serializable interface. @@ -49,4 +53,7 @@ func (c *Signer) DecodeBinary(br *io.BinReader) { if c.Scopes&CustomGroups != 0 { br.ReadArray(&c.AllowedGroups, maxSubitems) } + if c.Scopes&Rules != 0 { + br.ReadArray(&c.Rules, maxSubitems) + } } diff --git a/pkg/core/transaction/witness_condition.go b/pkg/core/transaction/witness_condition.go new file mode 100644 index 000000000..cb41b7746 --- /dev/null +++ b/pkg/core/transaction/witness_condition.go @@ -0,0 +1,575 @@ +package transaction + +import ( + "encoding/json" + "errors" + + "github.com/nspcc-dev/neo-go/pkg/crypto/keys" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +//go:generate stringer -type=WitnessConditionType -linecomment + +// WitnessConditionType encodes a type of witness condition. +type WitnessConditionType byte + +const ( + // WitnessBoolean is a generic boolean condition. + WitnessBoolean WitnessConditionType = 0x00 // Boolean + // WitnessNot reverses another condition. + WitnessNot WitnessConditionType = 0x01 // Not + // WitnessAnd means that all conditions must be met. + WitnessAnd WitnessConditionType = 0x02 // And + // WitnessOr means that any of conditions must be met. + WitnessOr WitnessConditionType = 0x03 // Or + // WitnessScriptHash matches executing contract's script hash. + WitnessScriptHash WitnessConditionType = 0x18 // ScriptHash + // WitnessGroup matches executing contract's group key. + WitnessGroup WitnessConditionType = 0x19 // Group + // WitnessCalledByEntry matches when current script is an entry script or is called by an entry script. + WitnessCalledByEntry WitnessConditionType = 0x20 // CalledByEntry + // WitnessCalledByContract matches when current script is called by the specified contract. + WitnessCalledByContract WitnessConditionType = 0x28 // CalledByContract + // WitnessCalledByGroup matches when current script is called by contract belonging to the specified group. + WitnessCalledByGroup WitnessConditionType = 0x29 // CalledByGroup + + // MaxConditionNesting limits the maximum allowed level of condition nesting. + MaxConditionNesting = 2 +) + +// WitnessCondition is a condition of WitnessRule. +type WitnessCondition interface { + // Type returns a type of this condition. + Type() WitnessConditionType + // Match checks whether this condition matches current context. + Match(MatchContext) (bool, error) + // EncodeBinary allows to serialize condition to its binary + // representation (including type data). + EncodeBinary(*io.BinWriter) + // DecodeBinarySpecific decodes type-specific binary data from the given + // reader (not including type data). + DecodeBinarySpecific(*io.BinReader, int) + + json.Marshaler +} + +// MatchContext is a set of methods from execution engine needed to perform the +// witness check. +type MatchContext interface { + GetCallingScriptHash() util.Uint160 + GetCurrentScriptHash() util.Uint160 + GetEntryScriptHash() util.Uint160 + CallingScriptHasGroup(*keys.PublicKey) (bool, error) + CurrentScriptHasGroup(*keys.PublicKey) (bool, error) +} + +type ( + // ConditionBoolean is a boolean condition type. + ConditionBoolean bool + // ConditionNot inverses the meaning of contained condition. + ConditionNot struct { + Condition WitnessCondition + } + // ConditionAnd is a set of conditions required to match. + ConditionAnd []WitnessCondition + // ConditionOr is a set of conditions one of which is required to match. + ConditionOr []WitnessCondition + // ConditionScriptHash is a condition matching executing script hash. + ConditionScriptHash util.Uint160 + // ConditionGroup is a condition matching executing script group. + ConditionGroup keys.PublicKey + // ConditionCalledByEntry is a condition matching entry script or one directly called by it. + ConditionCalledByEntry struct{} + // ConditionCalledByContract is a condition matching calling script hash. + ConditionCalledByContract util.Uint160 + // ConditionCalledByGroup is a condition matching calling script group. + ConditionCalledByGroup keys.PublicKey +) + +// conditionAux is used for JSON marshaling/unmarshaling. +type conditionAux struct { + Expression json.RawMessage `json:"expression,omitempty"` // Can be either boolean or conditionAux. + Expressions []json.RawMessage `json:"expressions,omitempty"` + Group *keys.PublicKey `json:"group,omitempty"` + Hash *util.Uint160 `json:"hash,omitempty"` + Type string `json:"type"` +} + +// Type implements WitnessCondition interface and returns condition type. +func (c *ConditionBoolean) Type() WitnessConditionType { + return WitnessBoolean +} + +// Match implements WitnessCondition interface checking whether this condition +// matches given context. +func (c *ConditionBoolean) Match(_ MatchContext) (bool, error) { + return bool(*c), nil +} + +// EncodeBinary implements WitnessCondition interface allowing to serialize condition. +func (c *ConditionBoolean) EncodeBinary(w *io.BinWriter) { + w.WriteB(byte(c.Type())) + w.WriteBool(bool(*c)) +} + +// DecodeBinarySpecific implements WitnessCondition interface allowing to +// deserialize condition-specific data. +func (c *ConditionBoolean) DecodeBinarySpecific(r *io.BinReader, maxDepth int) { + *c = ConditionBoolean(r.ReadBool()) +} + +// MarshalJSON implements json.Marshaler interface. +func (c *ConditionBoolean) MarshalJSON() ([]byte, error) { + boolJSON, _ := json.Marshal(bool(*c)) // Simple boolean can't fail. + aux := conditionAux{ + Type: c.Type().String(), + Expression: json.RawMessage(boolJSON), + } + return json.Marshal(aux) +} + +// Type implements WitnessCondition interface and returns condition type. +func (c *ConditionNot) Type() WitnessConditionType { + return WitnessNot +} + +// Match implements WitnessCondition interface checking whether this condition +// matches given context. +func (c *ConditionNot) Match(ctx MatchContext) (bool, error) { + res, err := c.Condition.Match(ctx) + return ((err == nil) && !res), err +} + +// EncodeBinary implements WitnessCondition interface allowing to serialize condition. +func (c *ConditionNot) EncodeBinary(w *io.BinWriter) { + w.WriteB(byte(c.Type())) + c.Condition.EncodeBinary(w) +} + +// DecodeBinarySpecific implements WitnessCondition interface allowing to +// deserialize condition-specific data. +func (c *ConditionNot) DecodeBinarySpecific(r *io.BinReader, maxDepth int) { + c.Condition = decodeBinaryCondition(r, maxDepth-1) +} + +// MarshalJSON implements json.Marshaler interface. +func (c *ConditionNot) MarshalJSON() ([]byte, error) { + condJSON, err := json.Marshal(c.Condition) + if err != nil { + return nil, err + } + aux := conditionAux{ + Type: c.Type().String(), + Expression: json.RawMessage(condJSON), + } + return json.Marshal(aux) +} + +// Type implements WitnessCondition interface and returns condition type. +func (c *ConditionAnd) Type() WitnessConditionType { + return WitnessAnd +} + +// Match implements WitnessCondition interface checking whether this condition +// matches given context. +func (c *ConditionAnd) Match(ctx MatchContext) (bool, error) { + for _, cond := range *c { + res, err := cond.Match(ctx) + if err != nil { + return false, err + } + if !res { + return false, nil + } + } + return true, nil +} + +// EncodeBinary implements WitnessCondition interface allowing to serialize condition. +func (c *ConditionAnd) EncodeBinary(w *io.BinWriter) { + w.WriteB(byte(c.Type())) + w.WriteArray([]WitnessCondition(*c)) +} + +func readArrayOfConditions(r *io.BinReader, maxDepth int) []WitnessCondition { + l := r.ReadVarUint() + if l == 0 { + r.Err = errors.New("empty array of conditions") + return nil + } + if l > maxSubitems { + r.Err = errors.New("too many elements") + return nil + } + a := make([]WitnessCondition, l) + for i := 0; i < int(l); i++ { + a[i] = decodeBinaryCondition(r, maxDepth-1) + } + if r.Err != nil { + return nil + } + return a +} + +// DecodeBinarySpecific implements WitnessCondition interface allowing to +// deserialize condition-specific data. +func (c *ConditionAnd) DecodeBinarySpecific(r *io.BinReader, maxDepth int) { + a := readArrayOfConditions(r, maxDepth) + if r.Err == nil { + *c = a + } +} + +func arrayToJSON(c WitnessCondition, a []WitnessCondition) ([]byte, error) { + exprs := make([]json.RawMessage, len(a)) + for i := 0; i < len(a); i++ { + b, err := a[i].MarshalJSON() + if err != nil { + return nil, err + } + exprs[i] = json.RawMessage(b) + } + aux := conditionAux{ + Type: c.Type().String(), + Expressions: exprs, + } + return json.Marshal(aux) +} + +// MarshalJSON implements json.Marshaler interface. +func (c *ConditionAnd) MarshalJSON() ([]byte, error) { + return arrayToJSON(c, []WitnessCondition(*c)) +} + +// Type implements WitnessCondition interface and returns condition type. +func (c *ConditionOr) Type() WitnessConditionType { + return WitnessOr +} + +// Match implements WitnessCondition interface checking whether this condition +// matches given context. +func (c *ConditionOr) Match(ctx MatchContext) (bool, error) { + for _, cond := range *c { + res, err := cond.Match(ctx) + if err != nil { + return false, err + } + if res { + return true, nil + } + } + return false, nil +} + +// EncodeBinary implements WitnessCondition interface allowing to serialize condition. +func (c *ConditionOr) EncodeBinary(w *io.BinWriter) { + w.WriteB(byte(c.Type())) + w.WriteArray([]WitnessCondition(*c)) +} + +// DecodeBinarySpecific implements WitnessCondition interface allowing to +// deserialize condition-specific data. +func (c *ConditionOr) DecodeBinarySpecific(r *io.BinReader, maxDepth int) { + a := readArrayOfConditions(r, maxDepth) + if r.Err == nil { + *c = a + } +} + +// MarshalJSON implements json.Marshaler interface. +func (c *ConditionOr) MarshalJSON() ([]byte, error) { + return arrayToJSON(c, []WitnessCondition(*c)) +} + +// Type implements WitnessCondition interface and returns condition type. +func (c *ConditionScriptHash) Type() WitnessConditionType { + return WitnessScriptHash +} + +// Match implements WitnessCondition interface checking whether this condition +// matches given context. +func (c *ConditionScriptHash) Match(ctx MatchContext) (bool, error) { + return util.Uint160(*c).Equals(ctx.GetCurrentScriptHash()), nil +} + +// EncodeBinary implements WitnessCondition interface allowing to serialize condition. +func (c *ConditionScriptHash) EncodeBinary(w *io.BinWriter) { + w.WriteB(byte(c.Type())) + w.WriteBytes(c[:]) +} + +// DecodeBinarySpecific implements WitnessCondition interface allowing to +// deserialize condition-specific data. +func (c *ConditionScriptHash) DecodeBinarySpecific(r *io.BinReader, _ int) { + r.ReadBytes(c[:]) +} + +// MarshalJSON implements json.Marshaler interface. +func (c *ConditionScriptHash) MarshalJSON() ([]byte, error) { + aux := conditionAux{ + Type: c.Type().String(), + Hash: (*util.Uint160)(c), + } + return json.Marshal(aux) +} + +// Type implements WitnessCondition interface and returns condition type. +func (c *ConditionGroup) Type() WitnessConditionType { + return WitnessGroup +} + +// Match implements WitnessCondition interface checking whether this condition +// matches given context. +func (c *ConditionGroup) Match(ctx MatchContext) (bool, error) { + return ctx.CurrentScriptHasGroup((*keys.PublicKey)(c)) +} + +// EncodeBinary implements WitnessCondition interface allowing to serialize condition. +func (c *ConditionGroup) EncodeBinary(w *io.BinWriter) { + w.WriteB(byte(c.Type())) + (*keys.PublicKey)(c).EncodeBinary(w) +} + +// DecodeBinarySpecific implements WitnessCondition interface allowing to +// deserialize condition-specific data. +func (c *ConditionGroup) DecodeBinarySpecific(r *io.BinReader, _ int) { + (*keys.PublicKey)(c).DecodeBinary(r) +} + +// MarshalJSON implements json.Marshaler interface. +func (c *ConditionGroup) MarshalJSON() ([]byte, error) { + aux := conditionAux{ + Type: c.Type().String(), + Group: (*keys.PublicKey)(c), + } + return json.Marshal(aux) +} + +// Type implements WitnessCondition interface and returns condition type. +func (c ConditionCalledByEntry) Type() WitnessConditionType { + return WitnessCalledByEntry +} + +// Match implements WitnessCondition interface checking whether this condition +// matches given context. +func (c ConditionCalledByEntry) Match(ctx MatchContext) (bool, error) { + entry := ctx.GetEntryScriptHash() + return entry.Equals(ctx.GetCallingScriptHash()) || entry.Equals(ctx.GetCurrentScriptHash()), nil +} + +// EncodeBinary implements WitnessCondition interface allowing to serialize condition. +func (c ConditionCalledByEntry) EncodeBinary(w *io.BinWriter) { + w.WriteB(byte(c.Type())) +} + +// DecodeBinarySpecific implements WitnessCondition interface allowing to +// deserialize condition-specific data. +func (c ConditionCalledByEntry) DecodeBinarySpecific(_ *io.BinReader, _ int) { +} + +// MarshalJSON implements json.Marshaler interface. +func (c ConditionCalledByEntry) MarshalJSON() ([]byte, error) { + aux := conditionAux{ + Type: c.Type().String(), + } + return json.Marshal(aux) +} + +// Type implements WitnessCondition interface and returns condition type. +func (c *ConditionCalledByContract) Type() WitnessConditionType { + return WitnessCalledByContract +} + +// Match implements WitnessCondition interface checking whether this condition +// matches given context. +func (c *ConditionCalledByContract) Match(ctx MatchContext) (bool, error) { + return util.Uint160(*c).Equals(ctx.GetCallingScriptHash()), nil +} + +// EncodeBinary implements WitnessCondition interface allowing to serialize condition. +func (c *ConditionCalledByContract) EncodeBinary(w *io.BinWriter) { + w.WriteB(byte(c.Type())) + w.WriteBytes(c[:]) +} + +// DecodeBinarySpecific implements WitnessCondition interface allowing to +// deserialize condition-specific data. +func (c *ConditionCalledByContract) DecodeBinarySpecific(r *io.BinReader, _ int) { + r.ReadBytes(c[:]) +} + +// MarshalJSON implements json.Marshaler interface. +func (c *ConditionCalledByContract) MarshalJSON() ([]byte, error) { + aux := conditionAux{ + Type: c.Type().String(), + Hash: (*util.Uint160)(c), + } + return json.Marshal(aux) +} + +// Type implements WitnessCondition interface and returns condition type. +func (c *ConditionCalledByGroup) Type() WitnessConditionType { + return WitnessCalledByGroup +} + +// Match implements WitnessCondition interface checking whether this condition +// matches given context. +func (c *ConditionCalledByGroup) Match(ctx MatchContext) (bool, error) { + return ctx.CallingScriptHasGroup((*keys.PublicKey)(c)) +} + +// EncodeBinary implements WitnessCondition interface allowing to serialize condition. +func (c *ConditionCalledByGroup) EncodeBinary(w *io.BinWriter) { + w.WriteB(byte(c.Type())) + (*keys.PublicKey)(c).EncodeBinary(w) +} + +// DecodeBinarySpecific implements WitnessCondition interface allowing to +// deserialize condition-specific data. +func (c *ConditionCalledByGroup) DecodeBinarySpecific(r *io.BinReader, _ int) { + (*keys.PublicKey)(c).DecodeBinary(r) +} + +// MarshalJSON implements json.Marshaler interface. +func (c *ConditionCalledByGroup) MarshalJSON() ([]byte, error) { + aux := conditionAux{ + Type: c.Type().String(), + Group: (*keys.PublicKey)(c), + } + return json.Marshal(aux) +} + +// DecodeBinaryCondition decodes and returns condition from the given binary stream. +func DecodeBinaryCondition(r *io.BinReader) WitnessCondition { + return decodeBinaryCondition(r, MaxConditionNesting) +} + +func decodeBinaryCondition(r *io.BinReader, maxDepth int) WitnessCondition { + if maxDepth <= 0 { + r.Err = errors.New("too many nesting levels") + return nil + } + t := WitnessConditionType(r.ReadB()) + if r.Err != nil { + return nil + } + var res WitnessCondition + switch t { + case WitnessBoolean: + var v ConditionBoolean + res = &v + case WitnessNot: + res = &ConditionNot{} + case WitnessAnd: + res = &ConditionAnd{} + case WitnessOr: + res = &ConditionOr{} + case WitnessScriptHash: + res = &ConditionScriptHash{} + case WitnessGroup: + res = &ConditionGroup{} + case WitnessCalledByEntry: + res = ConditionCalledByEntry{} + case WitnessCalledByContract: + res = &ConditionCalledByContract{} + case WitnessCalledByGroup: + res = &ConditionCalledByGroup{} + default: + r.Err = errors.New("invalid condition type") + return nil + } + res.DecodeBinarySpecific(r, maxDepth) + if r.Err != nil { + return nil + } + return res +} + +func unmarshalArrayOfConditionJSONs(arr []json.RawMessage, maxDepth int) ([]WitnessCondition, error) { + l := len(arr) + if l == 0 { + return nil, errors.New("empty array of conditions") + } + if l >= maxSubitems { + return nil, errors.New("too many elements") + } + res := make([]WitnessCondition, l) + for i := range arr { + v, err := unmarshalConditionJSON(arr[i], maxDepth-1) + if err != nil { + return nil, err + } + res[i] = v + } + return res, nil +} + +// UnmarshalConditionJSON unmarshalls condition from the given JSON data. +func UnmarshalConditionJSON(data []byte) (WitnessCondition, error) { + return unmarshalConditionJSON(data, MaxConditionNesting) +} + +func unmarshalConditionJSON(data []byte, maxDepth int) (WitnessCondition, error) { + if maxDepth <= 0 { + return nil, errors.New("too many nesting levels") + } + aux := &conditionAux{} + err := json.Unmarshal(data, aux) + if err != nil { + return nil, err + } + var res WitnessCondition + switch aux.Type { + case WitnessBoolean.String(): + var v bool + err = json.Unmarshal(aux.Expression, &v) + if err != nil { + return nil, err + } + res = (*ConditionBoolean)(&v) + case WitnessNot.String(): + v, err := unmarshalConditionJSON(aux.Expression, maxDepth-1) + if err != nil { + return nil, err + } + res = &ConditionNot{Condition: v} + case WitnessAnd.String(): + v, err := unmarshalArrayOfConditionJSONs(aux.Expressions, maxDepth) + if err != nil { + return nil, err + } + res = (*ConditionAnd)(&v) + case WitnessOr.String(): + v, err := unmarshalArrayOfConditionJSONs(aux.Expressions, maxDepth) + if err != nil { + return nil, err + } + res = (*ConditionOr)(&v) + case WitnessScriptHash.String(): + if aux.Hash == nil { + return nil, errors.New("no hash specified") + } + res = (*ConditionScriptHash)(aux.Hash) + case WitnessGroup.String(): + if aux.Group == nil { + return nil, errors.New("no group specified") + } + res = (*ConditionGroup)(aux.Group) + case WitnessCalledByEntry.String(): + res = ConditionCalledByEntry{} + case WitnessCalledByContract.String(): + if aux.Hash == nil { + return nil, errors.New("no hash specified") + } + res = (*ConditionCalledByContract)(aux.Hash) + case WitnessCalledByGroup.String(): + if aux.Group == nil { + return nil, errors.New("no group specified") + } + res = (*ConditionCalledByGroup)(aux.Group) + default: + return nil, errors.New("invalid condition type") + } + return res, nil +} diff --git a/pkg/core/transaction/witness_condition_test.go b/pkg/core/transaction/witness_condition_test.go new file mode 100644 index 000000000..22afbd232 --- /dev/null +++ b/pkg/core/transaction/witness_condition_test.go @@ -0,0 +1,325 @@ +package transaction + +import ( + "encoding/json" + "errors" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/crypto/keys" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/stretchr/testify/require" +) + +type InvalidCondition struct{} + +func (c InvalidCondition) Type() WitnessConditionType { + return 0xff +} +func (c InvalidCondition) Match(_ MatchContext) (bool, error) { + return true, nil +} +func (c InvalidCondition) EncodeBinary(w *io.BinWriter) { + w.WriteB(byte(c.Type())) +} +func (c InvalidCondition) DecodeBinarySpecific(r *io.BinReader, _ int) { +} +func (c InvalidCondition) MarshalJSON() ([]byte, error) { + aux := conditionAux{ + Type: c.Type().String(), + } + return json.Marshal(aux) +} + +type condCase struct { + condition WitnessCondition + success bool +} + +func TestWitnessConditionSerDes(t *testing.T) { + var someBool bool + pk, err := keys.NewPrivateKey() + require.NoError(t, err) + var cases = []condCase{ + {(*ConditionBoolean)(&someBool), true}, + {&ConditionNot{(*ConditionBoolean)(&someBool)}, true}, + {&ConditionAnd{(*ConditionBoolean)(&someBool), (*ConditionBoolean)(&someBool)}, true}, + {&ConditionOr{(*ConditionBoolean)(&someBool), (*ConditionBoolean)(&someBool)}, true}, + {&ConditionScriptHash{1, 2, 3}, true}, + {(*ConditionGroup)(pk.PublicKey()), true}, + {ConditionCalledByEntry{}, true}, + {&ConditionCalledByContract{1, 2, 3}, true}, + {(*ConditionCalledByGroup)(pk.PublicKey()), true}, + {InvalidCondition{}, false}, + {&ConditionAnd{}, false}, + {&ConditionOr{}, false}, + {&ConditionNot{&ConditionNot{&ConditionNot{(*ConditionBoolean)(&someBool)}}}, false}, + } + var maxSubCondAnd = &ConditionAnd{} + var maxSubCondOr = &ConditionAnd{} + for i := 0; i < maxSubitems+1; i++ { + *maxSubCondAnd = append(*maxSubCondAnd, (*ConditionBoolean)(&someBool)) + *maxSubCondOr = append(*maxSubCondOr, (*ConditionBoolean)(&someBool)) + } + cases = append(cases, condCase{maxSubCondAnd, false}) + cases = append(cases, condCase{maxSubCondOr, false}) + t.Run("binary", func(t *testing.T) { + for i, c := range cases { + w := io.NewBufBinWriter() + c.condition.EncodeBinary(w.BinWriter) + require.NoError(t, w.Err) + b := w.Bytes() + + r := io.NewBinReaderFromBuf(b) + res := DecodeBinaryCondition(r) + if !c.success { + require.Nil(t, res) + require.Errorf(t, r.Err, "case %d", i) + continue + } + require.NoErrorf(t, r.Err, "case %d", i) + require.Equal(t, c.condition, res) + } + }) + t.Run("json", func(t *testing.T) { + for i, c := range cases { + jj, err := c.condition.MarshalJSON() + require.NoError(t, err) + res, err := UnmarshalConditionJSON(jj) + if !c.success { + require.Errorf(t, err, "case %d, json %s", i, jj) + continue + } + require.NoErrorf(t, err, "case %d, json %s", i, jj) + require.Equal(t, c.condition, res) + } + }) +} + +func TestWitnessConditionZeroDeser(t *testing.T) { + r := io.NewBinReaderFromBuf([]byte{}) + res := DecodeBinaryCondition(r) + require.Nil(t, res) + require.Error(t, r.Err) +} + +func TestWitnessConditionJSONErrors(t *testing.T) { + var cases = []string{ + `[]`, + `{}`, + `{"type":"Boolean"}`, + `{"type":"Not"}`, + `{"type":"And"}`, + `{"type":"Or"}`, + `{"type":"ScriptHash"}`, + `{"type":"Group"}`, + `{"type":"CalledByContract"}`, + `{"type":"CalledByGroup"}`, + `{"type":"Boolean", "expression":42}`, + `{"type":"Not", "expression":true}`, + `{"type":"And", "expressions":[{"type":"CalledByGroup"},{"type":"Not", "expression":true}]}`, + `{"type":"Or", "expressions":{"type":"CalledByGroup"}}`, + `{"type":"Or", "expressions":[{"type":"CalledByGroup"},{"type":"Not", "expression":false}]}`, + `{"type":"ScriptHash", "hash":"1122"}`, + `{"type":"Group", "group":"032211"}`, + `{"type":"CalledByContract", "hash":"1122"}`, + `{"type":"CalledByGroup", "group":"032211"}`, + } + for i := range cases { + res, err := UnmarshalConditionJSON([]byte(cases[i])) + require.Errorf(t, err, "case %d, json %s", i, cases[i]) + require.Nil(t, res) + } +} + +type TestMC struct { + calling util.Uint160 + current util.Uint160 + entry util.Uint160 + goodKey *keys.PublicKey + badKey *keys.PublicKey +} + +func (t *TestMC) GetCallingScriptHash() util.Uint160 { + return t.calling +} +func (t *TestMC) GetCurrentScriptHash() util.Uint160 { + return t.current +} +func (t *TestMC) GetEntryScriptHash() util.Uint160 { + return t.entry +} +func (t *TestMC) CallingScriptHasGroup(k *keys.PublicKey) (bool, error) { + res, err := t.CurrentScriptHasGroup(k) + return !res, err // To differentiate from current we invert the logic value. +} +func (t *TestMC) CurrentScriptHasGroup(k *keys.PublicKey) (bool, error) { + if k.Equal(t.goodKey) { + return true, nil + } + if k.Equal(t.badKey) { + return false, errors.New("baaad key") + } + return false, nil +} + +func TestWitnessConditionMatch(t *testing.T) { + pkGood, err := keys.NewPrivateKey() + require.NoError(t, err) + pkBad, err := keys.NewPrivateKey() + require.NoError(t, err) + pkNeutral, err := keys.NewPrivateKey() + require.NoError(t, err) + entrySC := util.Uint160{1, 2, 3} + currentSC := util.Uint160{4, 5, 6} + tmc := &TestMC{ + calling: entrySC, + entry: entrySC, + current: currentSC, + goodKey: pkGood.PublicKey(), + badKey: pkBad.PublicKey(), + } + + t.Run("boolean", func(t *testing.T) { + var b bool + var c = (*ConditionBoolean)(&b) + res, err := c.Match(tmc) + require.NoError(t, err) + require.False(t, res) + b = true + res, err = c.Match(tmc) + require.NoError(t, err) + require.True(t, res) + }) + t.Run("not", func(t *testing.T) { + var b bool + var cInner = (*ConditionBoolean)(&b) + var cInner2 = (*ConditionGroup)(pkBad.PublicKey()) + var c = &ConditionNot{cInner} + var c2 = &ConditionNot{cInner2} + + res, err := c.Match(tmc) + require.NoError(t, err) + require.True(t, res) + b = true + res, err = c.Match(tmc) + require.NoError(t, err) + require.False(t, res) + _, err = c2.Match(tmc) + require.Error(t, err) + }) + t.Run("and", func(t *testing.T) { + var bFalse, bTrue bool + var cInnerFalse = (*ConditionBoolean)(&bFalse) + var cInnerTrue = (*ConditionBoolean)(&bTrue) + var cInnerBad = (*ConditionGroup)(pkBad.PublicKey()) + var c = &ConditionAnd{cInnerTrue, cInnerFalse, cInnerFalse} + var cBad = &ConditionAnd{cInnerTrue, cInnerBad} + + bTrue = true + res, err := c.Match(tmc) + require.NoError(t, err) + require.False(t, res) + bFalse = true + res, err = c.Match(tmc) + require.NoError(t, err) + require.True(t, res) + + _, err = cBad.Match(tmc) + require.Error(t, err) + }) + t.Run("or", func(t *testing.T) { + var bFalse, bTrue bool + var cInnerFalse = (*ConditionBoolean)(&bFalse) + var cInnerTrue = (*ConditionBoolean)(&bTrue) + var cInnerBad = (*ConditionGroup)(pkBad.PublicKey()) + var c = &ConditionOr{cInnerTrue, cInnerFalse, cInnerFalse} + var cBad = &ConditionOr{cInnerTrue, cInnerBad} + + bTrue = true + res, err := c.Match(tmc) + require.NoError(t, err) + require.True(t, res) + bTrue = false + res, err = c.Match(tmc) + require.NoError(t, err) + require.False(t, res) + + _, err = cBad.Match(tmc) + require.Error(t, err) + }) + t.Run("script hash", func(t *testing.T) { + var cEntry = (*ConditionScriptHash)(&entrySC) + var cCurrent = (*ConditionScriptHash)(¤tSC) + + res, err := cEntry.Match(tmc) + require.NoError(t, err) + require.False(t, res) + res, err = cCurrent.Match(tmc) + require.NoError(t, err) + require.True(t, res) + }) + t.Run("group", func(t *testing.T) { + var cBad = (*ConditionGroup)(pkBad.PublicKey()) + var cGood = (*ConditionGroup)(pkGood.PublicKey()) + var cNeutral = (*ConditionGroup)(pkNeutral.PublicKey()) + + res, err := cGood.Match(tmc) + require.NoError(t, err) + require.True(t, res) + + res, err = cNeutral.Match(tmc) + require.NoError(t, err) + require.False(t, res) + + _, err = cBad.Match(tmc) + require.Error(t, err) + }) + t.Run("called by entry", func(t *testing.T) { + var c = ConditionCalledByEntry{} + + res, err := c.Match(tmc) + require.NoError(t, err) + require.True(t, res) + + tmc2 := *tmc + tmc2.entry = util.Uint160{0, 9, 8} + res, err = c.Match(&tmc2) + require.NoError(t, err) + require.False(t, res) + + tmc3 := *tmc + tmc3.calling = util.Uint160{} + tmc3.current = tmc3.entry + res, err = c.Match(&tmc3) + require.NoError(t, err) + require.True(t, res) + }) + t.Run("called by contract", func(t *testing.T) { + var cEntry = (*ConditionCalledByContract)(&entrySC) + var cCurrent = (*ConditionCalledByContract)(¤tSC) + + res, err := cEntry.Match(tmc) + require.NoError(t, err) + require.True(t, res) + res, err = cCurrent.Match(tmc) + require.NoError(t, err) + require.False(t, res) + }) + t.Run("called by group", func(t *testing.T) { + var cBad = (*ConditionCalledByGroup)(pkBad.PublicKey()) + var cGood = (*ConditionCalledByGroup)(pkGood.PublicKey()) + var cNeutral = (*ConditionCalledByGroup)(pkNeutral.PublicKey()) + + res, err := cGood.Match(tmc) + require.NoError(t, err) + require.False(t, res) + + res, err = cNeutral.Match(tmc) + require.NoError(t, err) + require.True(t, res) + + _, err = cBad.Match(tmc) + require.Error(t, err) + }) +} diff --git a/pkg/core/transaction/witness_rule.go b/pkg/core/transaction/witness_rule.go new file mode 100644 index 000000000..75ab4abcd --- /dev/null +++ b/pkg/core/transaction/witness_rule.go @@ -0,0 +1,86 @@ +package transaction + +import ( + "encoding/json" + "errors" + + "github.com/nspcc-dev/neo-go/pkg/io" +) + +//go:generate stringer -type=WitnessAction -linecomment + +// WitnessAction represents an action to perform in WitnessRule if +// WitnessCondition matches. +type WitnessAction byte + +const ( + // WitnessDeny rejects current witness if condition is met. + WitnessDeny WitnessAction = 0 // Deny + // WitnessAllow approves current witness if condition is met. + WitnessAllow WitnessAction = 1 // Allow +) + +// WitnessRule represents a single rule for Rules witness scope. +type WitnessRule struct { + Action WitnessAction `json:"action"` + Condition WitnessCondition `json:"condition"` +} + +type witnessRuleAux struct { + Action string `json:"action"` + Condition json.RawMessage `json:"condition"` +} + +// EncodeBinary implements Serializable interface. +func (w *WitnessRule) EncodeBinary(bw *io.BinWriter) { + bw.WriteB(byte(w.Action)) + w.Condition.EncodeBinary(bw) +} + +// DecodeBinary implements Serializable interface. +func (w *WitnessRule) DecodeBinary(br *io.BinReader) { + w.Action = WitnessAction(br.ReadB()) + if br.Err == nil && w.Action != WitnessDeny && w.Action != WitnessAllow { + br.Err = errors.New("unknown witness rule action") + return + } + w.Condition = DecodeBinaryCondition(br) +} + +// UnmarshalJSON implements json.Unmarshaler interface. +func (w *WitnessRule) MarshalJSON() ([]byte, error) { + cond, err := w.Condition.MarshalJSON() + if err != nil { + return nil, err + } + aux := &witnessRuleAux{ + Action: w.Action.String(), + Condition: cond, + } + return json.Marshal(aux) +} + +// UnmarshalJSON implements json.Unmarshaler interface. +func (w *WitnessRule) UnmarshalJSON(data []byte) error { + aux := &witnessRuleAux{} + err := json.Unmarshal(data, aux) + if err != nil { + return err + } + var action WitnessAction + switch aux.Action { + case WitnessDeny.String(): + action = WitnessDeny + case WitnessAllow.String(): + action = WitnessAllow + default: + return errors.New("unknown witness rule action") + } + cond, err := UnmarshalConditionJSON(aux.Condition) + if err != nil { + return err + } + w.Action = action + w.Condition = cond + return nil +} diff --git a/pkg/core/transaction/witness_rule_test.go b/pkg/core/transaction/witness_rule_test.go new file mode 100644 index 000000000..70bf9c38d --- /dev/null +++ b/pkg/core/transaction/witness_rule_test.go @@ -0,0 +1,56 @@ +package transaction + +import ( + "encoding/json" + "testing" + + "github.com/nspcc-dev/neo-go/internal/testserdes" + "github.com/stretchr/testify/require" +) + +func TestWitnessRuleSerDes(t *testing.T) { + var b bool + expected := &WitnessRule{ + Action: WitnessAllow, + Condition: (*ConditionBoolean)(&b), + } + actual := &WitnessRule{} + testserdes.EncodeDecodeBinary(t, expected, actual) +} + +func TestWitnessRuleSerDesBad(t *testing.T) { + var b bool + bad := &WitnessRule{ + Action: 0xff, + Condition: (*ConditionBoolean)(&b), + } + badB, err := testserdes.EncodeBinary(bad) + require.NoError(t, err) + err = testserdes.DecodeBinary(badB, &WitnessRule{}) + require.Error(t, err) +} + +func TestWitnessRuleJSON(t *testing.T) { + var b bool + expected := &WitnessRule{ + Action: WitnessDeny, + Condition: (*ConditionBoolean)(&b), + } + actual := &WitnessRule{} + testserdes.MarshalUnmarshalJSON(t, expected, actual) +} + +func TestWitnessRuleBadJSON(t *testing.T) { + var cases = []string{ + `{}`, + `[]`, + `{"action":"Allow"}`, + `{"action":"Unknown","condition":{"type":"Boolean", "expression":true}}`, + `{"action":"Allow","condition":{"type":"Boolean", "expression":42}}`, + } + for i := range cases { + actual := &WitnessRule{} + err := json.Unmarshal([]byte(cases[i]), actual) + require.Errorf(t, err, "case %d, json %s", i, cases[i]) + } +} diff --git a/pkg/core/transaction/witness_scope.go b/pkg/core/transaction/witness_scope.go index 06d552e17..8950290cd 100644 --- a/pkg/core/transaction/witness_scope.go +++ b/pkg/core/transaction/witness_scope.go @@ -22,6 +22,8 @@ const ( CustomContracts WitnessScope = 0x10 // CustomGroups define custom pubkey for group members. CustomGroups WitnessScope = 0x20 + // Rules is a set of conditions with boolean operators. + Rules WitnessScope = 0x40 // Global allows this witness in all contexts (default Neo2 behavior). // This cannot be combined with other flags. Global WitnessScope = 0x80 @@ -42,6 +44,7 @@ func ScopesFromString(s string) (WitnessScope, error) { CalledByEntry.String(): CalledByEntry, CustomContracts.String(): CustomContracts, CustomGroups.String(): CustomGroups, + Rules.String(): Rules, None.String(): None, } var isGlobal bool @@ -61,6 +64,16 @@ func ScopesFromString(s string) (WitnessScope, error) { return result, nil } +func appendScopeString(str string, scopes WitnessScope, scope WitnessScope) string { + if scopes&scope != 0 { + if len(str) != 0 { + str += ", " + } + str += scope.String() + } + return str +} + // scopesToString converts witness scope to it's string representation. It uses // `, ` to separate scope names. func scopesToString(scopes WitnessScope) string { @@ -68,21 +81,10 @@ func scopesToString(scopes WitnessScope) string { return scopes.String() } var res string - if scopes&CalledByEntry != 0 { - res = CalledByEntry.String() - } - if scopes&CustomContracts != 0 { - if len(res) != 0 { - res += ", " - } - res += CustomContracts.String() - } - if scopes&CustomGroups != 0 { - if len(res) != 0 { - res += ", " - } - res += CustomGroups.String() - } + res = appendScopeString(res, scopes, CalledByEntry) + res = appendScopeString(res, scopes, CustomContracts) + res = appendScopeString(res, scopes, CustomGroups) + res = appendScopeString(res, scopes, Rules) return res } diff --git a/pkg/core/transaction/witness_scope_string.go b/pkg/core/transaction/witness_scope_string.go index 6934757fa..5f6666af1 100644 --- a/pkg/core/transaction/witness_scope_string.go +++ b/pkg/core/transaction/witness_scope_string.go @@ -12,6 +12,7 @@ func _() { _ = x[CalledByEntry-1] _ = x[CustomContracts-16] _ = x[CustomGroups-32] + _ = x[Rules-64] _ = x[Global-128] } @@ -19,7 +20,8 @@ const ( _WitnessScope_name_0 = "NoneCalledByEntry" _WitnessScope_name_1 = "CustomContracts" _WitnessScope_name_2 = "CustomGroups" - _WitnessScope_name_3 = "Global" + _WitnessScope_name_3 = "Rules" + _WitnessScope_name_4 = "Global" ) var ( @@ -28,14 +30,16 @@ var ( func (i WitnessScope) String() string { switch { - case i <= 1: + case 0 <= i && i <= 1: return _WitnessScope_name_0[_WitnessScope_index_0[i]:_WitnessScope_index_0[i+1]] case i == 16: return _WitnessScope_name_1 case i == 32: return _WitnessScope_name_2 - case i == 128: + case i == 64: return _WitnessScope_name_3 + case i == 128: + return _WitnessScope_name_4 default: return "WitnessScope(" + strconv.FormatInt(int64(i), 10) + ")" } diff --git a/pkg/core/transaction/witnessaction_string.go b/pkg/core/transaction/witnessaction_string.go new file mode 100644 index 000000000..f2f7ddfed --- /dev/null +++ b/pkg/core/transaction/witnessaction_string.go @@ -0,0 +1,24 @@ +// Code generated by "stringer -type=WitnessAction -linecomment"; DO NOT EDIT. + +package transaction + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[WitnessDeny-0] + _ = x[WitnessAllow-1] +} + +const _WitnessAction_name = "DenyAllow" + +var _WitnessAction_index = [...]uint8{0, 4, 9} + +func (i WitnessAction) String() string { + if i >= WitnessAction(len(_WitnessAction_index)-1) { + return "WitnessAction(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _WitnessAction_name[_WitnessAction_index[i]:_WitnessAction_index[i+1]] +} diff --git a/pkg/core/transaction/witnessconditiontype_string.go b/pkg/core/transaction/witnessconditiontype_string.go new file mode 100644 index 000000000..96e0d8a38 --- /dev/null +++ b/pkg/core/transaction/witnessconditiontype_string.go @@ -0,0 +1,50 @@ +// Code generated by "stringer -type=WitnessConditionType -linecomment"; DO NOT EDIT. + +package transaction + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[WitnessBoolean-0] + _ = x[WitnessNot-1] + _ = x[WitnessAnd-2] + _ = x[WitnessOr-3] + _ = x[WitnessScriptHash-24] + _ = x[WitnessGroup-25] + _ = x[WitnessCalledByEntry-32] + _ = x[WitnessCalledByContract-40] + _ = x[WitnessCalledByGroup-41] +} + +const ( + _WitnessConditionType_name_0 = "BooleanNotAndOr" + _WitnessConditionType_name_1 = "ScriptHashGroup" + _WitnessConditionType_name_2 = "CalledByEntry" + _WitnessConditionType_name_3 = "CalledByContractCalledByGroup" +) + +var ( + _WitnessConditionType_index_0 = [...]uint8{0, 7, 10, 13, 15} + _WitnessConditionType_index_1 = [...]uint8{0, 10, 15} + _WitnessConditionType_index_3 = [...]uint8{0, 16, 29} +) + +func (i WitnessConditionType) String() string { + switch { + case 0 <= i && i <= 3: + return _WitnessConditionType_name_0[_WitnessConditionType_index_0[i]:_WitnessConditionType_index_0[i+1]] + case 24 <= i && i <= 25: + i -= 24 + return _WitnessConditionType_name_1[_WitnessConditionType_index_1[i]:_WitnessConditionType_index_1[i+1]] + case i == 32: + return _WitnessConditionType_name_2 + case 40 <= i && i <= 41: + i -= 40 + return _WitnessConditionType_name_3[_WitnessConditionType_index_3[i]:_WitnessConditionType_index_3[i+1]] + default: + return "WitnessConditionType(" + strconv.FormatInt(int64(i), 10) + ")" + } +} diff --git a/pkg/smartcontract/manifest/group.go b/pkg/smartcontract/manifest/group.go index 920f17d38..1e5a942e2 100644 --- a/pkg/smartcontract/manifest/group.go +++ b/pkg/smartcontract/manifest/group.go @@ -64,6 +64,15 @@ func (g Groups) AreValid(h util.Uint160) error { return nil } +func (g Groups) Contains(k *keys.PublicKey) bool { + for i := range g { + if k.Equal(g[i].PublicKey) { + return true + } + } + return false +} + // MarshalJSON implements json.Marshaler interface. func (g *Group) MarshalJSON() ([]byte, error) { aux := &groupAux{ diff --git a/pkg/smartcontract/manifest/group_test.go b/pkg/smartcontract/manifest/group_test.go index 464ada130..6c50832c8 100644 --- a/pkg/smartcontract/manifest/group_test.go +++ b/pkg/smartcontract/manifest/group_test.go @@ -41,3 +41,17 @@ func TestGroupsAreValid(t *testing.T) { gps = Groups{gcorrect, gcorrect} require.Error(t, gps.AreValid(h)) } + +func TestGroupsContains(t *testing.T) { + priv, err := keys.NewPrivateKey() + require.NoError(t, err) + priv2, err := keys.NewPrivateKey() + require.NoError(t, err) + priv3, err := keys.NewPrivateKey() + require.NoError(t, err) + g1 := Group{priv.PublicKey(), nil} + g2 := Group{priv2.PublicKey(), nil} + gps := Groups{g1, g2} + require.True(t, gps.Contains(priv2.PublicKey())) + require.False(t, gps.Contains(priv3.PublicKey())) +}