forked from TrueCloudLab/neoneo-go
network: return error if header message is too big
Big messages can still be processed but only first 2000 headers will be used.
This commit is contained in:
parent
637c99eda7
commit
9b8b77c9ea
4 changed files with 62 additions and 38 deletions
|
@ -201,13 +201,11 @@ func (m *Message) decodePayload(br *io.BinReader) error {
|
|||
return fmt.Errorf("can't decode command %s", cmdByteArrayToString(m.Command))
|
||||
}
|
||||
p.DecodeBinary(r)
|
||||
if r.Err != nil {
|
||||
return r.Err
|
||||
if r.Err == nil || r.Err == payload.ErrTooManyHeaders {
|
||||
m.Payload = p
|
||||
}
|
||||
|
||||
m.Payload = p
|
||||
|
||||
return nil
|
||||
return r.Err
|
||||
}
|
||||
|
||||
// Encode encodes a Message to any given BinWriter.
|
||||
|
|
|
@ -3,6 +3,7 @@ package payload
|
|||
import (
|
||||
"github.com/CityOfZion/neo-go/pkg/core"
|
||||
"github.com/CityOfZion/neo-go/pkg/io"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Headers payload.
|
||||
|
@ -15,12 +16,17 @@ const (
|
|||
MaxHeadersAllowed = 2000
|
||||
)
|
||||
|
||||
// ErrTooManyHeaders is an error returned when too many headers were received.
|
||||
var ErrTooManyHeaders = errors.Errorf("too many headers were received (max: %d)", MaxHeadersAllowed)
|
||||
|
||||
// DecodeBinary implements Serializable interface.
|
||||
func (p *Headers) DecodeBinary(br *io.BinReader) {
|
||||
lenHeaders := br.ReadVarUint()
|
||||
|
||||
var limitExceeded bool
|
||||
|
||||
// C# node does it silently
|
||||
if lenHeaders > MaxHeadersAllowed {
|
||||
if limitExceeded = lenHeaders > MaxHeadersAllowed; limitExceeded {
|
||||
lenHeaders = MaxHeadersAllowed
|
||||
}
|
||||
|
||||
|
@ -31,6 +37,10 @@ func (p *Headers) DecodeBinary(br *io.BinReader) {
|
|||
header.DecodeBinary(br)
|
||||
p.Hdrs[i] = header
|
||||
}
|
||||
|
||||
if br.Err == nil && limitExceeded {
|
||||
br.Err = ErrTooManyHeaders
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeBinary implements Serializable interface.
|
||||
|
|
|
@ -11,36 +11,39 @@ import (
|
|||
)
|
||||
|
||||
func TestHeadersEncodeDecode(t *testing.T) {
|
||||
headers := &Headers{[]*core.Header{
|
||||
{
|
||||
BlockBase: core.BlockBase{
|
||||
Version: 0,
|
||||
Index: 1,
|
||||
Script: transaction.Witness{
|
||||
InvocationScript: []byte{0x0},
|
||||
VerificationScript: []byte{0x1},
|
||||
},
|
||||
}},
|
||||
{
|
||||
BlockBase: core.BlockBase{
|
||||
Version: 0,
|
||||
Index: 2,
|
||||
Script: transaction.Witness{
|
||||
InvocationScript: []byte{0x0},
|
||||
VerificationScript: []byte{0x1},
|
||||
},
|
||||
}},
|
||||
{
|
||||
BlockBase: core.BlockBase{
|
||||
Version: 0,
|
||||
Index: 3,
|
||||
Script: transaction.Witness{
|
||||
InvocationScript: []byte{0x0},
|
||||
VerificationScript: []byte{0x1},
|
||||
},
|
||||
}},
|
||||
}}
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
headers := newTestHeaders(3)
|
||||
|
||||
testHeadersEncodeDecode(t, headers, 3, false)
|
||||
})
|
||||
|
||||
t.Run("more than max", func(t *testing.T) {
|
||||
const sent = MaxHeadersAllowed + 1
|
||||
headers := newTestHeaders(sent)
|
||||
|
||||
testHeadersEncodeDecode(t, headers, MaxHeadersAllowed, true)
|
||||
})
|
||||
}
|
||||
|
||||
func newTestHeaders(n int) *Headers {
|
||||
headers := &Headers{Hdrs: make([]*core.Header, n)}
|
||||
|
||||
for i := range headers.Hdrs {
|
||||
headers.Hdrs[i] = &core.Header{
|
||||
BlockBase: core.BlockBase{
|
||||
Index: uint32(i + 1),
|
||||
Script: transaction.Witness{
|
||||
InvocationScript: []byte{0x0},
|
||||
VerificationScript: []byte{0x1},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
func testHeadersEncodeDecode(t *testing.T, headers *Headers, expected int, limit bool) {
|
||||
buf := io.NewBufBinWriter()
|
||||
headers.EncodeBinary(buf.BinWriter)
|
||||
assert.Nil(t, buf.Err)
|
||||
|
@ -49,9 +52,16 @@ func TestHeadersEncodeDecode(t *testing.T) {
|
|||
r := io.NewBinReaderFromBuf(b)
|
||||
headersDecode := &Headers{}
|
||||
headersDecode.DecodeBinary(r)
|
||||
assert.Nil(t, r.Err)
|
||||
|
||||
for i := 0; i < len(headers.Hdrs); i++ {
|
||||
var err error
|
||||
if limit {
|
||||
err = ErrTooManyHeaders
|
||||
}
|
||||
|
||||
assert.Equal(t, err, r.Err)
|
||||
assert.Equal(t, expected, len(headersDecode.Hdrs))
|
||||
|
||||
for i := 0; i < len(headersDecode.Hdrs); i++ {
|
||||
assert.Equal(t, headers.Hdrs[i].Version, headersDecode.Hdrs[i].Version)
|
||||
assert.Equal(t, headers.Hdrs[i].Index, headersDecode.Hdrs[i].Index)
|
||||
assert.Equal(t, headers.Hdrs[i].Script, headersDecode.Hdrs[i].Script)
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/io"
|
||||
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
|
@ -87,7 +88,12 @@ func (t *TCPTransport) handleConn(conn net.Conn) {
|
|||
r := io.NewBinReaderFromIO(p.conn)
|
||||
for {
|
||||
msg := &Message{}
|
||||
if err = msg.Decode(r); err != nil {
|
||||
err := msg.Decode(r)
|
||||
|
||||
if err == payload.ErrTooManyHeaders {
|
||||
t.log.Warn("not all headers were processed")
|
||||
r.Err = nil
|
||||
} else if err != nil {
|
||||
break
|
||||
}
|
||||
if err = t.server.handleMessage(p, msg); err != nil {
|
||||
|
|
Loading…
Reference in a new issue