From e1408b65255fc62c0f300f6f3bf09acd3f96cdf7 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Mon, 11 May 2020 01:00:19 +0300 Subject: [PATCH] rpc/server: add notification subscription Note that the protocol differs a bit from #895 in its notifications format, to avoid additional server-side processing we're omitting some metadata like: * block size and confirmations * transaction fees, confirmations, block hash and timestamp * application execution doesn't have ScriptHash populated Some block fields may also differ in encoding compared to `getblock` results (like nonce field). I think these differences are unnoticieable for most use cases, so we can leave them as is, but it can be changed in the future. --- pkg/rpc/response/events.go | 79 ++++++ pkg/rpc/response/result/application_log.go | 18 +- pkg/rpc/response/types.go | 9 + pkg/rpc/server/server.go | 316 ++++++++++++++++++++- pkg/rpc/server/server_helper_test.go | 28 +- pkg/rpc/server/subscription.go | 35 +++ pkg/rpc/server/subscription_test.go | 227 +++++++++++++++ 7 files changed, 688 insertions(+), 24 deletions(-) create mode 100644 pkg/rpc/response/events.go create mode 100644 pkg/rpc/server/subscription.go create mode 100644 pkg/rpc/server/subscription_test.go diff --git a/pkg/rpc/response/events.go b/pkg/rpc/response/events.go new file mode 100644 index 000000000..1efba39a5 --- /dev/null +++ b/pkg/rpc/response/events.go @@ -0,0 +1,79 @@ +package response + +import ( + "encoding/json" + + "github.com/pkg/errors" +) + +type ( + // EventID represents an event type happening on the chain. + EventID byte +) + +const ( + // InvalidEventID is an invalid event id that is the default value of + // EventID. It's only used as an initial value similar to nil. + InvalidEventID EventID = iota + // BlockEventID is a `block_added` event. + BlockEventID + // TransactionEventID corresponds to `transaction_added` event. + TransactionEventID + // NotificationEventID represents `notification_from_execution` events. + NotificationEventID + // ExecutionEventID is used for `transaction_executed` events. + ExecutionEventID +) + +// String is a good old Stringer implementation. +func (e EventID) String() string { + switch e { + case BlockEventID: + return "block_added" + case TransactionEventID: + return "transaction_added" + case NotificationEventID: + return "notification_from_execution" + case ExecutionEventID: + return "transaction_executed" + default: + return "unknown" + } +} + +// GetEventIDFromString converts input string into an EventID if it's possible. +func GetEventIDFromString(s string) (EventID, error) { + switch s { + case "block_added": + return BlockEventID, nil + case "transaction_added": + return TransactionEventID, nil + case "notification_from_execution": + return NotificationEventID, nil + case "transaction_executed": + return ExecutionEventID, nil + default: + return 255, errors.New("invalid stream name") + } +} + +// MarshalJSON implements json.Marshaler interface. +func (e EventID) MarshalJSON() ([]byte, error) { + return json.Marshal(e.String()) +} + +// UnmarshalJSON implements json.Unmarshaler interface. +func (e *EventID) UnmarshalJSON(b []byte) error { + var s string + + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + id, err := GetEventIDFromString(s) + if err != nil { + return err + } + *e = id + return nil +} diff --git a/pkg/rpc/response/result/application_log.go b/pkg/rpc/response/result/application_log.go index da436f2f4..ee59fc43d 100644 --- a/pkg/rpc/response/result/application_log.go +++ b/pkg/rpc/response/result/application_log.go @@ -30,16 +30,22 @@ type NotificationEvent struct { Item smartcontract.Parameter `json:"state"` } +// StateEventToResultNotification converts state.NotificationEvent to +// result.NotificationEvent. +func StateEventToResultNotification(event state.NotificationEvent) NotificationEvent { + seen := make(map[vm.StackItem]bool) + item := event.Item.ToContractParameter(seen) + return NotificationEvent{ + Contract: event.ScriptHash, + Item: item, + } +} + // NewApplicationLog creates a new ApplicationLog wrapper. func NewApplicationLog(appExecRes *state.AppExecResult, scriptHash util.Uint160) ApplicationLog { events := make([]NotificationEvent, 0, len(appExecRes.Events)) for _, e := range appExecRes.Events { - seen := make(map[vm.StackItem]bool) - item := e.Item.ToContractParameter(seen) - events = append(events, NotificationEvent{ - Contract: e.ScriptHash, - Item: item, - }) + events = append(events, StateEventToResultNotification(e)) } triggerString := appExecRes.Trigger.String() diff --git a/pkg/rpc/response/types.go b/pkg/rpc/response/types.go index 0b236826a..ba23c7677 100644 --- a/pkg/rpc/response/types.go +++ b/pkg/rpc/response/types.go @@ -37,3 +37,12 @@ type GetRawTx struct { HeaderAndError Result *result.TransactionOutputRaw `json:"result"` } + +// Notification is a type used to represent wire format of events, they're +// special in that they look like requests but they don't have IDs and their +// "method" is actually an event name. +type Notification struct { + JSONRPC string `json:"jsonrpc"` + Event EventID `json:"method"` + Payload []interface{} `json:"params"` +} diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index 5641b9501..2c2a2da41 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "strconv" + "sync" "time" "github.com/gorilla/websocket" @@ -41,6 +42,19 @@ type ( coreServer *network.Server log *zap.Logger https *http.Server + shutdown chan struct{} + + subsLock sync.RWMutex + subscribers map[*subscriber]bool + subsGroup sync.WaitGroup + blockSubs int + executionSubs int + notificationSubs int + transactionSubs int + blockCh chan *block.Block + executionCh chan *state.AppExecResult + notificationCh chan *state.NotificationEvent + transactionCh chan *transaction.Transaction } ) @@ -56,6 +70,11 @@ const ( // Write deadline. wsWriteLimit = wsPingPeriod / 2 + + // Maximum number of subscribers per Server. Each websocket client is + // treated like subscriber, so technically it's a limit on websocket + // connections. + maxSubscribers = 64 ) var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *response.Error){ @@ -91,6 +110,11 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *respon "validateaddress": (*Server).validateAddress, } +var rpcWsHandlers = map[string]func(*Server, request.Params, *subscriber) (interface{}, *response.Error){ + "subscribe": (*Server).subscribe, + "unsubscribe": (*Server).unsubscribe, +} + var invalidBlockHeightError = func(index int, height int) *response.Error { return response.NewRPCError(fmt.Sprintf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height), "", nil) } @@ -119,6 +143,14 @@ func New(chain core.Blockchainer, conf rpc.Config, coreServer *network.Server, l coreServer: coreServer, log: log, https: tlsServer, + shutdown: make(chan struct{}), + + subscribers: make(map[*subscriber]bool), + // These are NOT buffered to preserve original order of events. + blockCh: make(chan *block.Block), + executionCh: make(chan *state.AppExecResult), + notificationCh: make(chan *state.NotificationEvent), + transactionCh: make(chan *transaction.Transaction), } } @@ -133,6 +165,7 @@ func (s *Server) Start(errChan chan error) { s.Handler = http.HandlerFunc(s.handleHTTPRequest) s.log.Info("starting rpc-server", zap.String("endpoint", s.Addr)) + go s.handleSubEvents() if cfg := s.config.TLSConfig; cfg.Enabled { s.https.Handler = http.HandlerFunc(s.handleHTTPRequest) s.log.Info("starting rpc-server (https)", zap.String("endpoint", s.https.Addr)) @@ -155,6 +188,10 @@ func (s *Server) Start(errChan chan error) { // method. func (s *Server) Shutdown() error { var httpsErr error + + // Signal to websocket writer routines and handleSubEvents. + close(s.shutdown) + if s.config.TLSConfig.Enabled { s.log.Info("shutting down rpc-server (https)", zap.String("endpoint", s.https.Addr)) httpsErr = s.https.Shutdown(context.Background()) @@ -162,6 +199,10 @@ func (s *Server) Shutdown() error { s.log.Info("shutting down rpc-server", zap.String("endpoint", s.Addr)) err := s.Server.Shutdown(context.Background()) + + // Wait for handleSubEvents to finish. + <-s.executionCh + if err == nil { return httpsErr } @@ -169,20 +210,40 @@ func (s *Server) Shutdown() error { } func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) { + req := request.NewIn() + if httpRequest.URL.Path == "/ws" && httpRequest.Method == "GET" { + // Technically there is a race between this check and + // s.subscribers modification 20 lines below, but it's tiny + // and not really critical to bother with it. Some additional + // clients may sneak in, no big deal. + s.subsLock.RLock() + numOfSubs := len(s.subscribers) + s.subsLock.RUnlock() + if numOfSubs >= maxSubscribers { + s.writeHTTPErrorResponse( + req, + w, + response.NewInternalServerError("websocket users limit reached", nil), + ) + return + } ws, err := upgrader.Upgrade(w, httpRequest, nil) if err != nil { s.log.Info("websocket connection upgrade failed", zap.Error(err)) return } resChan := make(chan response.Raw) - go s.handleWsWrites(ws, resChan) - s.handleWsReads(ws, resChan) + subChan := make(chan *websocket.PreparedMessage, notificationBufSize) + subscr := &subscriber{writer: subChan, ws: ws} + s.subsLock.Lock() + s.subscribers[subscr] = true + s.subsLock.Unlock() + go s.handleWsWrites(ws, resChan, subChan) + s.handleWsReads(ws, resChan, subscr) return } - req := request.NewIn() - if httpRequest.Method != "POST" { s.writeHTTPErrorResponse( req, @@ -200,11 +261,14 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ return } - resp := s.handleRequest(req) + resp := s.handleRequest(req, nil) s.writeHTTPServerResponse(req, w, resp) } -func (s *Server) handleRequest(req *request.In) response.Raw { +func (s *Server) handleRequest(req *request.In, sub *subscriber) response.Raw { + var res interface{} + var resErr *response.Error + reqParams, err := req.Params() if err != nil { return s.packResponseToRaw(req, nil, response.NewInvalidParamsError("Problem parsing request parameters", err)) @@ -216,20 +280,37 @@ func (s *Server) handleRequest(req *request.In) response.Raw { incCounter(req.Method) + resErr = response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil) handler, ok := rpcHandlers[req.Method] - if !ok { - return s.packResponseToRaw(req, nil, response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil)) + if ok { + res, resErr = handler(s, *reqParams) + } else if sub != nil { + handler, ok := rpcWsHandlers[req.Method] + if ok { + res, resErr = handler(s, *reqParams, sub) + } } - res, resErr := handler(s, *reqParams) return s.packResponseToRaw(req, res, resErr) } -func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw) { +func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw, subChan <-chan *websocket.PreparedMessage) { pingTicker := time.NewTicker(wsPingPeriod) defer ws.Close() defer pingTicker.Stop() for { select { + case <-s.shutdown: + // Signal to the reader routine. + ws.Close() + return + case event, ok := <-subChan: + if !ok { + return + } + ws.SetWriteDeadline(time.Now().Add(wsWriteLimit)) + if err := ws.WritePreparedMessage(event); err != nil { + return + } case res, ok := <-resChan: if !ok { return @@ -247,22 +328,36 @@ func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw) } } -func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.Raw) { +func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.Raw, subscr *subscriber) { ws.SetReadLimit(wsReadLimit) ws.SetReadDeadline(time.Now().Add(wsPongLimit)) ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil }) +requestloop: for { req := new(request.In) err := ws.ReadJSON(req) if err != nil { break } - res := s.handleRequest(req) + res := s.handleRequest(req, subscr) if res.Error != nil { s.logRequestError(req, res.Error) } - resChan <- res + select { + case <-s.shutdown: + break requestloop + case resChan <- res: + } + } + s.subsLock.Lock() + delete(s.subscribers, subscr) + for _, e := range subscr.feeds { + if e != response.InvalidEventID { + s.unsubscribeFromChannel(e) + } + } + s.subsLock.Unlock() close(resChan) ws.Close() } @@ -1025,6 +1120,201 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *res return results, resultsErr } +// subscribe handles subscription requests from websocket clients. +func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) { + p, ok := reqParams.Value(0) + if !ok { + return nil, response.ErrInvalidParams + } + streamName, err := p.GetString() + if err != nil { + return nil, response.ErrInvalidParams + } + event, err := response.GetEventIDFromString(streamName) + if err != nil { + return nil, response.ErrInvalidParams + } + s.subsLock.Lock() + defer s.subsLock.Unlock() + select { + case <-s.shutdown: + return nil, response.NewInternalServerError("server is shutting down", nil) + default: + } + var id int + for ; id < len(sub.feeds); id++ { + if sub.feeds[id] == response.InvalidEventID { + break + } + } + if id == len(sub.feeds) { + return nil, response.NewInternalServerError("maximum number of subscriptions is reached", nil) + } + sub.feeds[id] = event + s.subscribeToChannel(event) + return strconv.FormatInt(int64(id), 10), nil +} + +// subscribeToChannel subscribes RPC server to appropriate chain events if +// it's not yet subscribed for them. It's supposed to be called with s.subsLock +// taken by the caller. +func (s *Server) subscribeToChannel(event response.EventID) { + switch event { + case response.BlockEventID: + if s.blockSubs == 0 { + s.chain.SubscribeForBlocks(s.blockCh) + } + s.blockSubs++ + case response.TransactionEventID: + if s.transactionSubs == 0 { + s.chain.SubscribeForTransactions(s.transactionCh) + } + s.transactionSubs++ + case response.NotificationEventID: + if s.notificationSubs == 0 { + s.chain.SubscribeForNotifications(s.notificationCh) + } + s.notificationSubs++ + case response.ExecutionEventID: + if s.executionSubs == 0 { + s.chain.SubscribeForExecutions(s.executionCh) + } + s.executionSubs++ + } +} + +// unsubscribe handles unsubscription requests from websocket clients. +func (s *Server) unsubscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) { + p, ok := reqParams.Value(0) + if !ok { + return nil, response.ErrInvalidParams + } + id, err := p.GetInt() + if err != nil || id < 0 { + return nil, response.ErrInvalidParams + } + s.subsLock.Lock() + defer s.subsLock.Unlock() + if len(sub.feeds) <= id || sub.feeds[id] == response.InvalidEventID { + return nil, response.ErrInvalidParams + } + event := sub.feeds[id] + sub.feeds[id] = response.InvalidEventID + s.unsubscribeFromChannel(event) + return true, nil +} + +// unsubscribeFromChannel unsubscribes RPC server from appropriate chain events +// if there are no other subscribers for it. It's supposed to be called with +// s.subsLock taken by the caller. +func (s *Server) unsubscribeFromChannel(event response.EventID) { + switch event { + case response.BlockEventID: + s.blockSubs-- + if s.blockSubs == 0 { + s.chain.UnsubscribeFromBlocks(s.blockCh) + } + case response.TransactionEventID: + s.transactionSubs-- + if s.transactionSubs == 0 { + s.chain.UnsubscribeFromTransactions(s.transactionCh) + } + case response.NotificationEventID: + s.notificationSubs-- + if s.notificationSubs == 0 { + s.chain.UnsubscribeFromNotifications(s.notificationCh) + } + case response.ExecutionEventID: + s.executionSubs-- + if s.executionSubs == 0 { + s.chain.UnsubscribeFromExecutions(s.executionCh) + } + } +} + +func (s *Server) handleSubEvents() { +chloop: + for { + var resp = response.Notification{ + JSONRPC: request.JSONRPCVersion, + Payload: make([]interface{}, 1), + } + var msg *websocket.PreparedMessage + select { + case <-s.shutdown: + break chloop + case b := <-s.blockCh: + resp.Event = response.BlockEventID + resp.Payload[0] = b + case execution := <-s.executionCh: + resp.Event = response.ExecutionEventID + resp.Payload[0] = result.NewApplicationLog(execution, util.Uint160{}) + case notification := <-s.notificationCh: + resp.Event = response.NotificationEventID + resp.Payload[0] = result.StateEventToResultNotification(*notification) + case tx := <-s.transactionCh: + resp.Event = response.TransactionEventID + resp.Payload[0] = tx + } + s.subsLock.RLock() + subloop: + for sub := range s.subscribers { + for _, subID := range sub.feeds { + if subID == resp.Event { + if msg == nil { + b, err := json.Marshal(resp) + if err != nil { + s.log.Error("failed to marshal notification", + zap.Error(err), + zap.String("type", resp.Event.String())) + break subloop + } + msg, err = websocket.NewPreparedMessage(websocket.TextMessage, b) + if err != nil { + s.log.Error("failed to prepare notification message", + zap.Error(err), + zap.String("type", resp.Event.String())) + break subloop + } + } + sub.writer <- msg + // The message is sent only once per subscriber. + break + } + } + } + s.subsLock.RUnlock() + } + // It's important to do it with lock held because no subscription routine + // should be running concurrently to this one. And even if one is to run + // after unlock, it'll see closed s.shutdown and won't subscribe. + s.subsLock.Lock() + // There might be no subscription in reality, but it's not a problem as + // core.Blockchain allows unsubscribing non-subscribed channels. + s.chain.UnsubscribeFromBlocks(s.blockCh) + s.chain.UnsubscribeFromTransactions(s.transactionCh) + s.chain.UnsubscribeFromNotifications(s.notificationCh) + s.chain.UnsubscribeFromExecutions(s.executionCh) + s.subsLock.Unlock() +drainloop: + for { + select { + case <-s.blockCh: + case <-s.executionCh: + case <-s.notificationCh: + case <-s.transactionCh: + default: + break drainloop + } + } + // It's not required closing these, but since they're drained already + // this is safe and it also allows to give a signal to Shutdown routine. + close(s.blockCh) + close(s.transactionCh) + close(s.notificationCh) + close(s.executionCh) +} + func (s *Server) blockHeightFromParam(param *request.Param) (int, *response.Error) { num, err := param.GetInt() if err != nil { diff --git a/pkg/rpc/server/server_helper_test.go b/pkg/rpc/server/server_helper_test.go index 7dc8fd263..61bcc2e0d 100644 --- a/pkg/rpc/server/server_helper_test.go +++ b/pkg/rpc/server/server_helper_test.go @@ -15,12 +15,11 @@ import ( "github.com/nspcc-dev/neo-go/pkg/network" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/require" + "go.uber.org/zap" "go.uber.org/zap/zaptest" ) -func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *httptest.Server) { - var nBlocks uint32 - +func getUnitTestChain(t *testing.T) (*core.Blockchain, config.Config, *zap.Logger) { net := config.ModeUnitTestNet configPath := "../../../config" cfg, err := config.Load(configPath, net) @@ -33,6 +32,11 @@ func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *http go chain.Run() + return chain, cfg, logger +} + +func getTestBlocks(t *testing.T) []*block.Block { + blocks := make([]*block.Block, 0) // File "./testdata/testblocks.acc" was generated by function core._ // ("neo-go/pkg/core/helper_test.go"). // To generate new "./testdata/testblocks.acc", follow the steps: @@ -42,15 +46,20 @@ func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *http f, err := os.Open("testdata/testblocks.acc") require.Nil(t, err) br := io.NewBinReaderFromIO(f) - nBlocks = br.ReadU32LE() + nBlocks := br.ReadU32LE() require.Nil(t, br.Err) for i := 0; i < int(nBlocks); i++ { _ = br.ReadU32LE() b := &block.Block{} b.DecodeBinary(br) require.Nil(t, br.Err) - require.NoError(t, chain.AddBlock(b)) + blocks = append(blocks, b) } + return blocks +} + +func initClearServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *httptest.Server) { + chain, cfg, logger := getUnitTestChain(t) serverConfig := network.NewServerConfig(cfg) server, err := network.NewServer(serverConfig, chain, logger) @@ -65,6 +74,15 @@ func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *http return chain, &rpcServer, srv } +func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *httptest.Server) { + chain, rpcServer, srv := initClearServerWithInMemoryChain(t) + + for _, b := range getTestBlocks(t) { + require.NoError(t, chain.AddBlock(b)) + } + return chain, rpcServer, srv +} + type FeerStub struct{} func (fs *FeerStub) NetworkFee(*transaction.Transaction) util.Fixed8 { diff --git a/pkg/rpc/server/subscription.go b/pkg/rpc/server/subscription.go new file mode 100644 index 000000000..10c9e25ec --- /dev/null +++ b/pkg/rpc/server/subscription.go @@ -0,0 +1,35 @@ +package server + +import ( + "github.com/gorilla/websocket" + "github.com/nspcc-dev/neo-go/pkg/rpc/response" +) + +type ( + // subscriber is an event subscriber. + subscriber struct { + writer chan<- *websocket.PreparedMessage + ws *websocket.Conn + + // These work like slots as there is not a lot of them (it's + // cheaper doing it this way rather than creating a map), + // pointing to EventID is an obvious overkill at the moment, but + // that's not for long. + feeds [maxFeeds]response.EventID + } +) + +const ( + // Maximum number of subscriptions per one client. + maxFeeds = 16 + + // This sets notification messages buffer depth, it may seem to be quite + // big, but there is a big gap in speed between internal event processing + // and networking communication that is combined with spiky nature of our + // event generation process, which leads to lots of events generated in + // short time and they will put some pressure to this buffer (consider + // ~500 invocation txs in one block with some notifications). At the same + // time this channel is about sending pointers, so it's doesn't cost + // a lot in terms of memory used. + notificationBufSize = 1024 +) diff --git a/pkg/rpc/server/subscription_test.go b/pkg/rpc/server/subscription_test.go new file mode 100644 index 000000000..bd4fcb792 --- /dev/null +++ b/pkg/rpc/server/subscription_test.go @@ -0,0 +1,227 @@ +package server + +import ( + "encoding/json" + "fmt" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/nspcc-dev/neo-go/pkg/core" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/rpc/response" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" +) + +func wsReader(t *testing.T, ws *websocket.Conn, msgCh chan<- []byte, isFinished *atomic.Bool) { + for { + ws.SetReadDeadline(time.Now().Add(time.Second)) + _, body, err := ws.ReadMessage() + if isFinished.Load() { + require.Error(t, err) + break + } + require.NoError(t, err) + msgCh <- body + } +} + +func callWSGetRaw(t *testing.T, ws *websocket.Conn, msg string, respCh <-chan []byte) *response.Raw { + var resp = new(response.Raw) + + ws.SetWriteDeadline(time.Now().Add(time.Second)) + require.NoError(t, ws.WriteMessage(websocket.TextMessage, []byte(msg))) + + body := <-respCh + require.NoError(t, json.Unmarshal(body, resp)) + return resp +} + +func getNotification(t *testing.T, respCh <-chan []byte) *response.Notification { + var resp = new(response.Notification) + body := <-respCh + require.NoError(t, json.Unmarshal(body, resp)) + return resp +} + +func initCleanServerAndWSClient(t *testing.T) (*core.Blockchain, *Server, *websocket.Conn, chan []byte, *atomic.Bool) { + chain, rpcSrv, httpSrv := initClearServerWithInMemoryChain(t) + + dialer := websocket.Dialer{HandshakeTimeout: time.Second} + url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" + ws, _, err := dialer.Dial(url, nil) + require.NoError(t, err) + + // Use buffered channel to read server's messages and then read expected + // responses from it. + respMsgs := make(chan []byte, 16) + finishedFlag := atomic.NewBool(false) + go wsReader(t, ws, respMsgs, finishedFlag) + return chain, rpcSrv, ws, respMsgs, finishedFlag +} + +func TestSubscriptions(t *testing.T) { + var subIDs = make([]string, 0) + var subFeeds = []string{"block_added", "transaction_added", "notification_from_execution", "transaction_executed"} + + chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t) + + defer chain.Close() + defer rpcSrv.Shutdown() + + for _, feed := range subFeeds { + var s string + resp := callWSGetRaw(t, c, fmt.Sprintf(`{ + "jsonrpc": "2.0", + "method": "subscribe", + "params": ["%s"], + "id": 1 +}`, feed), respMsgs) + require.Nil(t, resp.Error) + require.NotNil(t, resp.Result) + require.NoError(t, json.Unmarshal(resp.Result, &s)) + subIDs = append(subIDs, s) + } + + for _, b := range getTestBlocks(t) { + require.NoError(t, chain.AddBlock(b)) + for _, tx := range b.Transactions { + var mayNotify bool + + if tx.Type == transaction.InvocationType { + resp := getNotification(t, respMsgs) + require.Equal(t, response.ExecutionEventID, resp.Event) + mayNotify = true + } + for { + resp := getNotification(t, respMsgs) + if mayNotify && resp.Event == response.NotificationEventID { + continue + } + require.Equal(t, response.TransactionEventID, resp.Event) + break + } + } + resp := getNotification(t, respMsgs) + require.Equal(t, response.BlockEventID, resp.Event) + } + + for _, id := range subIDs { + var b bool + + resp := callWSGetRaw(t, c, fmt.Sprintf(`{ + "jsonrpc": "2.0", + "method": "unsubscribe", + "params": ["%s"], + "id": 1 +}`, id), respMsgs) + require.Nil(t, resp.Error) + require.NotNil(t, resp.Result) + require.NoError(t, json.Unmarshal(resp.Result, &b)) + require.Equal(t, true, b) + } + finishedFlag.CAS(false, true) + c.Close() +} + +func TestMaxSubscriptions(t *testing.T) { + var subIDs = make([]string, 0) + chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t) + + defer chain.Close() + defer rpcSrv.Shutdown() + + for i := 0; i < maxFeeds+1; i++ { + var s string + resp := callWSGetRaw(t, c, `{"jsonrpc": "2.0", "method": "subscribe", "params": ["block_added"], "id": 1}`, respMsgs) + if i < maxFeeds { + require.Nil(t, resp.Error) + require.NotNil(t, resp.Result) + require.NoError(t, json.Unmarshal(resp.Result, &s)) + // Each ID must be unique. + for _, id := range subIDs { + require.NotEqual(t, id, s) + } + subIDs = append(subIDs, s) + } else { + require.NotNil(t, resp.Error) + require.Nil(t, resp.Result) + } + } + + finishedFlag.CAS(false, true) + c.Close() +} + +func TestBadSubUnsub(t *testing.T) { + var subCases = map[string]string{ + "no params": `{"jsonrpc": "2.0", "method": "subscribe", "params": [], "id": 1}`, + "bad (non-string) event": `{"jsonrpc": "2.0", "method": "subscribe", "params": [1], "id": 1}`, + "bad (wrong) event": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["block_removed"], "id": 1}`, + } + var unsubCases = map[string]string{ + "no params": `{"jsonrpc": "2.0", "method": "unsubscribe", "params": [], "id": 1}`, + "bad id": `{"jsonrpc": "2.0", "method": "unsubscribe", "params": ["vasiliy"], "id": 1}`, + "not subscribed id": `{"jsonrpc": "2.0", "method": "unsubscribe", "params": ["7"], "id": 1}`, + } + chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t) + + defer chain.Close() + defer rpcSrv.Shutdown() + + testF := func(t *testing.T, cases map[string]string) func(t *testing.T) { + return func(t *testing.T) { + for n, s := range cases { + t.Run(n, func(t *testing.T) { + resp := callWSGetRaw(t, c, s, respMsgs) + require.NotNil(t, resp.Error) + require.Nil(t, resp.Result) + }) + } + } + } + t.Run("subscribe", testF(t, subCases)) + t.Run("unsubscribe", testF(t, unsubCases)) + + finishedFlag.CAS(false, true) + c.Close() +} + +func doSomeWSRequest(t *testing.T, ws *websocket.Conn) { + ws.SetWriteDeadline(time.Now().Add(time.Second)) + // It could be just about anything including invalid request, + // we only care about server handling being active. + require.NoError(t, ws.WriteMessage(websocket.TextMessage, []byte(`{"jsonrpc": "2.0", "method": "getversion", "params": [], "id": 1}`))) + ws.SetReadDeadline(time.Now().Add(time.Second)) + _, _, err := ws.ReadMessage() + require.NoError(t, err) +} + +func TestWSClientsLimit(t *testing.T) { + chain, rpcSrv, httpSrv := initClearServerWithInMemoryChain(t) + defer chain.Close() + defer rpcSrv.Shutdown() + + dialer := websocket.Dialer{HandshakeTimeout: time.Second} + url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" + wss := make([]*websocket.Conn, maxSubscribers) + + for i := 0; i < len(wss)+1; i++ { + ws, _, err := dialer.Dial(url, nil) + if i < maxSubscribers { + require.NoError(t, err) + wss[i] = ws + // Check that it's completely ready. + doSomeWSRequest(t, ws) + } else { + require.Error(t, err) + } + } + // Check connections are still alive (it actually is necessary to add + // some use of wss to keep connections alive). + for i := 0; i < len(wss); i++ { + doSomeWSRequest(t, wss[i]) + } +}