package extpool

import (
	"errors"
	"testing"

	"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
	"github.com/nspcc-dev/neo-go/pkg/core/transaction"
	"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
	"github.com/nspcc-dev/neo-go/pkg/network/payload"
	"github.com/nspcc-dev/neo-go/pkg/util"
	"github.com/stretchr/testify/require"
)

func TestAddGet(t *testing.T) {
	bc := newTestChain()
	bc.height = 10

	p := New(bc)
	t.Run("invalid witness", func(t *testing.T) {
		ep := &payload.Extensible{ValidBlockEnd: 100, Sender: util.Uint160{0x42}}
		p.testAdd(t, false, errVerification, ep)
	})
	t.Run("disallowed sender", func(t *testing.T) {
		ep := &payload.Extensible{ValidBlockEnd: 100, Sender: util.Uint160{0x41}}
		p.testAdd(t, false, errDisallowedSender, ep)
	})
	t.Run("bad height", func(t *testing.T) {
		ep := &payload.Extensible{ValidBlockEnd: 9}
		p.testAdd(t, false, errInvalidHeight, ep)

		ep = &payload.Extensible{ValidBlockEnd: 10}
		p.testAdd(t, false, nil, ep)
	})
	t.Run("good", func(t *testing.T) {
		ep := &payload.Extensible{ValidBlockEnd: 100}
		p.testAdd(t, true, nil, ep)
		require.Equal(t, ep, p.Get(ep.Hash()))

		p.testAdd(t, false, nil, ep)
	})
}

func TestRemoveStale(t *testing.T) {
	bc := newTestChain()
	bc.height = 10

	p := New(bc)
	eps := []*payload.Extensible{
		{ValidBlockEnd: 11},                             // small height
		{ValidBlockEnd: 12},                             // good
		{Sender: util.Uint160{0x11}, ValidBlockEnd: 12}, // invalid sender
		{Sender: util.Uint160{0x12}, ValidBlockEnd: 12}, // invalid witness
	}
	for i := range eps {
		p.testAdd(t, true, nil, eps[i])
	}
	bc.verifyWitness = func(u util.Uint160) bool { println("call"); return u[0] != 0x12 }
	bc.isAllowed = func(u util.Uint160) bool { return u[0] != 0x11 }
	p.RemoveStale(11)
	require.Nil(t, p.Get(eps[0].Hash()))
	require.Equal(t, eps[1], p.Get(eps[1].Hash()))
	require.Nil(t, p.Get(eps[2].Hash()))
	require.Nil(t, p.Get(eps[3].Hash()))
}

func (p *Pool) testAdd(t *testing.T, expectedOk bool, expectedErr error, ep *payload.Extensible) {
	ok, err := p.Add(ep)
	if expectedErr != nil {
		require.True(t, errors.Is(err, expectedErr), "got: %v", err)
	} else {
		require.NoError(t, err)
	}
	require.Equal(t, expectedOk, ok)
}

type testChain struct {
	blockchainer.Blockchainer
	height        uint32
	verifyWitness func(util.Uint160) bool
	isAllowed     func(util.Uint160) bool
}

var errVerification = errors.New("verification failed")

func newTestChain() *testChain {
	return &testChain{
		verifyWitness: func(u util.Uint160) bool {
			return u[0] != 0x42
		},
		isAllowed: func(u util.Uint160) bool {
			return u[0] != 0x42 && u[0] != 0x41
		},
	}
}
func (c *testChain) VerifyWitness(u util.Uint160, _ hash.Hashable, _ *transaction.Witness, _ int64) error {
	if !c.verifyWitness(u) {
		return errVerification
	}
	return nil
}
func (c *testChain) IsExtensibleAllowed(u util.Uint160) bool {
	return c.isAllowed(u)
}
func (c *testChain) BlockHeight() uint32 { return c.height }