package server

import (
	"encoding/json"
	"fmt"
	"strings"
	"testing"
	"time"

	"github.com/gorilla/websocket"
	"github.com/nspcc-dev/neo-go/internal/testchain"
	"github.com/nspcc-dev/neo-go/pkg/core"
	"github.com/nspcc-dev/neo-go/pkg/encoding/address"
	"github.com/nspcc-dev/neo-go/pkg/rpc/response"
	"github.com/stretchr/testify/require"
	"go.uber.org/atomic"
)

func wsReader(t *testing.T, ws *websocket.Conn, msgCh chan<- []byte, isFinished *atomic.Bool) {
	for {
		ws.SetReadDeadline(time.Now().Add(time.Second))
		_, body, err := ws.ReadMessage()
		if isFinished.Load() {
			require.Error(t, err)
			break
		}
		require.NoError(t, err)
		msgCh <- body
	}
}

func callWSGetRaw(t *testing.T, ws *websocket.Conn, msg string, respCh <-chan []byte) *response.Raw {
	var resp = new(response.Raw)

	ws.SetWriteDeadline(time.Now().Add(time.Second))
	require.NoError(t, ws.WriteMessage(websocket.TextMessage, []byte(msg)))

	body := <-respCh
	require.NoError(t, json.Unmarshal(body, resp))
	return resp
}

func getNotification(t *testing.T, respCh <-chan []byte) *response.Notification {
	var resp = new(response.Notification)
	body := <-respCh
	require.NoError(t, json.Unmarshal(body, resp))
	return resp
}

func initCleanServerAndWSClient(t *testing.T) (*core.Blockchain, *Server, *websocket.Conn, chan []byte, *atomic.Bool) {
	chain, rpcSrv, httpSrv := initClearServerWithInMemoryChain(t)

	dialer := websocket.Dialer{HandshakeTimeout: time.Second}
	url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws"
	ws, _, err := dialer.Dial(url, nil)
	require.NoError(t, err)

	// Use buffered channel to read server's messages and then read expected
	// responses from it.
	respMsgs := make(chan []byte, 16)
	finishedFlag := atomic.NewBool(false)
	go wsReader(t, ws, respMsgs, finishedFlag)
	return chain, rpcSrv, ws, respMsgs, finishedFlag
}

func callSubscribe(t *testing.T, ws *websocket.Conn, msgs <-chan []byte, params string) string {
	var s string
	resp := callWSGetRaw(t, ws, fmt.Sprintf(`{"jsonrpc": "2.0","method": "subscribe","params": %s,"id": 1}`, params), msgs)
	require.Nil(t, resp.Error)
	require.NotNil(t, resp.Result)
	require.NoError(t, json.Unmarshal(resp.Result, &s))
	return s
}

func callUnsubscribe(t *testing.T, ws *websocket.Conn, msgs <-chan []byte, id string) {
	var b bool
	resp := callWSGetRaw(t, ws, fmt.Sprintf(`{"jsonrpc": "2.0","method": "unsubscribe","params": ["%s"],"id": 1}`, id), msgs)
	require.Nil(t, resp.Error)
	require.NotNil(t, resp.Result)
	require.NoError(t, json.Unmarshal(resp.Result, &b))
	require.Equal(t, true, b)
}

func TestSubscriptions(t *testing.T) {
	var subIDs = make([]string, 0)
	var subFeeds = []string{"block_added", "transaction_added", "notification_from_execution", "transaction_executed"}

	chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t)

	defer chain.Close()
	defer rpcSrv.Shutdown()

	for _, feed := range subFeeds {
		s := callSubscribe(t, c, respMsgs, fmt.Sprintf(`["%s"]`, feed))
		subIDs = append(subIDs, s)
	}

	for _, b := range getTestBlocks(t) {
		require.NoError(t, chain.AddBlock(b))
		resp := getNotification(t, respMsgs)
		require.Equal(t, response.ExecutionEventID, resp.Event)
		for {
			resp = getNotification(t, respMsgs)
			if resp.Event != response.NotificationEventID {
				break
			}
		}
		for i := 0; i < len(b.Transactions); i++ {
			if i > 0 {
				resp = getNotification(t, respMsgs)
			}
			require.Equal(t, response.ExecutionEventID, resp.Event)
			for {
				resp := getNotification(t, respMsgs)
				if resp.Event == response.NotificationEventID {
					continue
				}
				require.Equal(t, response.TransactionEventID, resp.Event)
				break
			}
		}
		resp = getNotification(t, respMsgs)
		require.Equal(t, response.ExecutionEventID, resp.Event)
		for {
			resp = getNotification(t, respMsgs)
			if resp.Event != response.NotificationEventID {
				break
			}
		}
		require.Equal(t, response.BlockEventID, resp.Event)
	}

	for _, id := range subIDs {
		callUnsubscribe(t, c, respMsgs, id)
	}
	finishedFlag.CAS(false, true)
	c.Close()
}

