package tree

import (
	"context"
	"errors"
	"testing"

	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/pool"
	grpcService "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/pool/tree/service"
	"github.com/stretchr/testify/require"
	"go.uber.org/zap/zaptest"
)

type treeClientMock struct {
	address string
	err     bool
}

func (t *treeClientMock) serviceClient() (grpcService.TreeServiceClient, error) {
	if t.err {
		return nil, errors.New("serviceClient() mock error")
	}
	return nil, nil
}

func (t *treeClientMock) endpoint() string {
	return t.address
}

func (t *treeClientMock) isHealthy() bool {
	return true
}

func (t *treeClientMock) setHealthy(bool) {
	return
}

func (t *treeClientMock) dial(context.Context) error {
	return nil
}

func (t *treeClientMock) redialIfNecessary(context.Context) (bool, error) {
	if t.err {
		return false, errors.New("redialIfNecessary() mock error")
	}
	return false, nil
}

func (t *treeClientMock) close() error {
	return nil
}

func TestHandleError(t *testing.T) {
	defaultError := errors.New("default error")
	for _, tc := range []struct {
		err           error
		expectedError error
	}{
		{
			err:           defaultError,
			expectedError: defaultError,
		},
		{
			err:           errors.New("something not found"),
			expectedError: ErrNodeNotFound,
		},
		{
			err:           errors.New("something is denied by some acl rule"),
			expectedError: ErrNodeAccessDenied,
		},
	} {
		t.Run("", func(t *testing.T) {
			err := handleError("err message", tc.err)
			require.True(t, errors.Is(err, tc.expectedError))
		})
	}
}

func TestRetry(t *testing.T) {
	nodes := [][]string{
		{"node00", "node01", "node02", "node03"},
		{"node10", "node11", "node12", "node13"},
	}

	p := &Pool{
		logger:     zaptest.NewLogger(t),
		innerPools: makeInnerPool(nodes),
	}

	makeFn := func(client grpcService.TreeServiceClient) error {
		return nil
	}

	t.Run("first ok", func(t *testing.T) {
		err := p.requestWithRetry(makeFn)
		require.NoError(t, err)
		checkIndicesAndReset(t, p, 0, 0)
	})

	t.Run("first failed", func(t *testing.T) {
		setErrors(p, "node00")
		err := p.requestWithRetry(makeFn)
		require.NoError(t, err)
		checkIndicesAndReset(t, p, 0, 1)
	})

	t.Run("all failed", func(t *testing.T) {
		setErrors(p, nodes[0]...)
		setErrors(p, nodes[1]...)
		err := p.requestWithRetry(makeFn)
		require.Error(t, err)
		checkIndicesAndReset(t, p, 0, 0)
	})

	t.Run("round", func(t *testing.T) {
		setErrors(p, nodes[0][0], nodes[0][1])
		setErrors(p, nodes[1]...)
		err := p.requestWithRetry(makeFn)
		require.NoError(t, err)
		checkIndices(t, p, 0, 2)
		resetClientsErrors(p)

		setErrors(p, nodes[0][2], nodes[0][3])
		err = p.requestWithRetry(makeFn)
		require.NoError(t, err)
		checkIndicesAndReset(t, p, 0, 0)
	})

	t.Run("group switch", func(t *testing.T) {
		setErrors(p, nodes[0]...)
		setErrors(p, nodes[1][0])
		err := p.requestWithRetry(makeFn)
		require.NoError(t, err)
		checkIndicesAndReset(t, p, 1, 1)
	})

	t.Run("group round", func(t *testing.T) {
		setErrors(p, nodes[0][1:]...)
		err := p.requestWithRetry(makeFn)
		require.NoError(t, err)
		checkIndicesAndReset(t, p, 0, 0)
	})

	t.Run("group round switch", func(t *testing.T) {
		setErrors(p, nodes[0]...)
		p.setStartIndices(0, 1)
		err := p.requestWithRetry(makeFn)
		require.NoError(t, err)
		checkIndicesAndReset(t, p, 1, 0)
	})

	t.Run("no panic group switch", func(t *testing.T) {
		setErrors(p, nodes[1]...)
		p.setStartIndices(1, 0)
		err := p.requestWithRetry(makeFn)
		require.NoError(t, err)
		checkIndicesAndReset(t, p, 0, 0)
	})

	t.Run("error empty result", func(t *testing.T) {
		errNodes, index := 2, 0
		err := p.requestWithRetry(func(client grpcService.TreeServiceClient) error {
			if index < errNodes {
				index++
				return errNodeEmptyResult
			}
			return nil
		})
		require.NoError(t, err)
		checkIndicesAndReset(t, p, 0, errNodes)
	})

	t.Run("error not found", func(t *testing.T) {
		errNodes, index := 2, 0
		err := p.requestWithRetry(func(client grpcService.TreeServiceClient) error {
			if index < errNodes {
				index++
				return ErrNodeNotFound
			}
			return nil
		})
		require.NoError(t, err)
		checkIndicesAndReset(t, p, 0, errNodes)
	})

	t.Run("error access denied", func(t *testing.T) {
		var index int
		err := p.requestWithRetry(func(client grpcService.TreeServiceClient) error {
			index++
			return ErrNodeAccessDenied
		})
		require.ErrorIs(t, err, ErrNodeAccessDenied)
		require.Equal(t, 1, index)
		checkIndicesAndReset(t, p, 0, 0)
	})
}

