diff --git a/pkg/core/transaction/transaction.go b/pkg/core/transaction/transaction.go index ff9f6973c..2b26be5c1 100644 --- a/pkg/core/transaction/transaction.go +++ b/pkg/core/transaction/transaction.go @@ -133,6 +133,18 @@ func (t *Transaction) HasAttribute(typ AttrType) bool { return false } +// GetAttributes returns the list of transaction's attributes of the given type. +// Returns nil in case if attributes not found. +func (t *Transaction) GetAttributes(typ AttrType) []Attribute { + var result []Attribute + for _, attr := range t.Attributes { + if attr.Type == typ { + result = append(result, attr) + } + } + return result +} + // decodeHashableFields decodes the fields that are used for signing the // transaction, which are all fields except the scripts. func (t *Transaction) decodeHashableFields(br *io.BinReader) { diff --git a/pkg/core/transaction/transaction_test.go b/pkg/core/transaction/transaction_test.go index 3741dc987..fd213accd 100644 --- a/pkg/core/transaction/transaction_test.go +++ b/pkg/core/transaction/transaction_test.go @@ -7,12 +7,13 @@ import ( "math" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestWitnessEncodeDecode(t *testing.T) { @@ -215,3 +216,35 @@ func TestTransaction_isValid(t *testing.T) { require.True(t, errors.Is(tx.isValid(), ErrEmptyScript)) }) } + +func TestTransaction_GetAttributes(t *testing.T) { + attributesTypes := []AttrType{ + HighPriority, + OracleResponseT, + NotValidBeforeT, + } + t.Run("no attributes", func(t *testing.T) { + tx := new(Transaction) + for _, typ := range attributesTypes { + require.Nil(t, tx.GetAttributes(typ)) + } + }) + t.Run("single attributes", func(t *testing.T) { + attrs := make([]Attribute, len(attributesTypes)) + for i, typ := range attributesTypes { + attrs[i] = Attribute{Type: typ} + } + tx := &Transaction{Attributes: attrs} + for _, typ := range attributesTypes { + require.Equal(t, []Attribute{{Type: typ}}, tx.GetAttributes(typ)) + } + }) + t.Run("multiple attributes", func(t *testing.T) { + typ := AttrType(ReservedLowerBound + 1) + conflictsAttrs := []Attribute{{Type: typ}, {Type: typ}} + tx := Transaction{ + Attributes: append([]Attribute{{Type: HighPriority}}, conflictsAttrs...), + } + require.Equal(t, conflictsAttrs, tx.GetAttributes(typ)) + }) +}