func TestFilteredSubscriptions(t *testing.T) {
	priv0 := testchain.PrivateKeyByID(0)
	var goodSender = priv0.GetScriptHash()

	var cases = map[string]struct {
		params string
		check  func(*testing.T, *response.Notification)
	}{
		"tx matching sender": {
			params: `["transaction_added", {"sender":"` + goodSender.StringLE() + `"}]`,
			check: func(t *testing.T, resp *response.Notification) {
				rmap := resp.Payload[0].(map[string]interface{})
				require.Equal(t, response.TransactionEventID, resp.Event)
				sender := rmap["sender"].(string)
				require.Equal(t, address.Uint160ToString(goodSender), sender)
			},
		},
		"tx matching signer": {
			params: `["transaction_added", {"signer":"` + goodSender.StringLE() + `"}]`,
			check: func(t *testing.T, resp *response.Notification) {
				rmap := resp.Payload[0].(map[string]interface{})
				require.Equal(t, response.TransactionEventID, resp.Event)
				signers := rmap["signers"].([]interface{})
				signer0 := signers[0].(map[string]interface{})
				signer0acc := signer0["account"].(string)
				require.Equal(t, "0x"+goodSender.StringLE(), signer0acc)
			},
		},
		"tx matching sender and signer": {
			params: `["transaction_added", {"sender":"` + goodSender.StringLE() + `", "signer":"` + goodSender.StringLE() + `"}]`,
			check: func(t *testing.T, resp *response.Notification) {
				rmap := resp.Payload[0].(map[string]interface{})
				require.Equal(t, response.TransactionEventID, resp.Event)
				sender := rmap["sender"].(string)
				require.Equal(t, address.Uint160ToString(goodSender), sender)
				signers := rmap["signers"].([]interface{})
				signer0 := signers[0].(map[string]interface{})
				signer0acc := signer0["account"].(string)
				require.Equal(t, "0x"+goodSender.StringLE(), signer0acc)
			},
		},
		"notification matching contract hash": {
			params: `["notification_from_execution", {"contract":"` + testContractHash + `"}]`,
			check: func(t *testing.T, resp *response.Notification) {
				rmap := resp.Payload[0].(map[string]interface{})
				require.Equal(t, response.NotificationEventID, resp.Event)
				c := rmap["contract"].(string)
				require.Equal(t, "0x"+testContractHash, c)
			},
		},
		"notification matching name": {
			params: `["notification_from_execution", {"name":"my_pretty_notification"}]`,
			check: func(t *testing.T, resp *response.Notification) {
				rmap := resp.Payload[0].(map[string]interface{})
				require.Equal(t, response.NotificationEventID, resp.Event)
				n := rmap["name"].(string)
				require.Equal(t, "my_pretty_notification", n)
			},
		},
		"notification matching contract hash and name": {
			params: `["notification_from_execution", {"contract":"` + testContractHash + `", "name":"my_pretty_notification"}]`,
			check: func(t *testing.T, resp *response.Notification) {
				rmap := resp.Payload[0].(map[string]interface{})
				require.Equal(t, response.NotificationEventID, resp.Event)
				c := rmap["contract"].(string)
				require.Equal(t, "0x"+testContractHash, c)
				n := rmap["name"].(string)
				require.Equal(t, "my_pretty_notification", n)
			},
		},
		"execution matching": {
			params: `["transaction_executed", {"state":"HALT"}]`,
			check: func(t *testing.T, resp *response.Notification) {
				rmap := resp.Payload[0].(map[string]interface{})
				require.Equal(t, response.ExecutionEventID, resp.Event)
				st := rmap["vmstate"].(string)
				require.Equal(t, "HALT", st)
			},
		},
		"tx non-matching": {
			params: `["transaction_added", {"sender":"00112233445566778899aabbccddeeff00112233"}]`,
			check: func(t *testing.T, _ *response.Notification) {
				t.Fatal("unexpected match for EnrollmentTransaction")
			},
		},
		"notification non-matching": {
			params: `["notification_from_execution", {"contract":"00112233445566778899aabbccddeeff00112233"}]`,
			check: func(t *testing.T, _ *response.Notification) {
				t.Fatal("unexpected match for contract 00112233445566778899aabbccddeeff00112233")
			},
		},
		"execution non-matching": {
			params: `["transaction_executed", {"state":"FAULT"}]`,
			check: func(t *testing.T, _ *response.Notification) {
				t.Fatal("unexpected match for faulted execution")
			},
		},
	}

	for name, this := range cases {
		t.Run(name, func(t *testing.T) {
			chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t)

			defer chain.Close()
			defer rpcSrv.Shutdown()

			// It's used as an end-of-event-stream, so it's always present.
			blockSubID := callSubscribe(t, c, respMsgs, `["block_added"]`)
			subID := callSubscribe(t, c, respMsgs, this.params)

			var lastBlock uint32
			for _, b := range getTestBlocks(t) {
				require.NoError(t, chain.AddBlock(b))
				lastBlock = b.Index
			}

			for {
				resp := getNotification(t, respMsgs)
				rmap := resp.Payload[0].(map[string]interface{})
				if resp.Event == response.BlockEventID {
					index := rmap["index"].(float64)
					if uint32(index) == lastBlock {
						break
					}
					continue
				}
				this.check(t, resp)
			}

			callUnsubscribe(t, c, respMsgs, subID)
			callUnsubscribe(t, c, respMsgs, blockSubID)
			finishedFlag.CAS(false, true)
			c.Close()
		})
	}
}

