mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-11-23 03:38:35 +00:00
network: return error if header message is too big
Big messages can still be processed but only first 2000 headers will be used.
This commit is contained in:
parent
637c99eda7
commit
9b8b77c9ea
4 changed files with 62 additions and 38 deletions
|
@ -201,13 +201,11 @@ func (m *Message) decodePayload(br *io.BinReader) error {
|
||||||
return fmt.Errorf("can't decode command %s", cmdByteArrayToString(m.Command))
|
return fmt.Errorf("can't decode command %s", cmdByteArrayToString(m.Command))
|
||||||
}
|
}
|
||||||
p.DecodeBinary(r)
|
p.DecodeBinary(r)
|
||||||
if r.Err != nil {
|
if r.Err == nil || r.Err == payload.ErrTooManyHeaders {
|
||||||
return r.Err
|
m.Payload = p
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Payload = p
|
return r.Err
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes a Message to any given BinWriter.
|
// Encode encodes a Message to any given BinWriter.
|
||||||
|
|
|
@ -3,6 +3,7 @@ package payload
|
||||||
import (
|
import (
|
||||||
"github.com/CityOfZion/neo-go/pkg/core"
|
"github.com/CityOfZion/neo-go/pkg/core"
|
||||||
"github.com/CityOfZion/neo-go/pkg/io"
|
"github.com/CityOfZion/neo-go/pkg/io"
|
||||||
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Headers payload.
|
// Headers payload.
|
||||||
|
@ -15,12 +16,17 @@ const (
|
||||||
MaxHeadersAllowed = 2000
|
MaxHeadersAllowed = 2000
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrTooManyHeaders is an error returned when too many headers were received.
|
||||||
|
var ErrTooManyHeaders = errors.Errorf("too many headers were received (max: %d)", MaxHeadersAllowed)
|
||||||
|
|
||||||
// DecodeBinary implements Serializable interface.
|
// DecodeBinary implements Serializable interface.
|
||||||
func (p *Headers) DecodeBinary(br *io.BinReader) {
|
func (p *Headers) DecodeBinary(br *io.BinReader) {
|
||||||
lenHeaders := br.ReadVarUint()
|
lenHeaders := br.ReadVarUint()
|
||||||
|
|
||||||
|
var limitExceeded bool
|
||||||
|
|
||||||
// C# node does it silently
|
// C# node does it silently
|
||||||
if lenHeaders > MaxHeadersAllowed {
|
if limitExceeded = lenHeaders > MaxHeadersAllowed; limitExceeded {
|
||||||
lenHeaders = MaxHeadersAllowed
|
lenHeaders = MaxHeadersAllowed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -31,6 +37,10 @@ func (p *Headers) DecodeBinary(br *io.BinReader) {
|
||||||
header.DecodeBinary(br)
|
header.DecodeBinary(br)
|
||||||
p.Hdrs[i] = header
|
p.Hdrs[i] = header
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if br.Err == nil && limitExceeded {
|
||||||
|
br.Err = ErrTooManyHeaders
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// EncodeBinary implements Serializable interface.
|
// EncodeBinary implements Serializable interface.
|
||||||
|
|
|
@ -11,36 +11,39 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHeadersEncodeDecode(t *testing.T) {
|
func TestHeadersEncodeDecode(t *testing.T) {
|
||||||
headers := &Headers{[]*core.Header{
|
t.Run("normal case", func(t *testing.T) {
|
||||||
{
|
headers := newTestHeaders(3)
|
||||||
BlockBase: core.BlockBase{
|
|
||||||
Version: 0,
|
|
||||||
Index: 1,
|
|
||||||
Script: transaction.Witness{
|
|
||||||
InvocationScript: []byte{0x0},
|
|
||||||
VerificationScript: []byte{0x1},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
{
|
|
||||||
BlockBase: core.BlockBase{
|
|
||||||
Version: 0,
|
|
||||||
Index: 2,
|
|
||||||
Script: transaction.Witness{
|
|
||||||
InvocationScript: []byte{0x0},
|
|
||||||
VerificationScript: []byte{0x1},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
{
|
|
||||||
BlockBase: core.BlockBase{
|
|
||||||
Version: 0,
|
|
||||||
Index: 3,
|
|
||||||
Script: transaction.Witness{
|
|
||||||
InvocationScript: []byte{0x0},
|
|
||||||
VerificationScript: []byte{0x1},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
}}
|
|
||||||
|
|
||||||
|
testHeadersEncodeDecode(t, headers, 3, false)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("more than max", func(t *testing.T) {
|
||||||
|
const sent = MaxHeadersAllowed + 1
|
||||||
|
headers := newTestHeaders(sent)
|
||||||
|
|
||||||
|
testHeadersEncodeDecode(t, headers, MaxHeadersAllowed, true)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestHeaders(n int) *Headers {
|
||||||
|
headers := &Headers{Hdrs: make([]*core.Header, n)}
|
||||||
|
|
||||||
|
for i := range headers.Hdrs {
|
||||||
|
headers.Hdrs[i] = &core.Header{
|
||||||
|
BlockBase: core.BlockBase{
|
||||||
|
Index: uint32(i + 1),
|
||||||
|
Script: transaction.Witness{
|
||||||
|
InvocationScript: []byte{0x0},
|
||||||
|
VerificationScript: []byte{0x1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func testHeadersEncodeDecode(t *testing.T, headers *Headers, expected int, limit bool) {
|
||||||
buf := io.NewBufBinWriter()
|
buf := io.NewBufBinWriter()
|
||||||
headers.EncodeBinary(buf.BinWriter)
|
headers.EncodeBinary(buf.BinWriter)
|
||||||
assert.Nil(t, buf.Err)
|
assert.Nil(t, buf.Err)
|
||||||
|
@ -49,9 +52,16 @@ func TestHeadersEncodeDecode(t *testing.T) {
|
||||||
r := io.NewBinReaderFromBuf(b)
|
r := io.NewBinReaderFromBuf(b)
|
||||||
headersDecode := &Headers{}
|
headersDecode := &Headers{}
|
||||||
headersDecode.DecodeBinary(r)
|
headersDecode.DecodeBinary(r)
|
||||||
assert.Nil(t, r.Err)
|
|
||||||
|
|
||||||
for i := 0; i < len(headers.Hdrs); i++ {
|
var err error
|
||||||
|
if limit {
|
||||||
|
err = ErrTooManyHeaders
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, err, r.Err)
|
||||||
|
assert.Equal(t, expected, len(headersDecode.Hdrs))
|
||||||
|
|
||||||
|
for i := 0; i < len(headersDecode.Hdrs); i++ {
|
||||||
assert.Equal(t, headers.Hdrs[i].Version, headersDecode.Hdrs[i].Version)
|
assert.Equal(t, headers.Hdrs[i].Version, headersDecode.Hdrs[i].Version)
|
||||||
assert.Equal(t, headers.Hdrs[i].Index, headersDecode.Hdrs[i].Index)
|
assert.Equal(t, headers.Hdrs[i].Index, headersDecode.Hdrs[i].Index)
|
||||||
assert.Equal(t, headers.Hdrs[i].Script, headersDecode.Hdrs[i].Script)
|
assert.Equal(t, headers.Hdrs[i].Script, headersDecode.Hdrs[i].Script)
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/CityOfZion/neo-go/pkg/io"
|
"github.com/CityOfZion/neo-go/pkg/io"
|
||||||
|
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -87,7 +88,12 @@ func (t *TCPTransport) handleConn(conn net.Conn) {
|
||||||
r := io.NewBinReaderFromIO(p.conn)
|
r := io.NewBinReaderFromIO(p.conn)
|
||||||
for {
|
for {
|
||||||
msg := &Message{}
|
msg := &Message{}
|
||||||
if err = msg.Decode(r); err != nil {
|
err := msg.Decode(r)
|
||||||
|
|
||||||
|
if err == payload.ErrTooManyHeaders {
|
||||||
|
t.log.Warn("not all headers were processed")
|
||||||
|
r.Err = nil
|
||||||
|
} else if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if err = t.server.handleMessage(p, msg); err != nil {
|
if err = t.server.handleMessage(p, msg); err != nil {
|
||||||
|
|
Loading…
Reference in a new issue