core: refactor blocked accounts logic

This commit is contained in:
Anna Shaleva 2020-10-21 15:51:59 +03:00
parent 6685f8eba9
commit 0da01fde7f
6 changed files with 77 additions and 273 deletions

View file

@ -1,53 +0,0 @@
package native
import (
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
)
// BlockedAccounts represents a slice of blocked accounts hashes.
type BlockedAccounts []util.Uint160
// Bytes returns serialized BlockedAccounts.
func (ba *BlockedAccounts) Bytes() []byte {
w := io.NewBufBinWriter()
ba.EncodeBinary(w.BinWriter)
if w.Err != nil {
panic(w.Err)
}
return w.Bytes()
}
// EncodeBinary implements io.Serializable interface.
func (ba *BlockedAccounts) EncodeBinary(w *io.BinWriter) {
w.WriteArray(*ba)
}
// BlockedAccountsFromBytes converts serialized BlockedAccounts to structure.
func BlockedAccountsFromBytes(b []byte) (BlockedAccounts, error) {
ba := new(BlockedAccounts)
if len(b) == 0 {
return *ba, nil
}
r := io.NewBinReaderFromBuf(b)
ba.DecodeBinary(r)
if r.Err != nil {
return nil, r.Err
}
return *ba, nil
}
// DecodeBinary implements io.Serializable interface.
func (ba *BlockedAccounts) DecodeBinary(r *io.BinReader) {
r.ReadArray(ba)
}
// ToStackItem converts BlockedAccounts to stackitem.Item
func (ba *BlockedAccounts) ToStackItem() stackitem.Item {
result := make([]stackitem.Item, len(*ba))
for i, account := range *ba {
result[i] = stackitem.NewByteArray(account.BytesLE())
}
return stackitem.NewArray(result)
}

View file

@ -1,64 +0,0 @@
package native
import (
"encoding/json"
"testing"
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
"github.com/stretchr/testify/require"
)
func TestEncodeDecodeBinary(t *testing.T) {
expected := &BlockedAccounts{
util.Uint160{1, 2, 3},
util.Uint160{4, 5, 6},
}
actual := new(BlockedAccounts)
testserdes.EncodeDecodeBinary(t, expected, actual)
expected = &BlockedAccounts{}
actual = new(BlockedAccounts)
testserdes.EncodeDecodeBinary(t, expected, actual)
}
func TestBytesFromBytes(t *testing.T) {
expected := BlockedAccounts{
util.Uint160{1, 2, 3},
util.Uint160{4, 5, 6},
}
actual, err := BlockedAccountsFromBytes(expected.Bytes())
require.NoError(t, err)
require.Equal(t, expected, actual)
expected = BlockedAccounts{}
actual, err = BlockedAccountsFromBytes(expected.Bytes())
require.NoError(t, err)
require.Equal(t, expected, actual)
}
func TestToStackItem(t *testing.T) {
u1 := util.Uint160{1, 2, 3}
u2 := util.Uint160{4, 5, 6}
expected := BlockedAccounts{u1, u2}
actual := stackitem.NewArray([]stackitem.Item{
stackitem.NewByteArray(u1.BytesLE()),
stackitem.NewByteArray(u2.BytesLE()),
})
require.Equal(t, expected.ToStackItem(), actual)
expected = BlockedAccounts{}
actual = stackitem.NewArray([]stackitem.Item{})
require.Equal(t, expected.ToStackItem(), actual)
}
func TestMarshallJSON(t *testing.T) {
ba := &BlockedAccounts{}
p := smartcontract.ParameterFromStackItem(ba.ToStackItem(), make(map[stackitem.Item]bool))
actual, err := json.Marshal(p)
require.NoError(t, err)
expected := `{"type":"Array","value":[]}`
require.Equal(t, expected, string(actual))
}

View file

