diff --git a/pkg/core/native/management.go b/pkg/core/native/management.go index a5eb71c52..edb9236a8 100644 --- a/pkg/core/native/management.go +++ b/pkg/core/native/management.go @@ -33,6 +33,8 @@ type Management struct { mtx sync.RWMutex contracts map[util.Uint160]*state.Contract + // nep17 is a map of NEP17-compliant contracts which is updated with every PostPersist. + nep17 map[util.Uint160]struct{} } const ( @@ -63,6 +65,7 @@ func newManagement() *Management { var m = &Management{ ContractMD: *interop.NewContractMD(nativenames.Management, managementContractID), contracts: make(map[util.Uint160]*state.Contract), + nep17: make(map[util.Uint160]struct{}), } defer m.UpdateHash() @@ -471,6 +474,9 @@ func (m *Management) OnPersist(ic *interop.Context) error { } m.mtx.Lock() m.contracts[md.Hash] = cs + if md.Manifest.IsStandardSupported(manifest.NEP17StandardName) { + m.nep17[md.Hash] = struct{}{} + } m.mtx.Unlock() } @@ -492,6 +498,9 @@ func (m *Management) InitializeCache(d dao.DAO) error { return } m.contracts[cs.Hash] = cs + if cs.Manifest.IsStandardSupported(manifest.NEP17StandardName) { + m.nep17[cs.Hash] = struct{}{} + } }) return initErr } @@ -507,14 +516,33 @@ func (m *Management) PostPersist(ic *interop.Context) error { if err != nil { // Contract was destroyed. delete(m.contracts, h) + delete(m.nep17, h) continue } m.contracts[h] = newCs + if newCs.Manifest.IsStandardSupported(manifest.NEP17StandardName) { + m.nep17[h] = struct{}{} + } else { + delete(m.nep17, h) + } } m.mtx.Unlock() return nil } +// GetNEP17Contracts returns hashes of all deployed contracts that support NEP17 standard. The list +// is updated every PostPersist, so until PostPersist is called, the result for the previous block +// is returned. +func (m *Management) GetNEP17Contracts() []util.Uint160 { + m.mtx.RLock() + result := make([]util.Uint160, 0, len(m.nep17)) + for h := range m.nep17 { + result = append(result, h) + } + m.mtx.RUnlock() + return result +} + // Initialize implements Contract interface. func (m *Management) Initialize(ic *interop.Context) error { if err := setIntWithKey(m.ID, ic.DAO, keyMinimumDeploymentFee, defaultMinimumDeploymentFee); err != nil { diff --git a/pkg/core/native/management_test.go b/pkg/core/native/management_test.go index fedb9ed0e..ca71e4488 100644 --- a/pkg/core/native/management_test.go +++ b/pkg/core/native/management_test.go @@ -83,3 +83,50 @@ func TestManagement_Initialize(t *testing.T) { require.Error(t, mgmt.InitializeCache(d)) }) } + +func TestManagement_GetNEP17Contracts(t *testing.T) { + mgmt := newManagement() + d := dao.NewCached(dao.NewSimple(storage.NewMemoryStore(), false)) + err := mgmt.Initialize(&interop.Context{DAO: d}) + require.NoError(t, err) + + require.Empty(t, mgmt.GetNEP17Contracts()) + + // Deploy NEP17 contract + script := []byte{byte(opcode.RET)} + sender := util.Uint160{1, 2, 3} + ne, err := nef.NewFile(script) + require.NoError(t, err) + manif := manifest.NewManifest("Test") + manif.ABI.Methods = append(manif.ABI.Methods, manifest.Method{ + Name: "dummy", + ReturnType: smartcontract.VoidType, + Parameters: []manifest.Parameter{}, + }) + manif.SupportedStandards = []string{manifest.NEP17StandardName} + c1, err := mgmt.Deploy(d, sender, ne, manif) + require.NoError(t, err) + + // PostPersist is not yet called, thus no NEP17 contracts are expected + require.Empty(t, mgmt.GetNEP17Contracts()) + + // Call PostPersist, check c1 contract hash is returned + require.NoError(t, mgmt.PostPersist(&interop.Context{DAO: d})) + require.Equal(t, []util.Uint160{c1.Hash}, mgmt.GetNEP17Contracts()) + + // Update contract + manif.ABI.Methods = append(manif.ABI.Methods, manifest.Method{ + Name: "dummy2", + ReturnType: smartcontract.VoidType, + Parameters: []manifest.Parameter{}, + }) + c2, err := mgmt.Update(d, c1.Hash, ne, manif) + require.NoError(t, err) + + // No changes expected before PostPersist call. + require.Equal(t, []util.Uint160{c1.Hash}, mgmt.GetNEP17Contracts()) + + // Call PostPersist, check c2 contract hash is returned + require.NoError(t, mgmt.PostPersist(&interop.Context{DAO: d})) + require.Equal(t, []util.Uint160{c2.Hash}, mgmt.GetNEP17Contracts()) +} diff --git a/pkg/core/native_management_test.go b/pkg/core/native_management_test.go index 57ec4c54e..82f105581 100644 --- a/pkg/core/native_management_test.go +++ b/pkg/core/native_management_test.go @@ -607,3 +607,18 @@ func TestMinimumDeploymentFee(t *testing.T) { testGetSet(t, chain, chain.contracts.Management.Hash, "MinimumDeploymentFee", 10_00000000, 0, 0) } + +func TestManagement_GetNEP17Contracts(t *testing.T) { + t.Run("empty chain", func(t *testing.T) { + chain := newTestChain(t) + require.ElementsMatch(t, []util.Uint160{chain.contracts.NEO.Hash, chain.contracts.GAS.Hash}, chain.contracts.Management.GetNEP17Contracts()) + }) + + t.Run("test chain", func(t *testing.T) { + chain := newTestChain(t) + initBasicChain(t, chain) + rublesHash, err := chain.GetContractScriptHash(1) + require.NoError(t, err) + require.ElementsMatch(t, []util.Uint160{chain.contracts.NEO.Hash, chain.contracts.GAS.Hash, rublesHash}, chain.contracts.Management.GetNEP17Contracts()) + }) +} diff --git a/pkg/smartcontract/manifest/manifest.go b/pkg/smartcontract/manifest/manifest.go index 8fde178dc..65582fa58 100644 --- a/pkg/smartcontract/manifest/manifest.go +++ b/pkg/smartcontract/manifest/manifest.go @@ -119,6 +119,16 @@ func (m *Manifest) IsValid(hash util.Uint160) error { return Permissions(m.Permissions).AreValid() } +// IsStandardSupported denotes whether the specified standard supported by the contract. +func (m *Manifest) IsStandardSupported(standard string) bool { + for _, st := range m.SupportedStandards { + if st == standard { + return true + } + } + return false +} + // ToStackItem converts Manifest to stackitem.Item. func (m *Manifest) ToStackItem() (stackitem.Item, error) { groups := make([]stackitem.Item, len(m.Groups)) diff --git a/pkg/smartcontract/manifest/manifest_test.go b/pkg/smartcontract/manifest/manifest_test.go index 4ac0589b7..82fcc13b4 100644 --- a/pkg/smartcontract/manifest/manifest_test.go +++ b/pkg/smartcontract/manifest/manifest_test.go @@ -429,3 +429,15 @@ func TestExtraToStackItem(t *testing.T) { require.Equal(t, tc.expected, string(actual)) } } + +func TestManifest_IsStandardSupported(t *testing.T) { + m := &Manifest{ + SupportedStandards: []string{NEP17StandardName, NEP17Payable, NEP11Payable}, + } + for _, st := range m.SupportedStandards { + require.True(t, m.IsStandardSupported(st)) + } + require.False(t, m.IsStandardSupported(NEP11StandardName)) + require.False(t, m.IsStandardSupported("")) + require.False(t, m.IsStandardSupported("unknown standard")) +}