func TestFilteredBlockSubscriptions(t *testing.T) {
	// We can't fit this into TestFilteredSubscriptions, because it uses
	// blocks as EOF events to wait for.
	const numBlocks = 10
	chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t)

	defer chain.Close()
	defer rpcSrv.Shutdown()

	blockSubID := callSubscribe(t, c, respMsgs, `["block_added", {"primary":3}]`)

	var expectedCnt int
	for i := 0; i < numBlocks; i++ {
		primary := uint32(i % 4)
		if primary == 3 {
			expectedCnt++
		}
		b := testchain.NewBlock(t, chain, 1, primary)
		require.NoError(t, chain.AddBlock(b))
	}

	for i := 0; i < expectedCnt; i++ {
		var resp = new(response.Notification)
		select {
		case body := <-respMsgs:
			require.NoError(t, json.Unmarshal(body, resp))
		case <-time.After(time.Second):
			t.Fatal("timeout waiting for event")
		}

		require.Equal(t, response.BlockEventID, resp.Event)
		rmap := resp.Payload[0].(map[string]interface{})
		cd := rmap["consensusdata"].(map[string]interface{})
		primary := cd["primary"].(float64)
		require.Equal(t, 3, int(primary))

	}
	callUnsubscribe(t, c, respMsgs, blockSubID)
	finishedFlag.CAS(false, true)
	c.Close()
}

func TestMaxSubscriptions(t *testing.T) {
	var subIDs = make([]string, 0)
	chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t)

	defer chain.Close()
	defer rpcSrv.Shutdown()

	for i := 0; i < maxFeeds+1; i++ {
		var s string
		resp := callWSGetRaw(t, c, `{"jsonrpc": "2.0", "method": "subscribe", "params": ["block_added"], "id": 1}`, respMsgs)
		if i < maxFeeds {
			require.Nil(t, resp.Error)
			require.NotNil(t, resp.Result)
			require.NoError(t, json.Unmarshal(resp.Result, &s))
			// Each ID must be unique.
			for _, id := range subIDs {
				require.NotEqual(t, id, s)
			}
			subIDs = append(subIDs, s)
		} else {
			require.NotNil(t, resp.Error)
			require.Nil(t, resp.Result)
		}
	}

	finishedFlag.CAS(false, true)
	c.Close()
}

