From 0da01fde7fd428dde0eaba828ef8aca7eb67ee66 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Wed, 21 Oct 2020 15:51:59 +0300 Subject: [PATCH] core: refactor blocked accounts logic --- pkg/core/native/blocked_accounts.go | 53 --------- pkg/core/native/blocked_accounts_test.go | 64 ----------- pkg/core/native/policy.go | 137 ++++++++--------------- pkg/core/native_policy_test.go | 47 ++------ pkg/rpc/client/policy.go | 40 +++---- pkg/rpc/client/rpc_test.go | 9 +- 6 files changed, 77 insertions(+), 273 deletions(-) delete mode 100644 pkg/core/native/blocked_accounts.go delete mode 100644 pkg/core/native/blocked_accounts_test.go diff --git a/pkg/core/native/blocked_accounts.go b/pkg/core/native/blocked_accounts.go deleted file mode 100644 index fb27ce34f..000000000 --- a/pkg/core/native/blocked_accounts.go +++ /dev/null @@ -1,53 +0,0 @@ -package native - -import ( - "github.com/nspcc-dev/neo-go/pkg/io" - "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" -) - -// BlockedAccounts represents a slice of blocked accounts hashes. -type BlockedAccounts []util.Uint160 - -// Bytes returns serialized BlockedAccounts. -func (ba *BlockedAccounts) Bytes() []byte { - w := io.NewBufBinWriter() - ba.EncodeBinary(w.BinWriter) - if w.Err != nil { - panic(w.Err) - } - return w.Bytes() -} - -// EncodeBinary implements io.Serializable interface. -func (ba *BlockedAccounts) EncodeBinary(w *io.BinWriter) { - w.WriteArray(*ba) -} - -// BlockedAccountsFromBytes converts serialized BlockedAccounts to structure. -func BlockedAccountsFromBytes(b []byte) (BlockedAccounts, error) { - ba := new(BlockedAccounts) - if len(b) == 0 { - return *ba, nil - } - r := io.NewBinReaderFromBuf(b) - ba.DecodeBinary(r) - if r.Err != nil { - return nil, r.Err - } - return *ba, nil -} - -// DecodeBinary implements io.Serializable interface. -func (ba *BlockedAccounts) DecodeBinary(r *io.BinReader) { - r.ReadArray(ba) -} - -// ToStackItem converts BlockedAccounts to stackitem.Item -func (ba *BlockedAccounts) ToStackItem() stackitem.Item { - result := make([]stackitem.Item, len(*ba)) - for i, account := range *ba { - result[i] = stackitem.NewByteArray(account.BytesLE()) - } - return stackitem.NewArray(result) -} diff --git a/pkg/core/native/blocked_accounts_test.go b/pkg/core/native/blocked_accounts_test.go deleted file mode 100644 index 76709395e..000000000 --- a/pkg/core/native/blocked_accounts_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package native - -import ( - "encoding/json" - "testing" - - "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" - "github.com/nspcc-dev/neo-go/pkg/smartcontract" - "github.com/nspcc-dev/neo-go/pkg/util" - "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" - "github.com/stretchr/testify/require" -) - -func TestEncodeDecodeBinary(t *testing.T) { - expected := &BlockedAccounts{ - util.Uint160{1, 2, 3}, - util.Uint160{4, 5, 6}, - } - actual := new(BlockedAccounts) - testserdes.EncodeDecodeBinary(t, expected, actual) - - expected = &BlockedAccounts{} - actual = new(BlockedAccounts) - testserdes.EncodeDecodeBinary(t, expected, actual) -} - -func TestBytesFromBytes(t *testing.T) { - expected := BlockedAccounts{ - util.Uint160{1, 2, 3}, - util.Uint160{4, 5, 6}, - } - actual, err := BlockedAccountsFromBytes(expected.Bytes()) - require.NoError(t, err) - require.Equal(t, expected, actual) - - expected = BlockedAccounts{} - actual, err = BlockedAccountsFromBytes(expected.Bytes()) - require.NoError(t, err) - require.Equal(t, expected, actual) -} - -func TestToStackItem(t *testing.T) { - u1 := util.Uint160{1, 2, 3} - u2 := util.Uint160{4, 5, 6} - expected := BlockedAccounts{u1, u2} - actual := stackitem.NewArray([]stackitem.Item{ - stackitem.NewByteArray(u1.BytesLE()), - stackitem.NewByteArray(u2.BytesLE()), - }) - require.Equal(t, expected.ToStackItem(), actual) - - expected = BlockedAccounts{} - actual = stackitem.NewArray([]stackitem.Item{}) - require.Equal(t, expected.ToStackItem(), actual) -} - -func TestMarshallJSON(t *testing.T) { - ba := &BlockedAccounts{} - p := smartcontract.ParameterFromStackItem(ba.ToStackItem(), make(map[stackitem.Item]bool)) - actual, err := json.Marshal(p) - require.NoError(t, err) - expected := `{"type":"Array","value":[]}` - require.Equal(t, expected, string(actual)) -} diff --git a/pkg/core/native/policy.go b/pkg/core/native/policy.go index d0e8f7533..22a53cfbb 100644 --- a/pkg/core/native/policy.go +++ b/pkg/core/native/policy.go @@ -2,7 +2,6 @@ package native import ( "encoding/binary" - "errors" "fmt" "math/big" "sort" @@ -34,6 +33,9 @@ const ( minBlockSystemFee = 4007600 // maxFeePerByte is the maximum allowed fee per byte value. maxFeePerByte = 100_000_000 + + // blockedAccountPrefix is a prefix used to store blocked account. + blockedAccountPrefix = 15 ) var ( @@ -43,8 +45,6 @@ var ( // feePerByteKey is a key used to store the minimum fee per byte for // transaction. feePerByteKey = []byte{10} - // blockedAccountsKey is a key used to store the list of blocked accounts. - blockedAccountsKey = []byte{15} // maxBlockSizeKey is a key used to store the maximum block size value. maxBlockSizeKey = []byte{12} // maxBlockSystemFeeKey is a key used to store the maximum block system fee value. @@ -64,7 +64,7 @@ type Policy struct { feePerByte int64 maxBlockSystemFee int64 maxVerificationGas int64 - blockedAccounts BlockedAccounts + blockedAccounts []util.Uint160 } var _ interop.Contract = (*Policy)(nil) @@ -88,8 +88,9 @@ func newPolicy() *Policy { md = newMethodAndPrice(p.getFeePerByte, 1000000, smartcontract.AllowStates) p.AddMethod(md, desc, true) - desc = newDescriptor("getBlockedAccounts", smartcontract.ArrayType) - md = newMethodAndPrice(p.getBlockedAccounts, 1000000, smartcontract.AllowStates) + desc = newDescriptor("isBlocked", smartcontract.BoolType, + manifest.NewParameter("account", smartcontract.Hash160Type)) + md = newMethodAndPrice(p.isBlocked, 1000000, smartcontract.AllowStates) p.AddMethod(md, desc, true) desc = newDescriptor("getMaxBlockSystemFee", smartcontract.IntegerType) @@ -175,13 +176,6 @@ func (p *Policy) Initialize(ic *interop.Context) error { return err } - ba := new(BlockedAccounts) - si.Value = ba.Bytes() - err = ic.DAO.PutStorageItem(p.ContractID, blockedAccountsKey, si) - if err != nil { - return err - } - p.isValid = true p.maxTransactionsPerBlock = defaultMaxTransactionsPerBlock p.maxBlockSize = defaultMaxBlockSize @@ -220,15 +214,21 @@ func (p *Policy) OnPersistEnd(dao dao.DAO) error { p.maxVerificationGas = defaultMaxVerificationGas - si := dao.GetStorageItem(p.ContractID, blockedAccountsKey) - if si == nil { - return errors.New("BlockedAccounts uninitialized") - } - ba, err := BlockedAccountsFromBytes(si.Value) + p.blockedAccounts = make([]util.Uint160, 0) + siMap, err := dao.GetStorageItemsWithPrefix(p.ContractID, []byte{blockedAccountPrefix}) if err != nil { - return fmt.Errorf("failed to decode BlockedAccounts from bytes: %w", err) + return fmt.Errorf("failed to get blocked accounts from storage: %w", err) } - p.blockedAccounts = ba + for key := range siMap { + hash, err := util.Uint160DecodeBytesBE([]byte(key)) + if err != nil { + return fmt.Errorf("failed to decode blocked account hash: %w", err) + } + p.blockedAccounts = append(p.blockedAccounts, hash) + } + sort.Slice(p.blockedAccounts, func(i, j int) bool { + return p.blockedAccounts[i].Less(p.blockedAccounts[j]) + }) p.isValid = true return nil @@ -306,32 +306,28 @@ func (p *Policy) GetMaxBlockSystemFeeInternal(dao dao.DAO) int64 { return p.getInt64WithKey(dao, maxBlockSystemFeeKey) } -// getBlockedAccounts is Policy contract method and returns list of blocked -// accounts hashes. -func (p *Policy) getBlockedAccounts(ic *interop.Context, _ []stackitem.Item) stackitem.Item { - ba, err := p.GetBlockedAccountsInternal(ic.DAO) - if err != nil { - panic(err) - } - return ba.ToStackItem() +// isBlocked is Policy contract method and checks whether provided account is blocked. +func (p *Policy) isBlocked(ic *interop.Context, args []stackitem.Item) stackitem.Item { + hash := toUint160(args[0]) + return stackitem.NewBool(p.IsBlockedInternal(ic.DAO, hash)) } -// GetBlockedAccountsInternal returns list of blocked accounts hashes. -func (p *Policy) GetBlockedAccountsInternal(dao dao.DAO) (BlockedAccounts, error) { +// IsBlockedInternal checks whether provided account is blocked +func (p *Policy) IsBlockedInternal(dao dao.DAO, hash util.Uint160) bool { p.lock.RLock() defer p.lock.RUnlock() if p.isValid { - return p.blockedAccounts, nil + length := len(p.blockedAccounts) + i := sort.Search(length, func(i int) bool { + return !p.blockedAccounts[i].Less(hash) + }) + if length != 0 && i != length && p.blockedAccounts[i].Equals(hash) { + return true + } + return false } - si := dao.GetStorageItem(p.ContractID, blockedAccountsKey) - if si == nil { - return nil, errors.New("BlockedAccounts uninitialized") - } - ba, err := BlockedAccountsFromBytes(si.Value) - if err != nil { - return nil, err - } - return ba, nil + key := append([]byte{blockedAccountPrefix}, hash.BytesBE()...) + return dao.GetStorageItem(p.ContractID, key) != nil } // setMaxTransactionsPerBlock is Policy contract method and sets the upper limit @@ -437,30 +433,15 @@ func (p *Policy) blockAccount(ic *interop.Context, args []stackitem.Item) stacki if !ok { return stackitem.NewBool(false) } - value := toUint160(args[0]) - si := ic.DAO.GetStorageItem(p.ContractID, blockedAccountsKey) - if si == nil { - panic("BlockedAccounts uninitialized") - } - ba, err := BlockedAccountsFromBytes(si.Value) - if err != nil { - panic(err) - } - indexToInsert := sort.Search(len(ba), func(i int) bool { - return !ba[i].Less(value) - }) - ba = append(ba, value) - if indexToInsert != len(ba)-1 && ba[indexToInsert].Equals(value) { + hash := toUint160(args[0]) + if p.IsBlockedInternal(ic.DAO, hash) { return stackitem.NewBool(false) } - if len(ba) > 1 { - copy(ba[indexToInsert+1:], ba[indexToInsert:]) - ba[indexToInsert] = value - } + key := append([]byte{blockedAccountPrefix}, hash.BytesBE()...) p.lock.Lock() defer p.lock.Unlock() - err = ic.DAO.PutStorageItem(p.ContractID, blockedAccountsKey, &state.StorageItem{ - Value: ba.Bytes(), + err = ic.DAO.PutStorageItem(p.ContractID, key, &state.StorageItem{ + Value: []byte{0x01}, }) if err != nil { panic(err) @@ -479,27 +460,14 @@ func (p *Policy) unblockAccount(ic *interop.Context, args []stackitem.Item) stac if !ok { return stackitem.NewBool(false) } - value := toUint160(args[0]) - si := ic.DAO.GetStorageItem(p.ContractID, blockedAccountsKey) - if si == nil { - panic("BlockedAccounts uninitialized") - } - ba, err := BlockedAccountsFromBytes(si.Value) - if err != nil { - panic(err) - } - indexToRemove := sort.Search(len(ba), func(i int) bool { - return !ba[i].Less(value) - }) - if indexToRemove == len(ba) || !ba[indexToRemove].Equals(value) { + hash := toUint160(args[0]) + if !p.IsBlockedInternal(ic.DAO, hash) { return stackitem.NewBool(false) } - ba = append(ba[:indexToRemove], ba[indexToRemove+1:]...) + key := append([]byte{blockedAccountPrefix}, hash.BytesBE()...) p.lock.Lock() defer p.lock.Unlock() - err = ic.DAO.PutStorageItem(p.ContractID, blockedAccountsKey, &state.StorageItem{ - Value: ba.Bytes(), - }) + err = ic.DAO.DeleteStorageItem(p.ContractID, key) if err != nil { panic(err) } @@ -555,18 +523,9 @@ func (p *Policy) checkValidators(ic *interop.Context) (bool, error) { // like not being signed by blocked account or not exceeding block-level system // fee limit. func (p *Policy) CheckPolicy(d dao.DAO, tx *transaction.Transaction) error { - ba, err := p.GetBlockedAccountsInternal(d) - if err != nil { - return fmt.Errorf("unable to get blocked accounts list: %w", err) - } - if len(ba) > 0 { - for _, signer := range tx.Signers { - i := sort.Search(len(ba), func(i int) bool { - return !ba[i].Less(signer.Account) - }) - if i != len(ba) && ba[i].Equals(signer.Account) { - return fmt.Errorf("account %s is blocked", signer.Account.StringLE()) - } + for _, signer := range tx.Signers { + if p.IsBlockedInternal(d, signer.Account) { + return fmt.Errorf("account %s is blocked", signer.Account.StringLE()) } } maxBlockSystemFee := p.GetMaxBlockSystemFeeInternal(d) diff --git a/pkg/core/native_policy_test.go b/pkg/core/native_policy_test.go index 9a5c9f777..4312efdb4 100644 --- a/pkg/core/native_policy_test.go +++ b/pkg/core/native_policy_test.go @@ -2,7 +2,6 @@ package core import ( "math/big" - "sort" "testing" "github.com/nspcc-dev/neo-go/pkg/core/block" @@ -10,6 +9,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" + "github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/util" @@ -162,16 +162,15 @@ func TestBlockedAccounts(t *testing.T) { defer chain.Close() account := util.Uint160{1, 2, 3} - t.Run("get, internal method", func(t *testing.T) { - accounts, err := chain.contracts.Policy.GetBlockedAccountsInternal(chain.dao) - require.NoError(t, err) - require.Equal(t, native.BlockedAccounts{}, accounts) + t.Run("isBlocked, internal method", func(t *testing.T) { + isBlocked := chain.contracts.Policy.IsBlockedInternal(chain.dao, random.Uint160()) + require.Equal(t, false, isBlocked) }) - t.Run("get, contract method", func(t *testing.T) { - res, err := invokeNativePolicyMethod(chain, "getBlockedAccounts") + t.Run("isBlocked, contract method", func(t *testing.T) { + res, err := invokeNativePolicyMethod(chain, "isBlocked", random.Uint160()) require.NoError(t, err) - checkResult(t, res, stackitem.NewArray([]stackitem.Item{})) + checkResult(t, res, stackitem.NewBool(false)) require.NoError(t, chain.persist()) }) @@ -180,18 +179,16 @@ func TestBlockedAccounts(t *testing.T) { require.NoError(t, err) checkResult(t, res, stackitem.NewBool(true)) - accounts, err := chain.contracts.Policy.GetBlockedAccountsInternal(chain.dao) - require.NoError(t, err) - require.Equal(t, native.BlockedAccounts{account}, accounts) + isBlocked := chain.contracts.Policy.IsBlockedInternal(chain.dao, account) + require.Equal(t, isBlocked, true) require.NoError(t, chain.persist()) res, err = invokeNativePolicyMethod(chain, "unblockAccount", account.BytesBE()) require.NoError(t, err) checkResult(t, res, stackitem.NewBool(true)) - accounts, err = chain.contracts.Policy.GetBlockedAccountsInternal(chain.dao) - require.NoError(t, err) - require.Equal(t, native.BlockedAccounts{}, accounts) + isBlocked = chain.contracts.Policy.IsBlockedInternal(chain.dao, account) + require.Equal(t, false, isBlocked) require.NoError(t, chain.persist()) }) @@ -220,28 +217,6 @@ func TestBlockedAccounts(t *testing.T) { checkResult(t, res, stackitem.NewBool(false)) require.NoError(t, chain.persist()) }) - - t.Run("sorted", func(t *testing.T) { - accounts := []util.Uint160{ - {2, 3, 4}, - {4, 5, 6}, - {3, 4, 5}, - {1, 2, 3}, - } - for _, acc := range accounts { - res, err := invokeNativePolicyMethod(chain, "blockAccount", acc.BytesBE()) - require.NoError(t, err) - checkResult(t, res, stackitem.NewBool(true)) - require.NoError(t, chain.persist()) - } - - sort.Slice(accounts, func(i, j int) bool { - return accounts[i].Less(accounts[j]) - }) - actual, err := chain.contracts.Policy.GetBlockedAccountsInternal(chain.dao) - require.NoError(t, err) - require.Equal(t, native.BlockedAccounts(accounts), actual) - }) } func invokeNativePolicyMethod(chain *Blockchain, method string, args ...interface{}) (*state.AppExecResult, error) { diff --git a/pkg/rpc/client/policy.go b/pkg/rpc/client/policy.go index 7e6aa59bf..efffcd492 100644 --- a/pkg/rpc/client/policy.go +++ b/pkg/rpc/client/policy.go @@ -3,7 +3,6 @@ package client import ( "fmt" - "github.com/nspcc-dev/neo-go/pkg/core/native" "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -41,39 +40,28 @@ func (c *Client) invokeNativePolicyMethod(operation string) (int64, error) { return topIntFromStack(result.Stack) } -// GetBlockedAccounts invokes `getBlockedAccounts` method on a native Policy contract. -func (c *Client) GetBlockedAccounts() (native.BlockedAccounts, error) { - result, err := c.InvokeFunction(PolicyContractHash, "getBlockedAccounts", []smartcontract.Parameter{}, nil) +// IsBlocked invokes `isBlocked` method on native Policy contract. +func (c *Client) IsBlocked(hash util.Uint160) (bool, error) { + result, err := c.InvokeFunction(PolicyContractHash, "isBlocked", []smartcontract.Parameter{{ + Type: smartcontract.Hash160Type, + Value: hash, + }}, nil) if err != nil { - return nil, err + return false, err } err = getInvocationError(result) if err != nil { - return nil, fmt.Errorf("failed to get blocked accounts: %w", err) + return false, fmt.Errorf("failed to check if account is blocked: %w", err) } - return topBlockedAccountsFromStack(result.Stack) + return topBoolFromStack(result.Stack) } -func topBlockedAccountsFromStack(st []stackitem.Item) (native.BlockedAccounts, error) { +// topBoolFromStack returns the top boolean value from stack +func topBoolFromStack(st []stackitem.Item) (bool, error) { index := len(st) - 1 // top stack element is last in the array - var ( - ba native.BlockedAccounts - err error - ) - items, ok := st[index].Value().([]stackitem.Item) + result, ok := st[index].Value().(bool) if !ok { - return nil, fmt.Errorf("invalid stack item type: %s", st[index].Type()) + return false, fmt.Errorf("invalid stack item type: %s", st[index].Type()) } - ba = make(native.BlockedAccounts, len(items)) - for i, account := range items { - val, ok := account.Value().([]byte) - if !ok { - return nil, fmt.Errorf("invalid array element: %s", account.Type()) - } - ba[i], err = util.Uint160DecodeBytesLE(val) - if err != nil { - return nil, err - } - } - return ba, nil + return result, nil } diff --git a/pkg/rpc/client/rpc_test.go b/pkg/rpc/client/rpc_test.go index 1be95bba9..d1913a12f 100644 --- a/pkg/rpc/client/rpc_test.go +++ b/pkg/rpc/client/rpc_test.go @@ -16,7 +16,6 @@ import ( "github.com/gorilla/websocket" "github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/core/block" - "github.com/nspcc-dev/neo-go/pkg/core/native" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" @@ -378,15 +377,15 @@ var rpcClientTestCases = map[string][]rpcClientTestCase{ }, }, }, - "getBlockedAccounts": { + "isBlocked": { { name: "positive", invoke: func(c *Client) (interface{}, error) { - return c.GetBlockedAccounts() + return c.IsBlocked(util.Uint160{1, 2, 3}) }, - serverResponse: `{"id":1,"jsonrpc":"2.0","result":{"state":"HALT","gasconsumed":"2007390","script":"EMAMEmdldEJsb2NrZWRBY2NvdW50cwwUmmGkbuyXuJMG186B8VtGIJHQCTJBYn1bUg==","stack":[{"type":"Array","value":[]}],"tx":null}}`, + serverResponse: `{"id":1,"jsonrpc":"2.0","result":{"state":"HALT","gasconsumed":"2007390","script":"EMAMEmdldEJsb2NrZWRBY2NvdW50cwwUmmGkbuyXuJMG186B8VtGIJHQCTJBYn1bUg==","stack":[{"type":"Boolean","value":false}],"tx":null}}`, result: func(c *Client) interface{} { - return native.BlockedAccounts{} + return false }, }, },