From 08cc04c3d5490385ae8827a6c14f51de62f45c26 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Mon, 15 Jun 2020 21:13:32 +0300 Subject: [PATCH] core: add native policy contract part of #904 --- pkg/core/blockchain.go | 13 +- pkg/core/helper_test.go | 11 + pkg/core/native/blocked_accounts.go | 53 +++ pkg/core/native/blocked_accounts_test.go | 53 +++ pkg/core/native/contract.go | 7 +- pkg/core/native/policy.go | 393 +++++++++++++++++++++++ pkg/core/native_policy_test.go | 260 +++++++++++++++ pkg/network/compress.go | 5 +- pkg/network/message.go | 10 +- pkg/network/payload/payload.go | 3 + 10 files changed, 794 insertions(+), 14 deletions(-) create mode 100644 pkg/core/native/blocked_accounts.go create mode 100644 pkg/core/native/blocked_accounts_test.go create mode 100644 pkg/core/native/policy.go create mode 100644 pkg/core/native_policy_test.go diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 1144b43f9..12ff6a22a 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -1088,9 +1088,8 @@ func (bc *Blockchain) CalculateClaimable(value int64, startHeight, endHeight uin } // FeePerByte returns transaction network fee per byte. -// TODO: should be implemented as part of PolicyContract func (bc *Blockchain) FeePerByte() util.Fixed8 { - return util.Fixed8(1000) + return util.Fixed8(bc.contracts.Policy.GetFeePerByteInternal(bc.dao)) } // GetMemPool returns the memory pool of the blockchain. @@ -1101,8 +1100,9 @@ func (bc *Blockchain) GetMemPool() *mempool.Pool { // ApplyPolicyToTxSet applies configured policies to given transaction set. It // expects slice to be ordered by fee and returns a subslice of it. func (bc *Blockchain) ApplyPolicyToTxSet(txes []*transaction.Transaction) []*transaction.Transaction { - if bc.config.MaxTransactionsPerBlock != 0 && len(txes) > bc.config.MaxTransactionsPerBlock { - txes = txes[:bc.config.MaxTransactionsPerBlock] + maxTx := bc.contracts.Policy.GetMaxTransactionsPerBlockInternal(bc.dao) + if maxTx != 0 && len(txes) > int(maxTx) { + txes = txes[:maxTx] } return txes } @@ -1203,6 +1203,11 @@ func (bc *Blockchain) PoolTx(t *transaction.Transaction) error { return err } // Policying. + if ok, err := bc.contracts.Policy.CheckPolicy(bc.newInteropContext(trigger.Application, bc.dao, nil, t), t); err != nil { + return err + } else if !ok { + return ErrPolicy + } if err := bc.memPool.Add(t, bc); err != nil { switch err { case mempool.ErrOOM: diff --git a/pkg/core/helper_test.go b/pkg/core/helper_test.go index a334baa33..f7b840cc5 100644 --- a/pkg/core/helper_test.go +++ b/pkg/core/helper_test.go @@ -387,6 +387,17 @@ func addSender(txs ...*transaction.Transaction) error { return nil } +func addCosigners(txs ...*transaction.Transaction) { + for _, tx := range txs { + tx.Cosigners = []transaction.Cosigner{{ + Account: neoOwner, + Scopes: transaction.CalledByEntry, + AllowedContracts: nil, + AllowedGroups: nil, + }} + } +} + func signTx(bc *Blockchain, txs ...*transaction.Transaction) error { validators, err := getValidators(bc.config) if err != nil { diff --git a/pkg/core/native/blocked_accounts.go b/pkg/core/native/blocked_accounts.go new file mode 100644 index 000000000..fb27ce34f --- /dev/null +++ b/pkg/core/native/blocked_accounts.go @@ -0,0 +1,53 @@ +package native + +import ( + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" +) + +// BlockedAccounts represents a slice of blocked accounts hashes. +type BlockedAccounts []util.Uint160 + +// Bytes returns serialized BlockedAccounts. +func (ba *BlockedAccounts) Bytes() []byte { + w := io.NewBufBinWriter() + ba.EncodeBinary(w.BinWriter) + if w.Err != nil { + panic(w.Err) + } + return w.Bytes() +} + +// EncodeBinary implements io.Serializable interface. +func (ba *BlockedAccounts) EncodeBinary(w *io.BinWriter) { + w.WriteArray(*ba) +} + +// BlockedAccountsFromBytes converts serialized BlockedAccounts to structure. +func BlockedAccountsFromBytes(b []byte) (BlockedAccounts, error) { + ba := new(BlockedAccounts) + if len(b) == 0 { + return *ba, nil + } + r := io.NewBinReaderFromBuf(b) + ba.DecodeBinary(r) + if r.Err != nil { + return nil, r.Err + } + return *ba, nil +} + +// DecodeBinary implements io.Serializable interface. +func (ba *BlockedAccounts) DecodeBinary(r *io.BinReader) { + r.ReadArray(ba) +} + +// ToStackItem converts BlockedAccounts to stackitem.Item +func (ba *BlockedAccounts) ToStackItem() stackitem.Item { + result := make([]stackitem.Item, len(*ba)) + for i, account := range *ba { + result[i] = stackitem.NewByteArray(account.BytesLE()) + } + return stackitem.NewArray(result) +} diff --git a/pkg/core/native/blocked_accounts_test.go b/pkg/core/native/blocked_accounts_test.go new file mode 100644 index 000000000..ac2a35a5a --- /dev/null +++ b/pkg/core/native/blocked_accounts_test.go @@ -0,0 +1,53 @@ +package native + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/stretchr/testify/require" +) + +func TestEncodeDecodeBinary(t *testing.T) { + expected := &BlockedAccounts{ + util.Uint160{1, 2, 3}, + util.Uint160{4, 5, 6}, + } + actual := new(BlockedAccounts) + testserdes.EncodeDecodeBinary(t, expected, actual) + + expected = &BlockedAccounts{} + actual = new(BlockedAccounts) + testserdes.EncodeDecodeBinary(t, expected, actual) +} + +func TestBytesFromBytes(t *testing.T) { + expected := BlockedAccounts{ + util.Uint160{1, 2, 3}, + util.Uint160{4, 5, 6}, + } + actual, err := BlockedAccountsFromBytes(expected.Bytes()) + require.NoError(t, err) + require.Equal(t, expected, actual) + + expected = BlockedAccounts{} + actual, err = BlockedAccountsFromBytes(expected.Bytes()) + require.NoError(t, err) + require.Equal(t, expected, actual) +} + +func TestToStackItem(t *testing.T) { + u1 := util.Uint160{1, 2, 3} + u2 := util.Uint160{4, 5, 6} + expected := BlockedAccounts{u1, u2} + actual := stackitem.NewArray([]stackitem.Item{ + stackitem.NewByteArray(u1.BytesLE()), + stackitem.NewByteArray(u2.BytesLE()), + }) + require.Equal(t, expected.ToStackItem(), actual) + + expected = BlockedAccounts{} + actual = stackitem.NewArray([]stackitem.Item{}) + require.Equal(t, expected.ToStackItem(), actual) +} diff --git a/pkg/core/native/contract.go b/pkg/core/native/contract.go index e9dc4f418..47629e485 100644 --- a/pkg/core/native/contract.go +++ b/pkg/core/native/contract.go @@ -16,6 +16,7 @@ import ( type Contracts struct { NEO *NEO GAS *GAS + Policy *Policy Contracts []interop.Contract // persistScript is vm script which executes "onPersist" method of every native contract. persistScript []byte @@ -41,7 +42,7 @@ func (cs *Contracts) ByID(id uint32) interop.Contract { return nil } -// NewContracts returns new set of native contracts with new GAS and NEO +// NewContracts returns new set of native contracts with new GAS, NEO and Policy // contracts. func NewContracts() *Contracts { cs := new(Contracts) @@ -55,6 +56,10 @@ func NewContracts() *Contracts { cs.Contracts = append(cs.Contracts, gas) cs.NEO = neo cs.Contracts = append(cs.Contracts, neo) + + policy := newPolicy() + cs.Policy = policy + cs.Contracts = append(cs.Contracts, policy) return cs } diff --git a/pkg/core/native/policy.go b/pkg/core/native/policy.go new file mode 100644 index 000000000..6f379b1be --- /dev/null +++ b/pkg/core/native/policy.go @@ -0,0 +1,393 @@ +package native + +import ( + "encoding/binary" + "math/big" + "sort" + + "github.com/nspcc-dev/neo-go/pkg/core/dao" + "github.com/nspcc-dev/neo-go/pkg/core/interop" + "github.com/nspcc-dev/neo-go/pkg/core/interop/runtime" + "github.com/nspcc-dev/neo-go/pkg/core/state" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/network/payload" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/pkg/errors" +) + +const ( + policySyscallName = "Neo.Native.Policy" + policyContractID = -3 + + defaultMaxBlockSize = 1024 * 256 + defaultMaxTransactionsPerBlock = 512 + defaultFeePerByte = 1000 +) + +var ( + // maxTransactionsPerBlockKey is a key used to store the maximum number of + // transactions allowed in block. + maxTransactionsPerBlockKey = []byte{23} + // feePerByteKey is a key used to store the minimum fee per byte for + // transaction. + feePerByteKey = []byte{10} + // blockedAccountsKey is a key used to store the list of blocked accounts. + blockedAccountsKey = []byte{15} + // maxBlockSizeKey is a key used to store the maximum block size value. + maxBlockSizeKey = []byte{16} +) + +// Policy represents Policy native contract. +type Policy struct { + interop.ContractMD +} + +var _ interop.Contract = (*Policy)(nil) + +// newPolicy returns Policy native contract. +func newPolicy() *Policy { + p := &Policy{ContractMD: *interop.NewContractMD(policySyscallName)} + + p.ContractID = policyContractID + p.Manifest.Features |= smartcontract.HasStorage + + desc := newDescriptor("getMaxTransactionsPerBlock", smartcontract.IntegerType) + md := newMethodAndPrice(p.getMaxTransactionsPerBlock, 1000000, smartcontract.NoneFlag) + p.AddMethod(md, desc, true) + + desc = newDescriptor("getMaxBlockSize", smartcontract.IntegerType) + md = newMethodAndPrice(p.getMaxBlockSize, 1000000, smartcontract.NoneFlag) + p.AddMethod(md, desc, true) + + desc = newDescriptor("getFeePerByte", smartcontract.IntegerType) + md = newMethodAndPrice(p.getFeePerByte, 1000000, smartcontract.NoneFlag) + p.AddMethod(md, desc, true) + + desc = newDescriptor("getBlockedAccounts", smartcontract.ArrayType) + md = newMethodAndPrice(p.getBlockedAccounts, 1000000, smartcontract.NoneFlag) + p.AddMethod(md, desc, true) + + desc = newDescriptor("setMaxBlockSize", smartcontract.BoolType, + manifest.NewParameter("value", smartcontract.IntegerType)) + md = newMethodAndPrice(p.setMaxBlockSize, 3000000, smartcontract.NoneFlag) + p.AddMethod(md, desc, false) + + desc = newDescriptor("setMaxTransactionsPerBlock", smartcontract.BoolType, + manifest.NewParameter("value", smartcontract.IntegerType)) + md = newMethodAndPrice(p.setMaxTransactionsPerBlock, 3000000, smartcontract.NoneFlag) + p.AddMethod(md, desc, false) + + desc = newDescriptor("setFeePerByte", smartcontract.BoolType, + manifest.NewParameter("value", smartcontract.IntegerType)) + md = newMethodAndPrice(p.setFeePerByte, 3000000, smartcontract.NoneFlag) + p.AddMethod(md, desc, false) + + desc = newDescriptor("blockAccount", smartcontract.BoolType, + manifest.NewParameter("account", smartcontract.Hash160Type)) + md = newMethodAndPrice(p.blockAccount, 3000000, smartcontract.NoneFlag) + p.AddMethod(md, desc, false) + + desc = newDescriptor("unblockAccount", smartcontract.BoolType, + manifest.NewParameter("account", smartcontract.Hash160Type)) + md = newMethodAndPrice(p.unblockAccount, 3000000, smartcontract.NoneFlag) + p.AddMethod(md, desc, false) + + desc = newDescriptor("onPersist", smartcontract.BoolType) + md = newMethodAndPrice(getOnPersistWrapper(p.OnPersist), 0, smartcontract.AllowModifyStates) + p.AddMethod(md, desc, false) + return p +} + +// Metadata implements Contract interface. +func (p *Policy) Metadata() *interop.ContractMD { + return &p.ContractMD +} + +// Initialize initializes Policy native contract and implements Contract interface. +func (p *Policy) Initialize(ic *interop.Context) error { + si := &state.StorageItem{ + Value: make([]byte, 4, 8), + } + binary.LittleEndian.PutUint32(si.Value, defaultMaxBlockSize) + err := ic.DAO.PutStorageItem(p.ContractID, maxBlockSizeKey, si) + if err != nil { + return err + } + + binary.LittleEndian.PutUint32(si.Value, defaultMaxTransactionsPerBlock) + err = ic.DAO.PutStorageItem(p.ContractID, maxTransactionsPerBlockKey, si) + if err != nil { + return err + } + + si.Value = si.Value[:8] + binary.LittleEndian.PutUint64(si.Value, defaultFeePerByte) + err = ic.DAO.PutStorageItem(p.ContractID, feePerByteKey, si) + if err != nil { + return err + } + + ba := new(BlockedAccounts) + si.Value = ba.Bytes() + err = ic.DAO.PutStorageItem(p.ContractID, blockedAccountsKey, si) + if err != nil { + return err + } + return nil +} + +// OnPersist implements Contract interface. +func (p *Policy) OnPersist(ic *interop.Context) error { + return nil +} + +// getMaxTransactionsPerBlock is Policy contract method and returns the upper +// limit of transactions per block. +func (p *Policy) getMaxTransactionsPerBlock(ic *interop.Context, _ []stackitem.Item) stackitem.Item { + return stackitem.NewBigInteger(big.NewInt(int64(p.GetMaxTransactionsPerBlockInternal(ic.DAO)))) +} + +// GetMaxTransactionsPerBlockInternal returns the upper limit of transactions per +// block. +func (p *Policy) GetMaxTransactionsPerBlockInternal(dao dao.DAO) uint32 { + return p.getUint32WithKey(dao, maxTransactionsPerBlockKey) +} + +// getMaxBlockSize is Policy contract method and returns maximum block size. +func (p *Policy) getMaxBlockSize(ic *interop.Context, _ []stackitem.Item) stackitem.Item { + return stackitem.NewBigInteger(big.NewInt(int64(p.getUint32WithKey(ic.DAO, maxBlockSizeKey)))) +} + +// getFeePerByte is Policy contract method and returns required transaction's fee +// per byte. +func (p *Policy) getFeePerByte(ic *interop.Context, _ []stackitem.Item) stackitem.Item { + return stackitem.NewBigInteger(big.NewInt(p.GetFeePerByteInternal(ic.DAO))) +} + +// GetFeePerByteInternal returns required transaction's fee per byte. +func (p *Policy) GetFeePerByteInternal(dao dao.DAO) int64 { + return p.getInt64WithKey(dao, feePerByteKey) +} + +// getBlockedAccounts is Policy contract method and returns list of blocked +// accounts hashes. +func (p *Policy) getBlockedAccounts(ic *interop.Context, _ []stackitem.Item) stackitem.Item { + ba, err := p.GetBlockedAccountsInternal(ic.DAO) + if err != nil { + panic(err) + } + return ba.ToStackItem() +} + +// GetBlockedAccountsInternal returns list of blocked accounts hashes. +func (p *Policy) GetBlockedAccountsInternal(dao dao.DAO) (BlockedAccounts, error) { + si := dao.GetStorageItem(p.ContractID, blockedAccountsKey) + if si == nil { + return nil, errors.New("BlockedAccounts uninitialized") + } + ba, err := BlockedAccountsFromBytes(si.Value) + if err != nil { + return nil, err + } + return ba, nil +} + +// setMaxTransactionsPerBlock is Policy contract method and sets the upper limit +// of transactions per block. +func (p *Policy) setMaxTransactionsPerBlock(ic *interop.Context, args []stackitem.Item) stackitem.Item { + ok, err := p.checkValidators(ic) + if err != nil { + panic(err) + } + if !ok { + return stackitem.NewBool(false) + } + value := uint32(toBigInt(args[0]).Int64()) + err = p.setUint32WithKey(ic.DAO, maxTransactionsPerBlockKey, value) + if err != nil { + panic(err) + } + return stackitem.NewBool(true) +} + +// setMaxBlockSize is Policy contract method and sets maximum block size. +func (p *Policy) setMaxBlockSize(ic *interop.Context, args []stackitem.Item) stackitem.Item { + ok, err := p.checkValidators(ic) + if err != nil { + panic(err) + } + if !ok { + return stackitem.NewBool(false) + } + value := uint32(toBigInt(args[0]).Int64()) + if payload.MaxSize <= value { + return stackitem.NewBool(false) + } + err = p.setUint32WithKey(ic.DAO, maxBlockSizeKey, value) + if err != nil { + panic(err) + } + return stackitem.NewBool(true) +} + +// setFeePerByte is Policy contract method and sets transaction's fee per byte. +func (p *Policy) setFeePerByte(ic *interop.Context, args []stackitem.Item) stackitem.Item { + ok, err := p.checkValidators(ic) + if err != nil { + panic(err) + } + if !ok { + return stackitem.NewBool(false) + } + value := toBigInt(args[0]).Int64() + err = p.setInt64WithKey(ic.DAO, feePerByteKey, value) + if err != nil { + panic(err) + } + return stackitem.NewBool(true) +} + +// blockAccount is Policy contract method and adds given account hash to the list +// of blocked accounts. +func (p *Policy) blockAccount(ic *interop.Context, args []stackitem.Item) stackitem.Item { + ok, err := p.checkValidators(ic) + if err != nil { + panic(err) + } + if !ok { + return stackitem.NewBool(false) + } + value := toUint160(args[0]) + si := ic.DAO.GetStorageItem(p.ContractID, blockedAccountsKey) + if si == nil { + panic("BlockedAccounts uninitialized") + } + ba, err := BlockedAccountsFromBytes(si.Value) + if err != nil { + panic(err) + } + indexToInsert := sort.Search(len(ba), func(i int) bool { + return !ba[i].Less(value) + }) + ba = append(ba, value) + if indexToInsert != len(ba)-1 && ba[indexToInsert].Equals(value) { + return stackitem.NewBool(false) + } + if len(ba) > 1 { + copy(ba[indexToInsert+1:], ba[indexToInsert:]) + ba[indexToInsert] = value + } + err = ic.DAO.PutStorageItem(p.ContractID, blockedAccountsKey, &state.StorageItem{ + Value: ba.Bytes(), + }) + if err != nil { + panic(err) + } + return stackitem.NewBool(true) +} + +// unblockAccount is Policy contract method and removes given account hash from +// the list of blocked accounts. +func (p *Policy) unblockAccount(ic *interop.Context, args []stackitem.Item) stackitem.Item { + ok, err := p.checkValidators(ic) + if err != nil { + panic(err) + } + if !ok { + return stackitem.NewBool(false) + } + value := toUint160(args[0]) + si := ic.DAO.GetStorageItem(p.ContractID, blockedAccountsKey) + if si == nil { + panic("BlockedAccounts uninitialized") + } + ba, err := BlockedAccountsFromBytes(si.Value) + if err != nil { + panic(err) + } + indexToRemove := sort.Search(len(ba), func(i int) bool { + return !ba[i].Less(value) + }) + if indexToRemove == len(ba) || !ba[indexToRemove].Equals(value) { + return stackitem.NewBool(false) + } + ba = append(ba[:indexToRemove], ba[indexToRemove+1:]...) + err = ic.DAO.PutStorageItem(p.ContractID, blockedAccountsKey, &state.StorageItem{ + Value: ba.Bytes(), + }) + if err != nil { + panic(err) + } + return stackitem.NewBool(true) +} + +func (p *Policy) getUint32WithKey(dao dao.DAO, key []byte) uint32 { + si := dao.GetStorageItem(p.ContractID, key) + if si == nil { + return 0 + } + return binary.LittleEndian.Uint32(si.Value) +} + +func (p *Policy) setUint32WithKey(dao dao.DAO, key []byte, value uint32) error { + si := dao.GetStorageItem(p.ContractID, key) + binary.LittleEndian.PutUint32(si.Value, value) + err := dao.PutStorageItem(p.ContractID, key, si) + if err != nil { + return err + } + return nil +} + +func (p *Policy) getInt64WithKey(dao dao.DAO, key []byte) int64 { + si := dao.GetStorageItem(p.ContractID, key) + if si == nil { + return 0 + } + return int64(binary.LittleEndian.Uint64(si.Value)) +} + +func (p *Policy) setInt64WithKey(dao dao.DAO, key []byte, value int64) error { + si := dao.GetStorageItem(p.ContractID, key) + binary.LittleEndian.PutUint64(si.Value, uint64(value)) + err := dao.PutStorageItem(p.ContractID, key, si) + if err != nil { + return err + } + return nil +} + +func (p *Policy) checkValidators(ic *interop.Context) (bool, error) { + prevBlock, err := ic.Chain.GetBlock(ic.Block.PrevHash) + if err != nil { + return false, err + } + return runtime.CheckHashedWitness(ic, nep5ScriptHash{ + callingScriptHash: p.Hash, + entryScriptHash: p.Hash, + currentScriptHash: p.Hash, + }, prevBlock.NextConsensus) +} + +// CheckPolicy checks whether transaction's script hashes for verifying are +// included into blocked accounts list. +func (p *Policy) CheckPolicy(ic *interop.Context, tx *transaction.Transaction) (bool, error) { + ba, err := p.GetBlockedAccountsInternal(ic.DAO) + if err != nil { + return false, err + } + scriptHashes, err := ic.Chain.GetScriptHashesForVerifying(tx) + if err != nil { + return false, err + } + for _, acc := range ba { + for _, hash := range scriptHashes { + if acc.Equals(hash) { + return false, nil + } + } + } + return true, nil +} diff --git a/pkg/core/native_policy_test.go b/pkg/core/native_policy_test.go new file mode 100644 index 000000000..765f114ec --- /dev/null +++ b/pkg/core/native_policy_test.go @@ -0,0 +1,260 @@ +package core + +import ( + "math/big" + "sort" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/native" + "github.com/nspcc-dev/neo-go/pkg/core/state" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/emit" + "github.com/stretchr/testify/require" +) + +func TestMaxTransactionsPerBlock(t *testing.T) { + chain := newTestChain(t) + defer chain.Close() + + t.Run("get, internal method", func(t *testing.T) { + n := chain.contracts.Policy.GetMaxTransactionsPerBlockInternal(chain.dao) + require.Equal(t, 512, int(n)) + }) + + t.Run("get, contract method", func(t *testing.T) { + res, err := invokeNativePolicyMethod(chain, "getMaxTransactionsPerBlock") + require.NoError(t, err) + checkResult(t, res, smartcontract.Parameter{ + Type: smartcontract.IntegerType, + Value: 512, + }) + require.NoError(t, chain.persist()) + }) + + t.Run("set", func(t *testing.T) { + res, err := invokeNativePolicyMethod(chain, "setMaxTransactionsPerBlock", bigint.ToBytes(big.NewInt(1024))) + require.NoError(t, err) + checkResult(t, res, smartcontract.Parameter{ + Type: smartcontract.BoolType, + Value: true, + }) + require.NoError(t, chain.persist()) + n := chain.contracts.Policy.GetMaxTransactionsPerBlockInternal(chain.dao) + require.Equal(t, 1024, int(n)) + }) +} + +func TestMaxBlockSize(t *testing.T) { + chain := newTestChain(t) + defer chain.Close() + + t.Run("get", func(t *testing.T) { + res, err := invokeNativePolicyMethod(chain, "getMaxBlockSize") + require.NoError(t, err) + checkResult(t, res, smartcontract.Parameter{ + Type: smartcontract.IntegerType, + Value: 1024 * 256, + }) + require.NoError(t, chain.persist()) + }) + + t.Run("set", func(t *testing.T) { + res, err := invokeNativePolicyMethod(chain, "setMaxBlockSize", bigint.ToBytes(big.NewInt(102400))) + require.NoError(t, err) + checkResult(t, res, smartcontract.Parameter{ + Type: smartcontract.BoolType, + Value: true, + }) + require.NoError(t, chain.persist()) + res, err = invokeNativePolicyMethod(chain, "getMaxBlockSize") + require.NoError(t, err) + checkResult(t, res, smartcontract.Parameter{ + Type: smartcontract.IntegerType, + Value: 102400, + }) + require.NoError(t, chain.persist()) + }) +} + +func TestFeePerByte(t *testing.T) { + chain := newTestChain(t) + defer chain.Close() + + t.Run("get, internal method", func(t *testing.T) { + n := chain.contracts.Policy.GetFeePerByteInternal(chain.dao) + require.Equal(t, 1000, int(n)) + }) + + t.Run("get, contract method", func(t *testing.T) { + res, err := invokeNativePolicyMethod(chain, "getFeePerByte") + require.NoError(t, err) + checkResult(t, res, smartcontract.Parameter{ + Type: smartcontract.IntegerType, + Value: 1000, + }) + require.NoError(t, chain.persist()) + }) + + t.Run("set", func(t *testing.T) { + res, err := invokeNativePolicyMethod(chain, "setFeePerByte", bigint.ToBytes(big.NewInt(1024))) + require.NoError(t, err) + checkResult(t, res, smartcontract.Parameter{ + Type: smartcontract.BoolType, + Value: true, + }) + require.NoError(t, chain.persist()) + n := chain.contracts.Policy.GetFeePerByteInternal(chain.dao) + require.Equal(t, 1024, int(n)) + }) +} + +func TestBlockedAccounts(t *testing.T) { + chain := newTestChain(t) + defer chain.Close() + account := util.Uint160{1, 2, 3} + + t.Run("get, internal method", func(t *testing.T) { + accounts, err := chain.contracts.Policy.GetBlockedAccountsInternal(chain.dao) + require.NoError(t, err) + require.Equal(t, native.BlockedAccounts{}, accounts) + }) + + t.Run("get, contract method", func(t *testing.T) { + res, err := invokeNativePolicyMethod(chain, "getBlockedAccounts") + require.NoError(t, err) + checkResult(t, res, smartcontract.Parameter{ + Type: smartcontract.ArrayType, + Value: []smartcontract.Parameter{}, + }) + require.NoError(t, chain.persist()) + }) + + t.Run("block-unblock account", func(t *testing.T) { + res, err := invokeNativePolicyMethod(chain, "blockAccount", account.BytesBE()) + require.NoError(t, err) + checkResult(t, res, smartcontract.Parameter{ + Type: smartcontract.BoolType, + Value: true, + }) + + accounts, err := chain.contracts.Policy.GetBlockedAccountsInternal(chain.dao) + require.NoError(t, err) + require.Equal(t, native.BlockedAccounts{account}, accounts) + require.NoError(t, chain.persist()) + + res, err = invokeNativePolicyMethod(chain, "unblockAccount", account.BytesBE()) + checkResult(t, res, smartcontract.Parameter{ + Type: smartcontract.BoolType, + Value: true, + }) + + accounts, err = chain.contracts.Policy.GetBlockedAccountsInternal(chain.dao) + require.NoError(t, err) + require.Equal(t, native.BlockedAccounts{}, accounts) + require.NoError(t, chain.persist()) + }) + + t.Run("double-block", func(t *testing.T) { + // block + res, err := invokeNativePolicyMethod(chain, "blockAccount", account.BytesBE()) + require.NoError(t, err) + checkResult(t, res, smartcontract.Parameter{ + Type: smartcontract.BoolType, + Value: true, + }) + require.NoError(t, chain.persist()) + + // double-block should fail + res, err = invokeNativePolicyMethod(chain, "blockAccount", account.BytesBE()) + require.NoError(t, err) + checkResult(t, res, smartcontract.Parameter{ + Type: smartcontract.BoolType, + Value: false, + }) + require.NoError(t, chain.persist()) + + // unblock + res, err = invokeNativePolicyMethod(chain, "unblockAccount", account.BytesBE()) + checkResult(t, res, smartcontract.Parameter{ + Type: smartcontract.BoolType, + Value: true, + }) + require.NoError(t, chain.persist()) + + // unblock the same account should fail as we don't have it blocked + res, err = invokeNativePolicyMethod(chain, "unblockAccount", account.BytesBE()) + checkResult(t, res, smartcontract.Parameter{ + Type: smartcontract.BoolType, + Value: false, + }) + require.NoError(t, chain.persist()) + }) + + t.Run("sorted", func(t *testing.T) { + accounts := []util.Uint160{ + {2, 3, 4}, + {4, 5, 6}, + {3, 4, 5}, + {1, 2, 3}, + } + for _, acc := range accounts { + res, err := invokeNativePolicyMethod(chain, "blockAccount", acc.BytesBE()) + require.NoError(t, err) + checkResult(t, res, smartcontract.Parameter{ + Type: smartcontract.BoolType, + Value: true, + }) + require.NoError(t, chain.persist()) + } + + sort.Slice(accounts, func(i, j int) bool { + return accounts[i].Less(accounts[j]) + }) + actual, err := chain.contracts.Policy.GetBlockedAccountsInternal(chain.dao) + require.NoError(t, err) + require.Equal(t, native.BlockedAccounts(accounts), actual) + }) +} + +func invokeNativePolicyMethod(chain *Blockchain, method string, args ...interface{}) (*state.AppExecResult, error) { + w := io.NewBufBinWriter() + emit.AppCallWithOperationAndArgs(w.BinWriter, chain.contracts.Policy.Metadata().Hash, method, args...) + if w.Err != nil { + return nil, w.Err + } + script := w.Bytes() + tx := transaction.New(chain.GetConfig().Magic, script, 0) + validUntil := chain.blockHeight + 1 + tx.ValidUntilBlock = validUntil + err := addSender(tx) + if err != nil { + return nil, err + } + addCosigners(tx) + err = signTx(chain, tx) + if err != nil { + return nil, err + } + b := chain.newBlock(tx) + err = chain.AddBlock(b) + if err != nil { + return nil, err + } + + res, err := chain.GetAppExecResult(tx.Hash()) + if err != nil { + return nil, err + } + return res, nil +} + +func checkResult(t *testing.T, result *state.AppExecResult, expected smartcontract.Parameter) { + require.Equal(t, "HALT", result.VMState) + require.Equal(t, 1, len(result.Stack)) + require.Equal(t, expected.Type, result.Stack[0].Type) + require.EqualValues(t, expected.Value, result.Stack[0].Value) +} diff --git a/pkg/network/compress.go b/pkg/network/compress.go index ddde5e9d7..bed34690d 100644 --- a/pkg/network/compress.go +++ b/pkg/network/compress.go @@ -1,6 +1,7 @@ package network import ( + "github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/pierrec/lz4" ) @@ -17,8 +18,8 @@ func compress(source []byte) ([]byte, error) { // decompress decompresses bytes using lz4. func decompress(source []byte) ([]byte, error) { maxSize := len(source) * 255 - if maxSize > PayloadMaxSize { - maxSize = PayloadMaxSize + if maxSize > payload.MaxSize { + maxSize = payload.MaxSize } dest := make([]byte, maxSize) size, err := lz4.UncompressBlock(source, dest) diff --git a/pkg/network/message.go b/pkg/network/message.go index a7ec61c9e..bb0678d79 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -14,12 +14,8 @@ import ( //go:generate stringer -type=CommandType -const ( - // PayloadMaxSize is maximum payload size in decompressed form. - PayloadMaxSize = 0x02000000 - // CompressionMinSize is the lower bound to apply compression. - CompressionMinSize = 1024 -) +// CompressionMinSize is the lower bound to apply compression. +const CompressionMinSize = 1024 // Message is the complete message send between nodes. type Message struct { @@ -114,7 +110,7 @@ func (m *Message) Decode(br *io.BinReader) error { } return nil } - if l > PayloadMaxSize { + if l > payload.MaxSize { return errors.New("invalid payload size") } m.compressedPayload = make([]byte, l) diff --git a/pkg/network/payload/payload.go b/pkg/network/payload/payload.go index a761de2e2..8fc3ef211 100644 --- a/pkg/network/payload/payload.go +++ b/pkg/network/payload/payload.go @@ -2,6 +2,9 @@ package payload import "github.com/nspcc-dev/neo-go/pkg/io" +// MaxSize is maximum payload size in decompressed form. +const MaxSize = 0x02000000 + // Payload is anything that can be binary encoded/decoded. type Payload interface { io.Serializable