func TestBadSubUnsub(t *testing.T) {
	var subCases = map[string]string{
		"no params":              `{"jsonrpc": "2.0", "method": "subscribe", "params": [], "id": 1}`,
		"bad (non-string) event": `{"jsonrpc": "2.0", "method": "subscribe", "params": [1], "id": 1}`,
		"bad (wrong) event":      `{"jsonrpc": "2.0", "method": "subscribe", "params": ["block_removed"], "id": 1}`,
		"missed event":           `{"jsonrpc": "2.0", "method": "subscribe", "params": ["event_missed"], "id": 1}`,
		"block invalid filter":   `{"jsonrpc": "2.0", "method": "subscribe", "params": ["block_added", 1], "id": 1}`,
		"tx filter 1":            `{"jsonrpc": "2.0", "method": "subscribe", "params": ["transaction_added", 1], "id": 1}`,
		"tx filter 2":            `{"jsonrpc": "2.0", "method": "subscribe", "params": ["transaction_added", {"state": "HALT"}], "id": 1}`,
		"notification filter 1":  `{"jsonrpc": "2.0", "method": "subscribe", "params": ["notification_from_execution", "contract"], "id": 1}`,
		"notification filter 2":  `{"jsonrpc": "2.0", "method": "subscribe", "params": ["notification_from_execution", "name"], "id": 1}`,
		"execution filter 1":     `{"jsonrpc": "2.0", "method": "subscribe", "params": ["transaction_executed", "FAULT"], "id": 1}`,
		"execution filter 2":     `{"jsonrpc": "2.0", "method": "subscribe", "params": ["transaction_executed", {"state": "STOP"}], "id": 1}`,
	}
	var unsubCases = map[string]string{
		"no params":         `{"jsonrpc": "2.0", "method": "unsubscribe", "params": [], "id": 1}`,
		"bad id":            `{"jsonrpc": "2.0", "method": "unsubscribe", "params": ["vasiliy"], "id": 1}`,
		"not subscribed id": `{"jsonrpc": "2.0", "method": "unsubscribe", "params": ["7"], "id": 1}`,
	}
	chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t)

	defer chain.Close()
	defer rpcSrv.Shutdown()

	testF := func(t *testing.T, cases map[string]string) func(t *testing.T) {
		return func(t *testing.T) {
			for n, s := range cases {
				t.Run(n, func(t *testing.T) {
					resp := callWSGetRaw(t, c, s, respMsgs)
					require.NotNil(t, resp.Error)
					require.Nil(t, resp.Result)
				})
			}
		}
	}
	t.Run("subscribe", testF(t, subCases))
	t.Run("unsubscribe", testF(t, unsubCases))

	finishedFlag.CAS(false, true)
	c.Close()
}

func doSomeWSRequest(t *testing.T, ws *websocket.Conn) {
	ws.SetWriteDeadline(time.Now().Add(time.Second))
	// It could be just about anything including invalid request,
	// we only care about server handling being active.
	require.NoError(t, ws.WriteMessage(websocket.TextMessage, []byte(`{"jsonrpc": "2.0", "method": "getversion", "params": [], "id": 1}`)))
	ws.SetReadDeadline(time.Now().Add(time.Second))
	_, _, err := ws.ReadMessage()
	require.NoError(t, err)
}

func TestWSClientsLimit(t *testing.T) {
	chain, rpcSrv, httpSrv := initClearServerWithInMemoryChain(t)
	defer chain.Close()
	defer rpcSrv.Shutdown()

	dialer := websocket.Dialer{HandshakeTimeout: time.Second}
	url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws"
	wss := make([]*websocket.Conn, maxSubscribers)

	for i := 0; i < len(wss)+1; i++ {
		ws, _, err := dialer.Dial(url, nil)
		if i < maxSubscribers {
			require.NoError(t, err)
			wss[i] = ws
			// Check that it's completely ready.
			doSomeWSRequest(t, ws)
		} else {
			require.Error(t, err)
		}
	}
	// Check connections are still alive (it actually is necessary to add
	// some use of wss to keep connections alive).
	for i := 0; i < len(wss); i++ {
		doSomeWSRequest(t, wss[i])
	}
}

// The purpose of this test is to overflow buffers on server side to
// receive a 'missed' event. But it's actually hard to tell when exactly
// that's going to happen because of network-level buffering, typical
// number seen in tests is around ~3500 events, but it's not reliable enough,
// thus this test is disabled.
func testSubscriptionOverflow(t *testing.T) {
	const blockCnt = notificationBufSize * 5
	var receivedMiss bool

	chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t)

	defer chain.Close()
	defer rpcSrv.Shutdown()

	resp := callWSGetRaw(t, c, `{"jsonrpc": "2.0","method": "subscribe","params": ["block_added"],"id": 1}`, respMsgs)
	require.Nil(t, resp.Error)
	require.NotNil(t, resp.Result)

	// Push a lot of new blocks, but don't read events for them.
	for i := 0; i < blockCnt; i++ {
		b := testchain.NewBlock(t, chain, 1, 0)
		require.NoError(t, chain.AddBlock(b))
	}
	for i := 0; i < blockCnt; i++ {
		resp := getNotification(t, respMsgs)
		if resp.Event != response.BlockEventID {
			require.Equal(t, response.MissedEventID, resp.Event)
			receivedMiss = true
			break
		}
	}
	require.Equal(t, true, receivedMiss)
	// `Missed` is the last event and there is nothing afterwards.
	require.Equal(t, 0, len(respMsgs))

	finishedFlag.CAS(false, true)
	c.Close()
}