@ -2,7 +2,6 @@ package native
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"math/big" "math/big"
"sort" "sort"
@ -34,6 +33,9 @@ const (
minBlockSystemFee = 4007600 minBlockSystemFee = 4007600
// maxFeePerByte is the maximum allowed fee per byte value. // maxFeePerByte is the maximum allowed fee per byte value.
maxFeePerByte = 100_000_000 maxFeePerByte = 100_000_000
// blockedAccountPrefix is a prefix used to store blocked account.
blockedAccountPrefix = 15
) )
var ( var (
@ -43,8 +45,6 @@ var (
// feePerByteKey is a key used to store the minimum fee per byte for // feePerByteKey is a key used to store the minimum fee per byte for
// transaction. // transaction.
feePerByteKey = []byte{10} feePerByteKey = []byte{10}
// blockedAccountsKey is a key used to store the list of blocked accounts.
blockedAccountsKey = []byte{15}
// maxBlockSizeKey is a key used to store the maximum block size value. // maxBlockSizeKey is a key used to store the maximum block size value.
maxBlockSizeKey = []byte{12} maxBlockSizeKey = []byte{12}
// maxBlockSystemFeeKey is a key used to store the maximum block system fee value. // maxBlockSystemFeeKey is a key used to store the maximum block system fee value.
@ -64,7 +64,7 @@ type Policy struct {
feePerByte int64 feePerByte int64
maxBlockSystemFee int64 maxBlockSystemFee int64
maxVerificationGas int64 maxVerificationGas int64
blockedAccounts BlockedAccounts blockedAccounts []util.Uint160
} }
var _ interop.Contract = (*Policy)(nil) var _ interop.Contract = (*Policy)(nil)
@ -88,8 +88,9 @@ func newPolicy() *Policy {
md = newMethodAndPrice(p.getFeePerByte, 1000000, smartcontract.AllowStates) md = newMethodAndPrice(p.getFeePerByte, 1000000, smartcontract.AllowStates)
p.AddMethod(md, desc, true) p.AddMethod(md, desc, true)
desc = newDescriptor("getBlockedAccounts", smartcontract.ArrayType) desc = newDescriptor("isBlocked", smartcontract.BoolType,
md = newMethodAndPrice(p.getBlockedAccounts, 1000000, smartcontract.AllowStates) manifest.NewParameter("account", smartcontract.Hash160Type))
md = newMethodAndPrice(p.isBlocked, 1000000, smartcontract.AllowStates)
p.AddMethod(md, desc, true) p.AddMethod(md, desc, true)
desc = newDescriptor("getMaxBlockSystemFee", smartcontract.IntegerType) desc = newDescriptor("getMaxBlockSystemFee", smartcontract.IntegerType)
@ -175,13 +176,6 @@ func (p *Policy) Initialize(ic *interop.Context) error {
return err return err
} }
ba := new(BlockedAccounts)
si.Value = ba.Bytes()
err = ic.DAO.PutStorageItem(p.ContractID, blockedAccountsKey, si)
if err != nil {
return err
}
p.isValid = true p.isValid = true
p.maxTransactionsPerBlock = defaultMaxTransactionsPerBlock p.maxTransactionsPerBlock = defaultMaxTransactionsPerBlock
p.maxBlockSize = defaultMaxBlockSize p.maxBlockSize = defaultMaxBlockSize
@ -220,15 +214,21 @@ func (p *Policy) OnPersistEnd(dao dao.DAO) error {
p.maxVerificationGas = defaultMaxVerificationGas p.maxVerificationGas = defaultMaxVerificationGas
si := dao.GetStorageItem(p.ContractID, blockedAccountsKey) p.blockedAccounts = make([]util.Uint160, 0)
if si == nil { siMap, err := dao.GetStorageItemsWithPrefix(p.ContractID, []byte{blockedAccountPrefix})
return errors.New("BlockedAccounts uninitialized")
}
ba, err := BlockedAccountsFromBytes(si.Value)
if err != nil { if err != nil {
return fmt.Errorf("failed to decode BlockedAccounts from bytes: %w", err) return fmt.Errorf("failed to get blocked accounts from storage: %w", err)
} }
p.blockedAccounts = ba for key := range siMap {
hash, err := util.Uint160DecodeBytesBE([]byte(key))
if err != nil {
return fmt.Errorf("failed to decode blocked account hash: %w", err)
}
p.blockedAccounts = append(p.blockedAccounts, hash)
}
sort.Slice(p.blockedAccounts, func(i, j int) bool {
return p.blockedAccounts[i].Less(p.blockedAccounts[j])
})
p.isValid = true p.isValid = true
return nil return nil
@ -306,32 +306,28 @@ func (p *Policy) GetMaxBlockSystemFeeInternal(dao dao.DAO) int64 {
return p.getInt64WithKey(dao, maxBlockSystemFeeKey) return p.getInt64WithKey(dao, maxBlockSystemFeeKey)
} }
// getBlockedAccounts is Policy contract method and returns list of blocked // isBlocked is Policy contract method and checks whether provided account is blocked.
// accounts hashes. func (p *Policy) isBlocked(ic *interop.Context, args []stackitem.Item) stackitem.Item {
func (p *Policy) getBlockedAccounts(ic *interop.Context, _ []stackitem.Item) stackitem.Item { hash := toUint160(args[0])
ba, err := p.GetBlockedAccountsInternal(ic.DAO) return stackitem.NewBool(p.IsBlockedInternal(ic.DAO, hash))
if err != nil {
panic(err)
}
return ba.ToStackItem()
} }
// GetBlockedAccountsInternal returns list of blocked accounts hashes. // IsBlockedInternal checks whether provided account is blocked
func (p *Policy) GetBlockedAccountsInternal(dao dao.DAO) (BlockedAccounts, error) { func (p *Policy) IsBlockedInternal(dao dao.DAO, hash util.Uint160) bool {
p.lock.RLock() p.lock.RLock()
defer p.lock.RUnlock() defer p.lock.RUnlock()
if p.isValid { if p.isValid {
return p.blockedAccounts, nil length := len(p.blockedAccounts)
i := sort.Search(length, func(i int) bool {
return !p.blockedAccounts[i].Less(hash)
})
if length != 0 && i != length && p.blockedAccounts[i].Equals(hash) {
return true
}
return false
} }
si := dao.GetStorageItem(p.ContractID, blockedAccountsKey) key := append([]byte{blockedAccountPrefix}, hash.BytesBE()...)
if si == nil { return dao.GetStorageItem(p.ContractID, key) != nil
return nil, errors.New("BlockedAccounts uninitialized")
}
ba, err := BlockedAccountsFromBytes(si.Value)
if err != nil {
return nil, err
}
return ba, nil
} }
// setMaxTransactionsPerBlock is Policy contract method and sets the upper limit // setMaxTransactionsPerBlock is Policy contract method and sets the upper limit
@ -437,30 +433,15 @@ func (p *Policy) blockAccount(ic *interop.Context, args []stackitem.Item) stacki
if !ok { if !ok {
return stackitem.NewBool(false) return stackitem.NewBool(false)
} }
value := toUint160(args[0]) hash := toUint160(args[0])
si := ic.DAO.GetStorageItem(p.ContractID, blockedAccountsKey) if p.IsBlockedInternal(ic.DAO, hash) {
if si == nil {
panic("BlockedAccounts uninitialized")
}
ba, err := BlockedAccountsFromBytes(si.Value)
if err != nil {
panic(err)
}
indexToInsert := sort.Search(len(ba), func(i int) bool {
return !ba[i].Less(value)
})
ba = append(ba, value)
if indexToInsert != len(ba)-1 && ba[indexToInsert].Equals(value) {
return stackitem.NewBool(false) return stackitem.NewBool(false)
} }
if len(ba) > 1 { key := append([]byte{blockedAccountPrefix}, hash.BytesBE()...)
copy(ba[indexToInsert+1:], ba[indexToInsert:])
ba[indexToInsert] = value
}
p.lock.Lock() p.lock.Lock()
defer p.lock.Unlock() defer p.lock.Unlock()
err = ic.DAO.PutStorageItem(p.ContractID, blockedAccountsKey, &state.StorageItem{ err = ic.DAO.PutStorageItem(p.ContractID, key, &state.StorageItem{
Value: ba.Bytes(), Value: []byte{0x01},
}) })
if err != nil { if err != nil {
panic(err) panic(err)
@ -479,27 +460,14 @@ func (p *Policy) unblockAccount(ic *interop.Context, args []stackitem.Item) stac
if !ok { if !ok {
return stackitem.NewBool(false) return stackitem.NewBool(false)
} }
value := toUint160(args[0]) hash := toUint160(args[0])
si := ic.DAO.GetStorageItem(p.ContractID, blockedAccountsKey) if !p.IsBlockedInternal(ic.DAO, hash) {
if si == nil {
panic("BlockedAccounts uninitialized")
}
ba, err := BlockedAccountsFromBytes(si.Value)
if err != nil {
panic(err)
}
indexToRemove := sort.Search(len(ba), func(i int) bool {
return !ba[i].Less(value)
})
if indexToRemove == len(ba) || !ba[indexToRemove].Equals(value) {
return stackitem.NewBool(false) return stackitem.NewBool(false)
} }
ba = append(ba[:indexToRemove], ba[indexToRemove+1:]...) key := append([]byte{blockedAccountPrefix}, hash.BytesBE()...)
p.lock.Lock() p.lock.Lock()
defer p.lock.Unlock() defer p.lock.Unlock()
err = ic.DAO.PutStorageItem(p.ContractID, blockedAccountsKey, &state.StorageItem{ err = ic.DAO.DeleteStorageItem(p.ContractID, key)
Value: ba.Bytes(),
})
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -555,18 +523,9 @@ func (p *Policy) checkValidators(ic *interop.Context) (bool, error) {
// like not being signed by blocked account or not exceeding block-level system // like not being signed by blocked account or not exceeding block-level system
// fee limit. // fee limit.
func (p *Policy) CheckPolicy(d dao.DAO, tx *transaction.Transaction) error { func (p *Policy) CheckPolicy(d dao.DAO, tx *transaction.Transaction) error {
ba, err := p.GetBlockedAccountsInternal(d) for _, signer := range tx.Signers {
if err != nil { if p.IsBlockedInternal(d, signer.Account) {
return fmt.Errorf("unable to get blocked accounts list: %w", err) return fmt.Errorf("account %s is blocked", signer.Account.StringLE())
}
if len(ba) > 0 {
for _, signer := range tx.Signers {
i := sort.Search(len(ba), func(i int) bool {
return !ba[i].Less(signer.Account)
})
if i != len(ba) && ba[i].Equals(signer.Account) {
return fmt.Errorf("account %s is blocked", signer.Account.StringLE())
}
} }
} }
maxBlockSystemFee := p.GetMaxBlockSystemFeeInternal(d) maxBlockSystemFee := p.GetMaxBlockSystemFeeInternal(d)

View file

@ -2,7 +2,6 @@ package core
import ( import (
"math/big" "math/big"
"sort"
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/block"
@ -10,6 +9,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
"github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
@ -162,16 +162,15 @@ func TestBlockedAccounts(t *testing.T) {
defer chain.Close() defer chain.Close()
account := util.Uint160{1, 2, 3} account := util.Uint160{1, 2, 3}
t.Run("get, internal method", func(t *testing.T) { t.Run("isBlocked, internal method", func(t *testing.T) {
accounts, err := chain.contracts.Policy.GetBlockedAccountsInternal(chain.dao) isBlocked := chain.contracts.Policy.IsBlockedInternal(chain.dao, random.Uint160())
require.NoError(t, err) require.Equal(t, false, isBlocked)
require.Equal(t, native.BlockedAccounts{}, accounts)
}) })
t.Run("get, contract method", func(t *testing.T) { t.Run("isBlocked, contract method", func(t *testing.T) {
res, err := invokeNativePolicyMethod(chain, "getBlockedAccounts") res, err := invokeNativePolicyMethod(chain, "isBlocked", random.Uint160())
require.NoError(t, err) require.NoError(t, err)
checkResult(t, res, stackitem.NewArray([]stackitem.Item{})) checkResult(t, res, stackitem.NewBool(false))
require.NoError(t, chain.persist()) require.NoError(t, chain.persist())
}) })
@ -180,18 +179,16 @@ func TestBlockedAccounts(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
checkResult(t, res, stackitem.NewBool(true)) checkResult(t, res, stackitem.NewBool(true))
accounts, err := chain.contracts.Policy.GetBlockedAccountsInternal(chain.dao) isBlocked := chain.contracts.Policy.IsBlockedInternal(chain.dao, account)
require.NoError(t, err) require.Equal(t, isBlocked, true)
require.Equal(t, native.BlockedAccounts{account}, accounts)
require.NoError(t, chain.persist()) require.NoError(t, chain.persist())
res, err = invokeNativePolicyMethod(chain, "unblockAccount", account.BytesBE()) res, err = invokeNativePolicyMethod(chain, "unblockAccount", account.BytesBE())
require.NoError(t, err) require.NoError(t, err)
checkResult(t, res, stackitem.NewBool(true)) checkResult(t, res, stackitem.NewBool(true))
accounts, err = chain.contracts.Policy.GetBlockedAccountsInternal(chain.dao) isBlocked = chain.contracts.Policy.IsBlockedInternal(chain.dao, account)
require.NoError(t, err) require.Equal(t, false, isBlocked)
require.Equal(t, native.BlockedAccounts{}, accounts)
require.NoError(t, chain.persist()) require.NoError(t, chain.persist())
}) })
@ -220,28 +217,6 @@ func TestBlockedAccounts(t *testing.T) {
checkResult(t, res, stackitem.NewBool(false)) checkResult(t, res, stackitem.NewBool(false))
require.NoError(t, chain.persist()) require.NoError(t, chain.persist())
}) })
t.Run("sorted", func(t *testing.T) {
accounts := []util.Uint160{
{2, 3, 4},
{4, 5, 6},
{3, 4, 5},
{1, 2, 3},
}
for _, acc := range accounts {
res, err := invokeNativePolicyMethod(chain, "blockAccount", acc.BytesBE())
require.NoError(t, err)
checkResult(t, res, stackitem.NewBool(true))
require.NoError(t, chain.persist())
}
sort.Slice(accounts, func(i, j int) bool {
return accounts[i].Less(accounts[j])
})
actual, err := chain.contracts.Policy.GetBlockedAccountsInternal(chain.dao)
require.NoError(t, err)
require.Equal(t, native.BlockedAccounts(accounts), actual)
})
} }
func invokeNativePolicyMethod(chain *Blockchain, method string, args ...interface{}) (*state.AppExecResult, error) { func invokeNativePolicyMethod(chain *Blockchain, method string, args ...interface{}) (*state.AppExecResult, error) {

View file

@ -3,7 +3,6 @@ package client
import ( import (
"fmt" "fmt"
"github.com/nspcc-dev/neo-go/pkg/core/native"
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
@ -41,39 +40,28 @@ func (c *Client) invokeNativePolicyMethod(operation string) (int64, error) {
return topIntFromStack(result.Stack) return topIntFromStack(result.Stack)
} }
// GetBlockedAccounts invokes `getBlockedAccounts` method on a native Policy contract. // IsBlocked invokes `isBlocked` method on native Policy contract.
func (c *Client) GetBlockedAccounts() (native.BlockedAccounts, error) { func (c *Client) IsBlocked(hash util.Uint160) (bool, error) {
result, err := c.InvokeFunction(PolicyContractHash, "getBlockedAccounts", []smartcontract.Parameter{}, nil) result, err := c.InvokeFunction(PolicyContractHash, "isBlocked", []smartcontract.Parameter{{
Type: smartcontract.Hash160Type,
Value: hash,
}}, nil)
if err != nil { if err != nil {
return nil, err return false, err
} }
err = getInvocationError(result) err = getInvocationError(result)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get blocked accounts: %w", err) return false, fmt.Errorf("failed to check if account is blocked: %w", err)
} }
return topBlockedAccountsFromStack(result.Stack) return topBoolFromStack(result.Stack)
} }
func topBlockedAccountsFromStack(st []stackitem.Item) (native.BlockedAccounts, error) { // topBoolFromStack returns the top boolean value from stack
func topBoolFromStack(st []stackitem.Item) (bool, error) {
index := len(st) - 1 // top stack element is last in the array index := len(st) - 1 // top stack element is last in the array
var ( result, ok := st[index].Value().(bool)
ba native.BlockedAccounts
err error
)
items, ok := st[index].Value().([]stackitem.Item)
if !ok { if !ok {
return nil, fmt.Errorf("invalid stack item type: %s", st[index].Type()) return false, fmt.Errorf("invalid stack item type: %s", st[index].Type())
} }
ba = make(native.BlockedAccounts, len(items)) return result, nil
for i, account := range items {
val, ok := account.Value().([]byte)
if !ok {
return nil, fmt.Errorf("invalid array element: %s", account.Type())
}
ba[i], err = util.Uint160DecodeBytesLE(val)
if err != nil {
return nil, err
}
}
return ba, nil
} }

View file

@ -16,7 +16,6 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/native"
"github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
@ -378,15 +377,15 @@ var rpcClientTestCases = map[string][]rpcClientTestCase{
}, },
}, },
}, },
"getBlockedAccounts": { "isBlocked": {
{ {
name: "positive", name: "positive",
invoke: func(c *Client) (interface{}, error) { invoke: func(c *Client) (interface{}, error) {
return c.GetBlockedAccounts() return c.IsBlocked(util.Uint160{1, 2, 3})
}, },
serverResponse: `{"id":1,"jsonrpc":"2.0","result":{"state":"HALT","gasconsumed":"2007390","script":"EMAMEmdldEJsb2NrZWRBY2NvdW50cwwUmmGkbuyXuJMG186B8VtGIJHQCTJBYn1bUg==","stack":[{"type":"Array","value":[]}],"tx":null}}`, serverResponse: `{"id":1,"jsonrpc":"2.0","result":{"state":"HALT","gasconsumed":"2007390","script":"EMAMEmdldEJsb2NrZWRBY2NvdW50cwwUmmGkbuyXuJMG186B8VtGIJHQCTJBYn1bUg==","stack":[{"type":"Boolean","value":false}],"tx":null}}`,
result: func(c *Client) interface{} { result: func(c *Client) interface{} {
return native.BlockedAccounts{} return false
}, },
}, },
}, },