106 lines
2.9 KiB
Go
106 lines
2.9 KiB
Go
|
package extpool
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"testing"
|
||
|
|
||
|
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
||
|
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||
|
"github.com/nspcc-dev/neo-go/pkg/crypto"
|
||
|
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
)
|
||
|
|
||
|
func TestAddGet(t *testing.T) {
|
||
|
bc := newTestChain()
|
||
|
bc.height = 10
|
||
|
|
||
|
p := New(bc)
|
||
|
t.Run("invalid witness", func(t *testing.T) {
|
||
|
ep := &payload.Extensible{ValidBlockEnd: 100, Sender: util.Uint160{0x42}}
|
||
|
p.testAdd(t, false, errVerification, ep)
|
||
|
})
|
||
|
t.Run("disallowed sender", func(t *testing.T) {
|
||
|
ep := &payload.Extensible{ValidBlockEnd: 100, Sender: util.Uint160{0x41}}
|
||
|
p.testAdd(t, false, errDisallowedSender, ep)
|
||
|
})
|
||
|
t.Run("bad height", func(t *testing.T) {
|
||
|
ep := &payload.Extensible{ValidBlockEnd: 9}
|
||
|
p.testAdd(t, false, errInvalidHeight, ep)
|
||
|
|
||
|
ep = &payload.Extensible{ValidBlockEnd: 10}
|
||
|
p.testAdd(t, false, nil, ep)
|
||
|
})
|
||
|
t.Run("good", func(t *testing.T) {
|
||
|
ep := &payload.Extensible{ValidBlockEnd: 100}
|
||
|
p.testAdd(t, true, nil, ep)
|
||
|
require.Equal(t, ep, p.Get(ep.Hash()))
|
||
|
|
||
|
p.testAdd(t, false, nil, ep)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func TestRemoveStale(t *testing.T) {
|
||
|
bc := newTestChain()
|
||
|
bc.height = 10
|
||
|
|
||
|
p := New(bc)
|
||
|
eps := []*payload.Extensible{
|
||
|
{ValidBlockEnd: 11}, // small height
|
||
|
{ValidBlockEnd: 12}, // good
|
||
|
{Sender: util.Uint160{0x11}, ValidBlockEnd: 12}, // invalid sender
|
||
|
{Sender: util.Uint160{0x12}, ValidBlockEnd: 12}, // invalid witness
|
||
|
}
|
||
|
for i := range eps {
|
||
|
p.testAdd(t, true, nil, eps[i])
|
||
|
}
|
||
|
bc.verifyWitness = func(u util.Uint160) bool { println("call"); return u[0] != 0x12 }
|
||
|
bc.isAllowed = func(u util.Uint160) bool { return u[0] != 0x11 }
|
||
|
p.RemoveStale(11)
|
||
|
require.Nil(t, p.Get(eps[0].Hash()))
|
||
|
require.Equal(t, eps[1], p.Get(eps[1].Hash()))
|
||
|
require.Nil(t, p.Get(eps[2].Hash()))
|
||
|
require.Nil(t, p.Get(eps[3].Hash()))
|
||
|
}
|
||
|
|
||
|
func (p *Pool) testAdd(t *testing.T, expectedOk bool, expectedErr error, ep *payload.Extensible) {
|
||
|
ok, err := p.Add(ep)
|
||
|
if expectedErr != nil {
|
||
|
require.True(t, errors.Is(err, expectedErr), "got: %v", err)
|
||
|
} else {
|
||
|
require.NoError(t, err)
|
||
|
}
|
||
|
require.Equal(t, expectedOk, ok)
|
||
|
}
|
||
|
|
||
|
type testChain struct {
|
||
|
blockchainer.Blockchainer
|
||
|
height uint32
|
||
|
verifyWitness func(util.Uint160) bool
|
||
|
isAllowed func(util.Uint160) bool
|
||
|
}
|
||
|
|
||
|
var errVerification = errors.New("verification failed")
|
||
|
|
||
|
func newTestChain() *testChain {
|
||
|
return &testChain{
|
||
|
verifyWitness: func(u util.Uint160) bool {
|
||
|
return u[0] != 0x42
|
||
|
},
|
||
|
isAllowed: func(u util.Uint160) bool {
|
||
|
return u[0] != 0x42 && u[0] != 0x41
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
func (c *testChain) VerifyWitness(u util.Uint160, _ crypto.Verifiable, _ *transaction.Witness, _ int64) error {
|
||
|
if !c.verifyWitness(u) {
|
||
|
return errVerification
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
func (c *testChain) IsExtensibleAllowed(u util.Uint160) bool {
|
||
|
return c.isAllowed(u)
|
||
|
}
|
||
|
func (c *testChain) BlockHeight() uint32 { return c.height }
|