Merge pull request #1948 from nspcc-dev/extensible-dos

Prevent network DoS attack using Extensible payloads
This commit is contained in:
Roman Khimov 2021-05-12 10:58:51 +03:00 committed by GitHub
commit fec214055f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 134 additions and 55 deletions

View file

@ -30,4 +30,6 @@ type ApplicationConfiguration struct {
Oracle OracleConfiguration `yaml:"Oracle"` Oracle OracleConfiguration `yaml:"Oracle"`
P2PNotary P2PNotary `yaml:"P2PNotary"` P2PNotary P2PNotary `yaml:"P2PNotary"`
StateRoot StateRoot `yaml:"StateRoot"` StateRoot StateRoot `yaml:"StateRoot"`
// ExtensiblePoolSize is the maximum amount of the extensible payloads from a single sender.
ExtensiblePoolSize int `yaml:"ExtensiblePoolSize"`
} }

View file

@ -795,16 +795,12 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error
func (bc *Blockchain) updateExtensibleWhitelist(height uint32) error { func (bc *Blockchain) updateExtensibleWhitelist(height uint32) error {
updateCommittee := native.ShouldUpdateCommittee(height, bc) updateCommittee := native.ShouldUpdateCommittee(height, bc)
oracles, oh, err := bc.contracts.Designate.GetDesignatedByRole(bc.dao, noderoles.Oracle, height)
if err != nil {
return err
}
stateVals, sh, err := bc.contracts.Designate.GetDesignatedByRole(bc.dao, noderoles.StateValidator, height) stateVals, sh, err := bc.contracts.Designate.GetDesignatedByRole(bc.dao, noderoles.StateValidator, height)
if err != nil { if err != nil {
return err return err
} }
if bc.extensible.Load() != nil && !updateCommittee && oh != height && sh != height { if bc.extensible.Load() != nil && !updateCommittee && sh != height {
return nil return nil
} }
@ -817,15 +813,6 @@ func (bc *Blockchain) updateExtensibleWhitelist(height uint32) error {
newList = append(newList, hash.Hash160(script)) newList = append(newList, hash.Hash160(script))
bc.updateExtensibleList(&newList, bc.contracts.NEO.GetNextBlockValidatorsInternal()) bc.updateExtensibleList(&newList, bc.contracts.NEO.GetNextBlockValidatorsInternal())
if len(oracles) > 0 {
h, err := bc.contracts.Designate.GetLastDesignatedHash(bc.dao, noderoles.Oracle)
if err != nil {
return err
}
newList = append(newList, h)
bc.updateExtensibleList(&newList, oracles)
}
if len(stateVals) > 0 { if len(stateVals) > 0 {
h, err := bc.contracts.Designate.GetLastDesignatedHash(bc.dao, noderoles.StateValidator) h, err := bc.contracts.Designate.GetLastDesignatedHash(bc.dao, noderoles.StateValidator)
if err != nil { if err != nil {

View file

@ -1,6 +1,7 @@
package extpool package extpool
import ( import (
"container/list"
"errors" "errors"
"sync" "sync"
@ -12,15 +13,24 @@ import (
// Pool represents pool of extensible payloads. // Pool represents pool of extensible payloads.
type Pool struct { type Pool struct {
lock sync.RWMutex lock sync.RWMutex
verified map[util.Uint256]*payload.Extensible verified map[util.Uint256]*list.Element
chain blockchainer.Blockchainer senders map[util.Uint160]*list.List
// singleCap represents maximum number of payloads from the single sender.
singleCap int
chain blockchainer.Blockchainer
} }
// New returns new payload pool using provided chain. // New returns new payload pool using provided chain.
func New(bc blockchainer.Blockchainer) *Pool { func New(bc blockchainer.Blockchainer, capacity int) *Pool {
if capacity <= 0 {
panic("invalid capacity")
}
return &Pool{ return &Pool{
verified: make(map[util.Uint256]*payload.Extensible), verified: make(map[util.Uint256]*list.Element),
chain: bc, senders: make(map[util.Uint160]*list.List),
singleCap: capacity,
chain: bc,
} }
} }
@ -44,7 +54,17 @@ func (p *Pool) Add(e *payload.Extensible) (bool, error) {
if _, ok := p.verified[h]; ok { if _, ok := p.verified[h]; ok {
return false, nil return false, nil
} }
p.verified[h] = e
lst, ok := p.senders[e.Sender]
if ok && lst.Len() >= p.singleCap {
value := lst.Remove(lst.Front())
delete(p.verified, value.(*payload.Extensible).Hash())
} else if !ok {
lst = list.New()
p.senders[e.Sender] = lst
}
p.verified[h] = lst.PushBack(e)
return true, nil return true, nil
} }
@ -72,7 +92,11 @@ func (p *Pool) Get(h util.Uint256) *payload.Extensible {
p.lock.RLock() p.lock.RLock()
defer p.lock.RUnlock() defer p.lock.RUnlock()
return p.verified[h] elem, ok := p.verified[h]
if !ok {
return nil
}
return elem.Value.(*payload.Extensible)
} }
const extensibleVerifyMaxGAS = 2000000 const extensibleVerifyMaxGAS = 2000000
@ -81,13 +105,27 @@ const extensibleVerifyMaxGAS = 2000000
func (p *Pool) RemoveStale(index uint32) { func (p *Pool) RemoveStale(index uint32) {
p.lock.Lock() p.lock.Lock()
defer p.lock.Unlock() defer p.lock.Unlock()
for h, e := range p.verified {
if e.ValidBlockEnd <= index || !p.chain.IsExtensibleAllowed(e.Sender) { for s, lst := range p.senders {
delete(p.verified, h) for elem := lst.Front(); elem != nil; {
continue e := elem.Value.(*payload.Extensible)
h := e.Hash()
old := elem
elem = elem.Next()
if e.ValidBlockEnd <= index || !p.chain.IsExtensibleAllowed(e.Sender) {
delete(p.verified, h)
lst.Remove(old)
continue
}
if err := p.chain.VerifyWitness(e.Sender, e, &e.Witness, extensibleVerifyMaxGAS); err != nil {
delete(p.verified, h)
lst.Remove(old)
continue
}
} }
if err := p.chain.VerifyWitness(e.Sender, e, &e.Witness, extensibleVerifyMaxGAS); err != nil { if lst.Len() == 0 {
delete(p.verified, h) delete(p.senders, s)
} }
} }
} }

View file

@ -16,7 +16,7 @@ func TestAddGet(t *testing.T) {
bc := newTestChain() bc := newTestChain()
bc.height = 10 bc.height = 10
p := New(bc) p := New(bc, 100)
t.Run("invalid witness", func(t *testing.T) { t.Run("invalid witness", func(t *testing.T) {
ep := &payload.Extensible{ValidBlockEnd: 100, Sender: util.Uint160{0x42}} ep := &payload.Extensible{ValidBlockEnd: 100, Sender: util.Uint160{0x42}}
p.testAdd(t, false, errVerification, ep) p.testAdd(t, false, errVerification, ep)
@ -41,11 +41,52 @@ func TestAddGet(t *testing.T) {
}) })
} }
func TestCapacityLimit(t *testing.T) {
bc := newTestChain()
bc.height = 10
t.Run("invalid capacity", func(t *testing.T) {
require.Panics(t, func() { New(bc, 0) })
})
p := New(bc, 3)
first := &payload.Extensible{ValidBlockEnd: 11}
p.testAdd(t, true, nil, first)
for _, height := range []uint32{12, 13} {
ep := &payload.Extensible{ValidBlockEnd: height}
p.testAdd(t, true, nil, ep)
}
require.NotNil(t, p.Get(first.Hash()))
ok, err := p.Add(&payload.Extensible{ValidBlockEnd: 14})
require.True(t, ok)
require.NoError(t, err)
require.Nil(t, p.Get(first.Hash()))
}
// This test checks that sender count is updated
// when oldest payload is removed during `Add`.
func TestDecreaseSenderOnEvict(t *testing.T) {
bc := newTestChain()
bc.height = 10
p := New(bc, 2)
senders := []util.Uint160{{1}, {2}, {3}}
for i := uint32(11); i < 17; i++ {
ep := &payload.Extensible{Sender: senders[i%3], ValidBlockEnd: i}
p.testAdd(t, true, nil, ep)
}
}
func TestRemoveStale(t *testing.T) { func TestRemoveStale(t *testing.T) {
bc := newTestChain() bc := newTestChain()
bc.height = 10 bc.height = 10
p := New(bc) p := New(bc, 100)
eps := []*payload.Extensible{ eps := []*payload.Extensible{
{ValidBlockEnd: 11}, // small height {ValidBlockEnd: 11}, // small height
{ValidBlockEnd: 12}, // good {ValidBlockEnd: 12}, // good
@ -55,7 +96,7 @@ func TestRemoveStale(t *testing.T) {
for i := range eps { for i := range eps {
p.testAdd(t, true, nil, eps[i]) p.testAdd(t, true, nil, eps[i])
} }
bc.verifyWitness = func(u util.Uint160) bool { println("call"); return u[0] != 0x12 } bc.verifyWitness = func(u util.Uint160) bool { return u[0] != 0x12 }
bc.isAllowed = func(u util.Uint160) bool { return u[0] != 0x11 } bc.isAllowed = func(u util.Uint160) bool { return u[0] != 0x11 }
p.RemoveStale(11) p.RemoveStale(11)
require.Nil(t, p.Get(eps[0].Hash())) require.Nil(t, p.Get(eps[0].Hash()))

View file

@ -30,11 +30,12 @@ import (
const ( const (
// peer numbers are arbitrary at the moment. // peer numbers are arbitrary at the moment.
defaultMinPeers = 5 defaultMinPeers = 5
defaultAttemptConnPeers = 20 defaultAttemptConnPeers = 20
defaultMaxPeers = 100 defaultMaxPeers = 100
maxBlockBatch = 200 defaultExtensiblePoolSize = 20
minPoolCount = 30 maxBlockBatch = 200
minPoolCount = 30
) )
var ( var (
@ -121,6 +122,12 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai
return nil, errors.New("logger is a required parameter") return nil, errors.New("logger is a required parameter")
} }
if config.ExtensiblePoolSize <= 0 {
config.ExtensiblePoolSize = defaultExtensiblePoolSize
log.Info("ExtensiblePoolSize is not set or wrong, using default value",
zap.Int("ExtensiblePoolSize", config.ExtensiblePoolSize))
}
s := &Server{ s := &Server{
ServerConfig: config, ServerConfig: config,
chain: chain, chain: chain,
@ -132,7 +139,7 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai
unregister: make(chan peerDrop), unregister: make(chan peerDrop),
peers: make(map[Peer]bool), peers: make(map[Peer]bool),
syncReached: atomic.NewBool(false), syncReached: atomic.NewBool(false),
extensiblePool: extpool.New(chain), extensiblePool: extpool.New(chain, config.ExtensiblePoolSize),
log: log, log: log,
transactions: make(chan *transaction.Transaction, 64), transactions: make(chan *transaction.Transaction, 64),
} }

View file

@ -78,6 +78,9 @@ type (
// StateRootCfg is stateroot module configuration. // StateRootCfg is stateroot module configuration.
StateRootCfg config.StateRoot StateRootCfg config.StateRoot
// ExtensiblePoolSize is size of the pool for extensible payloads from a single sender.
ExtensiblePoolSize int
} }
) )
@ -93,24 +96,25 @@ func NewServerConfig(cfg config.Config) ServerConfig {
} }
return ServerConfig{ return ServerConfig{
UserAgent: cfg.GenerateUserAgent(), UserAgent: cfg.GenerateUserAgent(),
Address: appConfig.Address, Address: appConfig.Address,
AnnouncedPort: appConfig.AnnouncedNodePort, AnnouncedPort: appConfig.AnnouncedNodePort,
Port: appConfig.NodePort, Port: appConfig.NodePort,
Net: protoConfig.Magic, Net: protoConfig.Magic,
Relay: appConfig.Relay, Relay: appConfig.Relay,
Seeds: protoConfig.SeedList, Seeds: protoConfig.SeedList,
DialTimeout: appConfig.DialTimeout * time.Second, DialTimeout: appConfig.DialTimeout * time.Second,
ProtoTickInterval: appConfig.ProtoTickInterval * time.Second, ProtoTickInterval: appConfig.ProtoTickInterval * time.Second,
PingInterval: appConfig.PingInterval * time.Second, PingInterval: appConfig.PingInterval * time.Second,
PingTimeout: appConfig.PingTimeout * time.Second, PingTimeout: appConfig.PingTimeout * time.Second,
MaxPeers: appConfig.MaxPeers, MaxPeers: appConfig.MaxPeers,
AttemptConnPeers: appConfig.AttemptConnPeers, AttemptConnPeers: appConfig.AttemptConnPeers,
MinPeers: appConfig.MinPeers, MinPeers: appConfig.MinPeers,
Wallet: wc, Wallet: wc,
TimePerBlock: time.Duration(protoConfig.SecondsPerBlock) * time.Second, TimePerBlock: time.Duration(protoConfig.SecondsPerBlock) * time.Second,
OracleCfg: appConfig.Oracle, OracleCfg: appConfig.Oracle,
P2PNotaryCfg: appConfig.P2PNotary, P2PNotaryCfg: appConfig.P2PNotary,
StateRootCfg: appConfig.StateRoot, StateRootCfg: appConfig.StateRoot,
ExtensiblePoolSize: appConfig.ExtensiblePoolSize,
} }
} }