neoneo-go/pkg/network/extpool/pool_test.go

106 lines
2.9 KiB
Go
Raw Normal View History

package extpool
import (
"errors"
"testing"
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require"
)
func TestAddGet(t *testing.T) {
bc := newTestChain()
bc.height = 10
p := New(bc)
t.Run("invalid witness", func(t *testing.T) {
ep := &payload.Extensible{ValidBlockEnd: 100, Sender: util.Uint160{0x42}}
p.testAdd(t, false, errVerification, ep)
})
t.Run("disallowed sender", func(t *testing.T) {
ep := &payload.Extensible{ValidBlockEnd: 100, Sender: util.Uint160{0x41}}
p.testAdd(t, false, errDisallowedSender, ep)
})
t.Run("bad height", func(t *testing.T) {
ep := &payload.Extensible{ValidBlockEnd: 9}
p.testAdd(t, false, errInvalidHeight, ep)
ep = &payload.Extensible{ValidBlockEnd: 10}
p.testAdd(t, false, nil, ep)
})
t.Run("good", func(t *testing.T) {
ep := &payload.Extensible{ValidBlockEnd: 100}
p.testAdd(t, true, nil, ep)
require.Equal(t, ep, p.Get(ep.Hash()))
p.testAdd(t, false, nil, ep)
})
}
func TestRemoveStale(t *testing.T) {
bc := newTestChain()
bc.height = 10
p := New(bc)
eps := []*payload.Extensible{
{ValidBlockEnd: 11}, // small height
{ValidBlockEnd: 12}, // good
{Sender: util.Uint160{0x11}, ValidBlockEnd: 12}, // invalid sender
{Sender: util.Uint160{0x12}, ValidBlockEnd: 12}, // invalid witness
}
for i := range eps {
p.testAdd(t, true, nil, eps[i])
}
bc.verifyWitness = func(u util.Uint160) bool { println("call"); return u[0] != 0x12 }
bc.isAllowed = func(u util.Uint160) bool { return u[0] != 0x11 }
p.RemoveStale(11)
require.Nil(t, p.Get(eps[0].Hash()))
require.Equal(t, eps[1], p.Get(eps[1].Hash()))
require.Nil(t, p.Get(eps[2].Hash()))
require.Nil(t, p.Get(eps[3].Hash()))
}
func (p *Pool) testAdd(t *testing.T, expectedOk bool, expectedErr error, ep *payload.Extensible) {
ok, err := p.Add(ep)
if expectedErr != nil {
require.True(t, errors.Is(err, expectedErr), "got: %v", err)
} else {
require.NoError(t, err)
}
require.Equal(t, expectedOk, ok)
}
type testChain struct {
blockchainer.Blockchainer
height uint32
verifyWitness func(util.Uint160) bool
isAllowed func(util.Uint160) bool
}
var errVerification = errors.New("verification failed")
func newTestChain() *testChain {
return &testChain{
verifyWitness: func(u util.Uint160) bool {
return u[0] != 0x42
},
isAllowed: func(u util.Uint160) bool {
return u[0] != 0x42 && u[0] != 0x41
},
}
}
func (c *testChain) VerifyWitness(u util.Uint160, _ hash.Hashable, _ *transaction.Witness, _ int64) error {
if !c.verifyWitness(u) {
return errVerification
}
return nil
}
func (c *testChain) IsExtensibleAllowed(u util.Uint160) bool {
return c.isAllowed(u)
}
func (c *testChain) BlockHeight() uint32 { return c.height }