func TestRebalance(t *testing.T) {
	nodes := [][]string{
		{"node00", "node01"},
		{"node10", "node11"},
	}

	p := &Pool{
		logger:     zaptest.NewLogger(t),
		innerPools: makeInnerPool(nodes),
		rebalanceParams: rebalanceParameters{
			nodesGroup: makeNodesGroup(nodes),
		},
	}

	ctx := context.Background()
	buffers := makeBuffer(p)

	t.Run("check dirty buffers", func(t *testing.T) {
		p.updateNodesHealth(ctx, buffers)
		checkIndices(t, p, 0, 0)
		setErrors(p, nodes[0][0])
		p.updateNodesHealth(ctx, buffers)
		checkIndices(t, p, 0, 1)
		resetClients(p)
	})

	t.Run("don't change healthy status", func(t *testing.T) {
		p.updateNodesHealth(ctx, buffers)
		checkIndices(t, p, 0, 0)
		resetClients(p)
	})

	t.Run("switch to second group", func(t *testing.T) {
		setErrors(p, nodes[0][0], nodes[0][1])
		p.updateNodesHealth(ctx, buffers)
		checkIndices(t, p, 1, 0)
		resetClients(p)
	})

	t.Run("switch back and forth", func(t *testing.T) {
		setErrors(p, nodes[0][0], nodes[0][1])
		p.updateNodesHealth(ctx, buffers)
		checkIndices(t, p, 1, 0)

		p.updateNodesHealth(ctx, buffers)
		checkIndices(t, p, 1, 0)

		setNoErrors(p, nodes[0][0])
		p.updateNodesHealth(ctx, buffers)
		checkIndices(t, p, 0, 0)

		resetClients(p)
	})
}

func makeInnerPool(nodes [][]string) []*innerPool {
	res := make([]*innerPool, len(nodes))

	for i, group := range nodes {
		res[i] = &innerPool{clients: make([]client, len(group))}
		for j, node := range group {
			res[i].clients[j] = &treeClientMock{address: node}
		}
	}

	return res
}

func makeNodesGroup(nodes [][]string) [][]pool.NodeParam {
	res := make([][]pool.NodeParam, len(nodes))

	for i, group := range nodes {
		res[i] = make([]pool.NodeParam, len(group))
		for j, node := range group {
			res[i][j] = pool.NewNodeParam(1, node, 1)
		}
	}

	return res
}

func makeBuffer(p *Pool) [][]bool {
	buffers := make([][]bool, len(p.rebalanceParams.nodesGroup))
	for i, nodes := range p.rebalanceParams.nodesGroup {
		buffers[i] = make([]bool, len(nodes))
	}
	return buffers
}

func checkIndicesAndReset(t *testing.T, p *Pool, iExp, jExp int) {
	checkIndices(t, p, iExp, jExp)
	resetClients(p)
}

func checkIndices(t *testing.T, p *Pool, iExp, jExp int) {
	i, j := p.getStartIndices()
	require.Equal(t, [2]int{iExp, jExp}, [2]int{i, j})
}

func resetClients(p *Pool) {
	resetClientsErrors(p)
	p.setStartIndices(0, 0)
}

func resetClientsErrors(p *Pool) {
	for _, group := range p.innerPools {
		for _, cl := range group.clients {
			node := cl.(*treeClientMock)
			node.err = false
		}
	}
}

func setErrors(p *Pool, nodes ...string) {
	setErrorsBase(p, true, nodes...)
}

func setNoErrors(p *Pool, nodes ...string) {
	setErrorsBase(p, false, nodes...)
}

func setErrorsBase(p *Pool, err bool, nodes ...string) {
	for _, group := range p.innerPools {
		for _, cl := range group.clients {
			node := cl.(*treeClientMock)
			if containsStr(nodes, node.address) {
				node.err = err
			}
		}
	}
}

func containsStr(list []string, item string) bool {
	for i := range list {
		if list[i] == item {
			return true
		}
	}

	return false
}