mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-12-12 21:10:36 +00:00
rpc/server: add notification subscription
Note that the protocol differs a bit from #895 in its notifications format, to avoid additional server-side processing we're omitting some metadata like: * block size and confirmations * transaction fees, confirmations, block hash and timestamp * application execution doesn't have ScriptHash populated Some block fields may also differ in encoding compared to `getblock` results (like nonce field). I think these differences are unnoticieable for most use cases, so we can leave them as is, but it can be changed in the future.
This commit is contained in:
parent
29ada4ca46
commit
e1408b6525
7 changed files with 688 additions and 24 deletions
79
pkg/rpc/response/events.go
Normal file
79
pkg/rpc/response/events.go
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
package response
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
// EventID represents an event type happening on the chain.
|
||||||
|
EventID byte
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// InvalidEventID is an invalid event id that is the default value of
|
||||||
|
// EventID. It's only used as an initial value similar to nil.
|
||||||
|
InvalidEventID EventID = iota
|
||||||
|
// BlockEventID is a `block_added` event.
|
||||||
|
BlockEventID
|
||||||
|
// TransactionEventID corresponds to `transaction_added` event.
|
||||||
|
TransactionEventID
|
||||||
|
// NotificationEventID represents `notification_from_execution` events.
|
||||||
|
NotificationEventID
|
||||||
|
// ExecutionEventID is used for `transaction_executed` events.
|
||||||
|
ExecutionEventID
|
||||||
|
)
|
||||||
|
|
||||||
|
// String is a good old Stringer implementation.
|
||||||
|
func (e EventID) String() string {
|
||||||
|
switch e {
|
||||||
|
case BlockEventID:
|
||||||
|
return "block_added"
|
||||||
|
case TransactionEventID:
|
||||||
|
return "transaction_added"
|
||||||
|
case NotificationEventID:
|
||||||
|
return "notification_from_execution"
|
||||||
|
case ExecutionEventID:
|
||||||
|
return "transaction_executed"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEventIDFromString converts input string into an EventID if it's possible.
|
||||||
|
func GetEventIDFromString(s string) (EventID, error) {
|
||||||
|
switch s {
|
||||||
|
case "block_added":
|
||||||
|
return BlockEventID, nil
|
||||||
|
case "transaction_added":
|
||||||
|
return TransactionEventID, nil
|
||||||
|
case "notification_from_execution":
|
||||||
|
return NotificationEventID, nil
|
||||||
|
case "transaction_executed":
|
||||||
|
return ExecutionEventID, nil
|
||||||
|
default:
|
||||||
|
return 255, errors.New("invalid stream name")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaler interface.
|
||||||
|
func (e EventID) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(e.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler interface.
|
||||||
|
func (e *EventID) UnmarshalJSON(b []byte) error {
|
||||||
|
var s string
|
||||||
|
|
||||||
|
err := json.Unmarshal(b, &s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
id, err := GetEventIDFromString(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*e = id
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -30,16 +30,22 @@ type NotificationEvent struct {
|
||||||
Item smartcontract.Parameter `json:"state"`
|
Item smartcontract.Parameter `json:"state"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StateEventToResultNotification converts state.NotificationEvent to
|
||||||
|
// result.NotificationEvent.
|
||||||
|
func StateEventToResultNotification(event state.NotificationEvent) NotificationEvent {
|
||||||
|
seen := make(map[vm.StackItem]bool)
|
||||||
|
item := event.Item.ToContractParameter(seen)
|
||||||
|
return NotificationEvent{
|
||||||
|
Contract: event.ScriptHash,
|
||||||
|
Item: item,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewApplicationLog creates a new ApplicationLog wrapper.
|
// NewApplicationLog creates a new ApplicationLog wrapper.
|
||||||
func NewApplicationLog(appExecRes *state.AppExecResult, scriptHash util.Uint160) ApplicationLog {
|
func NewApplicationLog(appExecRes *state.AppExecResult, scriptHash util.Uint160) ApplicationLog {
|
||||||
events := make([]NotificationEvent, 0, len(appExecRes.Events))
|
events := make([]NotificationEvent, 0, len(appExecRes.Events))
|
||||||
for _, e := range appExecRes.Events {
|
for _, e := range appExecRes.Events {
|
||||||
seen := make(map[vm.StackItem]bool)
|
events = append(events, StateEventToResultNotification(e))
|
||||||
item := e.Item.ToContractParameter(seen)
|
|
||||||
events = append(events, NotificationEvent{
|
|
||||||
Contract: e.ScriptHash,
|
|
||||||
Item: item,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
triggerString := appExecRes.Trigger.String()
|
triggerString := appExecRes.Trigger.String()
|
||||||
|
|
|
@ -37,3 +37,12 @@ type GetRawTx struct {
|
||||||
HeaderAndError
|
HeaderAndError
|
||||||
Result *result.TransactionOutputRaw `json:"result"`
|
Result *result.TransactionOutputRaw `json:"result"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Notification is a type used to represent wire format of events, they're
|
||||||
|
// special in that they look like requests but they don't have IDs and their
|
||||||
|
// "method" is actually an event name.
|
||||||
|
type Notification struct {
|
||||||
|
JSONRPC string `json:"jsonrpc"`
|
||||||
|
Event EventID `json:"method"`
|
||||||
|
Payload []interface{} `json:"params"`
|
||||||
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
@ -41,6 +42,19 @@ type (
|
||||||
coreServer *network.Server
|
coreServer *network.Server
|
||||||
log *zap.Logger
|
log *zap.Logger
|
||||||
https *http.Server
|
https *http.Server
|
||||||
|
shutdown chan struct{}
|
||||||
|
|
||||||
|
subsLock sync.RWMutex
|
||||||
|
subscribers map[*subscriber]bool
|
||||||
|
subsGroup sync.WaitGroup
|
||||||
|
blockSubs int
|
||||||
|
executionSubs int
|
||||||
|
notificationSubs int
|
||||||
|
transactionSubs int
|
||||||
|
blockCh chan *block.Block
|
||||||
|
executionCh chan *state.AppExecResult
|
||||||
|
notificationCh chan *state.NotificationEvent
|
||||||
|
transactionCh chan *transaction.Transaction
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -56,6 +70,11 @@ const (
|
||||||
|
|
||||||
// Write deadline.
|
// Write deadline.
|
||||||
wsWriteLimit = wsPingPeriod / 2
|
wsWriteLimit = wsPingPeriod / 2
|
||||||
|
|
||||||
|
// Maximum number of subscribers per Server. Each websocket client is
|
||||||
|
// treated like subscriber, so technically it's a limit on websocket
|
||||||
|
// connections.
|
||||||
|
maxSubscribers = 64
|
||||||
)
|
)
|
||||||
|
|
||||||
var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *response.Error){
|
var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *response.Error){
|
||||||
|
@ -91,6 +110,11 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *respon
|
||||||
"validateaddress": (*Server).validateAddress,
|
"validateaddress": (*Server).validateAddress,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var rpcWsHandlers = map[string]func(*Server, request.Params, *subscriber) (interface{}, *response.Error){
|
||||||
|
"subscribe": (*Server).subscribe,
|
||||||
|
"unsubscribe": (*Server).unsubscribe,
|
||||||
|
}
|
||||||
|
|
||||||
var invalidBlockHeightError = func(index int, height int) *response.Error {
|
var invalidBlockHeightError = func(index int, height int) *response.Error {
|
||||||
return response.NewRPCError(fmt.Sprintf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height), "", nil)
|
return response.NewRPCError(fmt.Sprintf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height), "", nil)
|
||||||
}
|
}
|
||||||
|
@ -119,6 +143,14 @@ func New(chain core.Blockchainer, conf rpc.Config, coreServer *network.Server, l
|
||||||
coreServer: coreServer,
|
coreServer: coreServer,
|
||||||
log: log,
|
log: log,
|
||||||
https: tlsServer,
|
https: tlsServer,
|
||||||
|
shutdown: make(chan struct{}),
|
||||||
|
|
||||||
|
subscribers: make(map[*subscriber]bool),
|
||||||
|
// These are NOT buffered to preserve original order of events.
|
||||||
|
blockCh: make(chan *block.Block),
|
||||||
|
executionCh: make(chan *state.AppExecResult),
|
||||||
|
notificationCh: make(chan *state.NotificationEvent),
|
||||||
|
transactionCh: make(chan *transaction.Transaction),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -133,6 +165,7 @@ func (s *Server) Start(errChan chan error) {
|
||||||
s.Handler = http.HandlerFunc(s.handleHTTPRequest)
|
s.Handler = http.HandlerFunc(s.handleHTTPRequest)
|
||||||
s.log.Info("starting rpc-server", zap.String("endpoint", s.Addr))
|
s.log.Info("starting rpc-server", zap.String("endpoint", s.Addr))
|
||||||
|
|
||||||
|
go s.handleSubEvents()
|
||||||
if cfg := s.config.TLSConfig; cfg.Enabled {
|
if cfg := s.config.TLSConfig; cfg.Enabled {
|
||||||
s.https.Handler = http.HandlerFunc(s.handleHTTPRequest)
|
s.https.Handler = http.HandlerFunc(s.handleHTTPRequest)
|
||||||
s.log.Info("starting rpc-server (https)", zap.String("endpoint", s.https.Addr))
|
s.log.Info("starting rpc-server (https)", zap.String("endpoint", s.https.Addr))
|
||||||
|
@ -155,6 +188,10 @@ func (s *Server) Start(errChan chan error) {
|
||||||
// method.
|
// method.
|
||||||
func (s *Server) Shutdown() error {
|
func (s *Server) Shutdown() error {
|
||||||
var httpsErr error
|
var httpsErr error
|
||||||
|
|
||||||
|
// Signal to websocket writer routines and handleSubEvents.
|
||||||
|
close(s.shutdown)
|
||||||
|
|
||||||
if s.config.TLSConfig.Enabled {
|
if s.config.TLSConfig.Enabled {
|
||||||
s.log.Info("shutting down rpc-server (https)", zap.String("endpoint", s.https.Addr))
|
s.log.Info("shutting down rpc-server (https)", zap.String("endpoint", s.https.Addr))
|
||||||
httpsErr = s.https.Shutdown(context.Background())
|
httpsErr = s.https.Shutdown(context.Background())
|
||||||
|
@ -162,6 +199,10 @@ func (s *Server) Shutdown() error {
|
||||||
|
|
||||||
s.log.Info("shutting down rpc-server", zap.String("endpoint", s.Addr))
|
s.log.Info("shutting down rpc-server", zap.String("endpoint", s.Addr))
|
||||||
err := s.Server.Shutdown(context.Background())
|
err := s.Server.Shutdown(context.Background())
|
||||||
|
|
||||||
|
// Wait for handleSubEvents to finish.
|
||||||
|
<-s.executionCh
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return httpsErr
|
return httpsErr
|
||||||
}
|
}
|
||||||
|
@ -169,20 +210,40 @@ func (s *Server) Shutdown() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) {
|
func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) {
|
||||||
|
req := request.NewIn()
|
||||||
|
|
||||||
if httpRequest.URL.Path == "/ws" && httpRequest.Method == "GET" {
|
if httpRequest.URL.Path == "/ws" && httpRequest.Method == "GET" {
|
||||||
|
// Technically there is a race between this check and
|
||||||
|
// s.subscribers modification 20 lines below, but it's tiny
|
||||||
|
// and not really critical to bother with it. Some additional
|
||||||
|
// clients may sneak in, no big deal.
|
||||||
|
s.subsLock.RLock()
|
||||||
|
numOfSubs := len(s.subscribers)
|
||||||
|
s.subsLock.RUnlock()
|
||||||
|
if numOfSubs >= maxSubscribers {
|
||||||
|
s.writeHTTPErrorResponse(
|
||||||
|
req,
|
||||||
|
w,
|
||||||
|
response.NewInternalServerError("websocket users limit reached", nil),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
ws, err := upgrader.Upgrade(w, httpRequest, nil)
|
ws, err := upgrader.Upgrade(w, httpRequest, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Info("websocket connection upgrade failed", zap.Error(err))
|
s.log.Info("websocket connection upgrade failed", zap.Error(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resChan := make(chan response.Raw)
|
resChan := make(chan response.Raw)
|
||||||
go s.handleWsWrites(ws, resChan)
|
subChan := make(chan *websocket.PreparedMessage, notificationBufSize)
|
||||||
s.handleWsReads(ws, resChan)
|
subscr := &subscriber{writer: subChan, ws: ws}
|
||||||
|
s.subsLock.Lock()
|
||||||
|
s.subscribers[subscr] = true
|
||||||
|
s.subsLock.Unlock()
|
||||||
|
go s.handleWsWrites(ws, resChan, subChan)
|
||||||
|
s.handleWsReads(ws, resChan, subscr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req := request.NewIn()
|
|
||||||
|
|
||||||
if httpRequest.Method != "POST" {
|
if httpRequest.Method != "POST" {
|
||||||
s.writeHTTPErrorResponse(
|
s.writeHTTPErrorResponse(
|
||||||
req,
|
req,
|
||||||
|
@ -200,11 +261,14 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := s.handleRequest(req)
|
resp := s.handleRequest(req, nil)
|
||||||
s.writeHTTPServerResponse(req, w, resp)
|
s.writeHTTPServerResponse(req, w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleRequest(req *request.In) response.Raw {
|
func (s *Server) handleRequest(req *request.In, sub *subscriber) response.Raw {
|
||||||
|
var res interface{}
|
||||||
|
var resErr *response.Error
|
||||||
|
|
||||||
reqParams, err := req.Params()
|
reqParams, err := req.Params()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s.packResponseToRaw(req, nil, response.NewInvalidParamsError("Problem parsing request parameters", err))
|
return s.packResponseToRaw(req, nil, response.NewInvalidParamsError("Problem parsing request parameters", err))
|
||||||
|
@ -216,20 +280,37 @@ func (s *Server) handleRequest(req *request.In) response.Raw {
|
||||||
|
|
||||||
incCounter(req.Method)
|
incCounter(req.Method)
|
||||||
|
|
||||||
|
resErr = response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil)
|
||||||
handler, ok := rpcHandlers[req.Method]
|
handler, ok := rpcHandlers[req.Method]
|
||||||
if !ok {
|
if ok {
|
||||||
return s.packResponseToRaw(req, nil, response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil))
|
res, resErr = handler(s, *reqParams)
|
||||||
|
} else if sub != nil {
|
||||||
|
handler, ok := rpcWsHandlers[req.Method]
|
||||||
|
if ok {
|
||||||
|
res, resErr = handler(s, *reqParams, sub)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
res, resErr := handler(s, *reqParams)
|
|
||||||
return s.packResponseToRaw(req, res, resErr)
|
return s.packResponseToRaw(req, res, resErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw) {
|
func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw, subChan <-chan *websocket.PreparedMessage) {
|
||||||
pingTicker := time.NewTicker(wsPingPeriod)
|
pingTicker := time.NewTicker(wsPingPeriod)
|
||||||
defer ws.Close()
|
defer ws.Close()
|
||||||
defer pingTicker.Stop()
|
defer pingTicker.Stop()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
case <-s.shutdown:
|
||||||
|
// Signal to the reader routine.
|
||||||
|
ws.Close()
|
||||||
|
return
|
||||||
|
case event, ok := <-subChan:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ws.SetWriteDeadline(time.Now().Add(wsWriteLimit))
|
||||||
|
if err := ws.WritePreparedMessage(event); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
case res, ok := <-resChan:
|
case res, ok := <-resChan:
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
|
@ -247,22 +328,36 @@ func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.Raw) {
|
func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.Raw, subscr *subscriber) {
|
||||||
ws.SetReadLimit(wsReadLimit)
|
ws.SetReadLimit(wsReadLimit)
|
||||||
ws.SetReadDeadline(time.Now().Add(wsPongLimit))
|
ws.SetReadDeadline(time.Now().Add(wsPongLimit))
|
||||||
ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil })
|
ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil })
|
||||||
|
requestloop:
|
||||||
for {
|
for {
|
||||||
req := new(request.In)
|
req := new(request.In)
|
||||||
err := ws.ReadJSON(req)
|
err := ws.ReadJSON(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
res := s.handleRequest(req)
|
res := s.handleRequest(req, subscr)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
s.logRequestError(req, res.Error)
|
s.logRequestError(req, res.Error)
|
||||||
}
|
}
|
||||||
resChan <- res
|
select {
|
||||||
|
case <-s.shutdown:
|
||||||
|
break requestloop
|
||||||
|
case resChan <- res:
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
s.subsLock.Lock()
|
||||||
|
delete(s.subscribers, subscr)
|
||||||
|
for _, e := range subscr.feeds {
|
||||||
|
if e != response.InvalidEventID {
|
||||||
|
s.unsubscribeFromChannel(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.subsLock.Unlock()
|
||||||
close(resChan)
|
close(resChan)
|
||||||
ws.Close()
|
ws.Close()
|
||||||
}
|
}
|
||||||
|
@ -1025,6 +1120,201 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *res
|
||||||
return results, resultsErr
|
return results, resultsErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// subscribe handles subscription requests from websocket clients.
|
||||||
|
func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) {
|
||||||
|
p, ok := reqParams.Value(0)
|
||||||
|
if !ok {
|
||||||
|
return nil, response.ErrInvalidParams
|
||||||
|
}
|
||||||
|
streamName, err := p.GetString()
|
||||||
|
if err != nil {
|
||||||
|
return nil, response.ErrInvalidParams
|
||||||
|
}
|
||||||
|
event, err := response.GetEventIDFromString(streamName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, response.ErrInvalidParams
|
||||||
|
}
|
||||||
|
s.subsLock.Lock()
|
||||||
|
defer s.subsLock.Unlock()
|
||||||
|
select {
|
||||||
|
case <-s.shutdown:
|
||||||
|
return nil, response.NewInternalServerError("server is shutting down", nil)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
var id int
|
||||||
|
for ; id < len(sub.feeds); id++ {
|
||||||
|
if sub.feeds[id] == response.InvalidEventID {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if id == len(sub.feeds) {
|
||||||
|
return nil, response.NewInternalServerError("maximum number of subscriptions is reached", nil)
|
||||||
|
}
|
||||||
|
sub.feeds[id] = event
|
||||||
|
s.subscribeToChannel(event)
|
||||||
|
return strconv.FormatInt(int64(id), 10), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// subscribeToChannel subscribes RPC server to appropriate chain events if
|
||||||
|
// it's not yet subscribed for them. It's supposed to be called with s.subsLock
|
||||||
|
// taken by the caller.
|
||||||
|
func (s *Server) subscribeToChannel(event response.EventID) {
|
||||||
|
switch event {
|
||||||
|
case response.BlockEventID:
|
||||||
|
if s.blockSubs == 0 {
|
||||||
|
s.chain.SubscribeForBlocks(s.blockCh)
|
||||||
|
}
|
||||||
|
s.blockSubs++
|
||||||
|
case response.TransactionEventID:
|
||||||
|
if s.transactionSubs == 0 {
|
||||||
|
s.chain.SubscribeForTransactions(s.transactionCh)
|
||||||
|
}
|
||||||
|
s.transactionSubs++
|
||||||
|
case response.NotificationEventID:
|
||||||
|
if s.notificationSubs == 0 {
|
||||||
|
s.chain.SubscribeForNotifications(s.notificationCh)
|
||||||
|
}
|
||||||
|
s.notificationSubs++
|
||||||
|
case response.ExecutionEventID:
|
||||||
|
if s.executionSubs == 0 {
|
||||||
|
s.chain.SubscribeForExecutions(s.executionCh)
|
||||||
|
}
|
||||||
|
s.executionSubs++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unsubscribe handles unsubscription requests from websocket clients.
|
||||||
|
func (s *Server) unsubscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) {
|
||||||
|
p, ok := reqParams.Value(0)
|
||||||
|
if !ok {
|
||||||
|
return nil, response.ErrInvalidParams
|
||||||
|
}
|
||||||
|
id, err := p.GetInt()
|
||||||
|
if err != nil || id < 0 {
|
||||||
|
return nil, response.ErrInvalidParams
|
||||||
|
}
|
||||||
|
s.subsLock.Lock()
|
||||||
|
defer s.subsLock.Unlock()
|
||||||
|
if len(sub.feeds) <= id || sub.feeds[id] == response.InvalidEventID {
|
||||||
|
return nil, response.ErrInvalidParams
|
||||||
|
}
|
||||||
|
event := sub.feeds[id]
|
||||||
|
sub.feeds[id] = response.InvalidEventID
|
||||||
|
s.unsubscribeFromChannel(event)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// unsubscribeFromChannel unsubscribes RPC server from appropriate chain events
|
||||||
|
// if there are no other subscribers for it. It's supposed to be called with
|
||||||
|
// s.subsLock taken by the caller.
|
||||||
|
func (s *Server) unsubscribeFromChannel(event response.EventID) {
|
||||||
|
switch event {
|
||||||
|
case response.BlockEventID:
|
||||||
|
s.blockSubs--
|
||||||
|
if s.blockSubs == 0 {
|
||||||
|
s.chain.UnsubscribeFromBlocks(s.blockCh)
|
||||||
|
}
|
||||||
|
case response.TransactionEventID:
|
||||||
|
s.transactionSubs--
|
||||||
|
if s.transactionSubs == 0 {
|
||||||
|
s.chain.UnsubscribeFromTransactions(s.transactionCh)
|
||||||
|
}
|
||||||
|
case response.NotificationEventID:
|
||||||
|
s.notificationSubs--
|
||||||
|
if s.notificationSubs == 0 {
|
||||||
|
s.chain.UnsubscribeFromNotifications(s.notificationCh)
|
||||||
|
}
|
||||||
|
case response.ExecutionEventID:
|
||||||
|
s.executionSubs--
|
||||||
|
if s.executionSubs == 0 {
|
||||||
|
s.chain.UnsubscribeFromExecutions(s.executionCh)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleSubEvents() {
|
||||||
|
chloop:
|
||||||
|
for {
|
||||||
|
var resp = response.Notification{
|
||||||
|
JSONRPC: request.JSONRPCVersion,
|
||||||
|
Payload: make([]interface{}, 1),
|
||||||
|
}
|
||||||
|
var msg *websocket.PreparedMessage
|
||||||
|
select {
|
||||||
|
case <-s.shutdown:
|
||||||
|
break chloop
|
||||||
|
case b := <-s.blockCh:
|
||||||
|
resp.Event = response.BlockEventID
|
||||||
|
resp.Payload[0] = b
|
||||||
|
case execution := <-s.executionCh:
|
||||||
|
resp.Event = response.ExecutionEventID
|
||||||
|
resp.Payload[0] = result.NewApplicationLog(execution, util.Uint160{})
|
||||||
|
case notification := <-s.notificationCh:
|
||||||
|
resp.Event = response.NotificationEventID
|
||||||
|
resp.Payload[0] = result.StateEventToResultNotification(*notification)
|
||||||
|
case tx := <-s.transactionCh:
|
||||||
|
resp.Event = response.TransactionEventID
|
||||||
|
resp.Payload[0] = tx
|
||||||
|
}
|
||||||
|
s.subsLock.RLock()
|
||||||
|
subloop:
|
||||||
|
for sub := range s.subscribers {
|
||||||
|
for _, subID := range sub.feeds {
|
||||||
|
if subID == resp.Event {
|
||||||
|
if msg == nil {
|
||||||
|
b, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error("failed to marshal notification",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("type", resp.Event.String()))
|
||||||
|
break subloop
|
||||||
|
}
|
||||||
|
msg, err = websocket.NewPreparedMessage(websocket.TextMessage, b)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error("failed to prepare notification message",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("type", resp.Event.String()))
|
||||||
|
break subloop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sub.writer <- msg
|
||||||
|
// The message is sent only once per subscriber.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.subsLock.RUnlock()
|
||||||
|
}
|
||||||
|
// It's important to do it with lock held because no subscription routine
|
||||||
|
// should be running concurrently to this one. And even if one is to run
|
||||||
|
// after unlock, it'll see closed s.shutdown and won't subscribe.
|
||||||
|
s.subsLock.Lock()
|
||||||
|
// There might be no subscription in reality, but it's not a problem as
|
||||||
|
// core.Blockchain allows unsubscribing non-subscribed channels.
|
||||||
|
s.chain.UnsubscribeFromBlocks(s.blockCh)
|
||||||
|
s.chain.UnsubscribeFromTransactions(s.transactionCh)
|
||||||
|
s.chain.UnsubscribeFromNotifications(s.notificationCh)
|
||||||
|
s.chain.UnsubscribeFromExecutions(s.executionCh)
|
||||||
|
s.subsLock.Unlock()
|
||||||
|
drainloop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.blockCh:
|
||||||
|
case <-s.executionCh:
|
||||||
|
case <-s.notificationCh:
|
||||||
|
case <-s.transactionCh:
|
||||||
|
default:
|
||||||
|
break drainloop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// It's not required closing these, but since they're drained already
|
||||||
|
// this is safe and it also allows to give a signal to Shutdown routine.
|
||||||
|
close(s.blockCh)
|
||||||
|
close(s.transactionCh)
|
||||||
|
close(s.notificationCh)
|
||||||
|
close(s.executionCh)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) blockHeightFromParam(param *request.Param) (int, *response.Error) {
|
func (s *Server) blockHeightFromParam(param *request.Param) (int, *response.Error) {
|
||||||
num, err := param.GetInt()
|
num, err := param.GetInt()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -15,12 +15,11 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/network"
|
"github.com/nspcc-dev/neo-go/pkg/network"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.uber.org/zap"
|
||||||
"go.uber.org/zap/zaptest"
|
"go.uber.org/zap/zaptest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *httptest.Server) {
|
func getUnitTestChain(t *testing.T) (*core.Blockchain, config.Config, *zap.Logger) {
|
||||||
var nBlocks uint32
|
|
||||||
|
|
||||||
net := config.ModeUnitTestNet
|
net := config.ModeUnitTestNet
|
||||||
configPath := "../../../config"
|
configPath := "../../../config"
|
||||||
cfg, err := config.Load(configPath, net)
|
cfg, err := config.Load(configPath, net)
|
||||||
|
@ -33,6 +32,11 @@ func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *http
|
||||||
|
|
||||||
go chain.Run()
|
go chain.Run()
|
||||||
|
|
||||||
|
return chain, cfg, logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTestBlocks(t *testing.T) []*block.Block {
|
||||||
|
blocks := make([]*block.Block, 0)
|
||||||
// File "./testdata/testblocks.acc" was generated by function core._
|
// File "./testdata/testblocks.acc" was generated by function core._
|
||||||
// ("neo-go/pkg/core/helper_test.go").
|
// ("neo-go/pkg/core/helper_test.go").
|
||||||
// To generate new "./testdata/testblocks.acc", follow the steps:
|
// To generate new "./testdata/testblocks.acc", follow the steps:
|
||||||
|
@ -42,15 +46,20 @@ func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *http
|
||||||
f, err := os.Open("testdata/testblocks.acc")
|
f, err := os.Open("testdata/testblocks.acc")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
br := io.NewBinReaderFromIO(f)
|
br := io.NewBinReaderFromIO(f)
|
||||||
nBlocks = br.ReadU32LE()
|
nBlocks := br.ReadU32LE()
|
||||||
require.Nil(t, br.Err)
|
require.Nil(t, br.Err)
|
||||||
for i := 0; i < int(nBlocks); i++ {
|
for i := 0; i < int(nBlocks); i++ {
|
||||||
_ = br.ReadU32LE()
|
_ = br.ReadU32LE()
|
||||||
b := &block.Block{}
|
b := &block.Block{}
|
||||||
b.DecodeBinary(br)
|
b.DecodeBinary(br)
|
||||||
require.Nil(t, br.Err)
|
require.Nil(t, br.Err)
|
||||||
require.NoError(t, chain.AddBlock(b))
|
blocks = append(blocks, b)
|
||||||
}
|
}
|
||||||
|
return blocks
|
||||||
|
}
|
||||||
|
|
||||||
|
func initClearServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *httptest.Server) {
|
||||||
|
chain, cfg, logger := getUnitTestChain(t)
|
||||||
|
|
||||||
serverConfig := network.NewServerConfig(cfg)
|
serverConfig := network.NewServerConfig(cfg)
|
||||||
server, err := network.NewServer(serverConfig, chain, logger)
|
server, err := network.NewServer(serverConfig, chain, logger)
|
||||||
|
@ -65,6 +74,15 @@ func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *http
|
||||||
return chain, &rpcServer, srv
|
return chain, &rpcServer, srv
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *httptest.Server) {
|
||||||
|
chain, rpcServer, srv := initClearServerWithInMemoryChain(t)
|
||||||
|
|
||||||
|
for _, b := range getTestBlocks(t) {
|
||||||
|
require.NoError(t, chain.AddBlock(b))
|
||||||
|
}
|
||||||
|
return chain, rpcServer, srv
|
||||||
|
}
|
||||||
|
|
||||||
type FeerStub struct{}
|
type FeerStub struct{}
|
||||||
|
|
||||||
func (fs *FeerStub) NetworkFee(*transaction.Transaction) util.Fixed8 {
|
func (fs *FeerStub) NetworkFee(*transaction.Transaction) util.Fixed8 {
|
||||||
|
|
35
pkg/rpc/server/subscription.go
Normal file
35
pkg/rpc/server/subscription.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/rpc/response"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
// subscriber is an event subscriber.
|
||||||
|
subscriber struct {
|
||||||
|
writer chan<- *websocket.PreparedMessage
|
||||||
|
ws *websocket.Conn
|
||||||
|
|
||||||
|
// These work like slots as there is not a lot of them (it's
|
||||||
|
// cheaper doing it this way rather than creating a map),
|
||||||
|
// pointing to EventID is an obvious overkill at the moment, but
|
||||||
|
// that's not for long.
|
||||||
|
feeds [maxFeeds]response.EventID
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Maximum number of subscriptions per one client.
|
||||||
|
maxFeeds = 16
|
||||||
|
|
||||||
|
// This sets notification messages buffer depth, it may seem to be quite
|
||||||
|
// big, but there is a big gap in speed between internal event processing
|
||||||
|
// and networking communication that is combined with spiky nature of our
|
||||||
|
// event generation process, which leads to lots of events generated in
|
||||||
|
// short time and they will put some pressure to this buffer (consider
|
||||||
|
// ~500 invocation txs in one block with some notifications). At the same
|
||||||
|
// time this channel is about sending pointers, so it's doesn't cost
|
||||||
|
// a lot in terms of memory used.
|
||||||
|
notificationBufSize = 1024
|
||||||
|
)
|
227
pkg/rpc/server/subscription_test.go
Normal file
227
pkg/rpc/server/subscription_test.go
Normal file
|
@ -0,0 +1,227 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/rpc/response"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.uber.org/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
func wsReader(t *testing.T, ws *websocket.Conn, msgCh chan<- []byte, isFinished *atomic.Bool) {
|
||||||
|
for {
|
||||||
|
ws.SetReadDeadline(time.Now().Add(time.Second))
|
||||||
|
_, body, err := ws.ReadMessage()
|
||||||
|
if isFinished.Load() {
|
||||||
|
require.Error(t, err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
msgCh <- body
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func callWSGetRaw(t *testing.T, ws *websocket.Conn, msg string, respCh <-chan []byte) *response.Raw {
|
||||||
|
var resp = new(response.Raw)
|
||||||
|
|
||||||
|
ws.SetWriteDeadline(time.Now().Add(time.Second))
|
||||||
|
require.NoError(t, ws.WriteMessage(websocket.TextMessage, []byte(msg)))
|
||||||
|
|
||||||
|
body := <-respCh
|
||||||
|
require.NoError(t, json.Unmarshal(body, resp))
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNotification(t *testing.T, respCh <-chan []byte) *response.Notification {
|
||||||
|
var resp = new(response.Notification)
|
||||||
|
body := <-respCh
|
||||||
|
require.NoError(t, json.Unmarshal(body, resp))
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func initCleanServerAndWSClient(t *testing.T) (*core.Blockchain, *Server, *websocket.Conn, chan []byte, *atomic.Bool) {
|
||||||
|
chain, rpcSrv, httpSrv := initClearServerWithInMemoryChain(t)
|
||||||
|
|
||||||
|
dialer := websocket.Dialer{HandshakeTimeout: time.Second}
|
||||||
|
url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws"
|
||||||
|
ws, _, err := dialer.Dial(url, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Use buffered channel to read server's messages and then read expected
|
||||||
|
// responses from it.
|
||||||
|
respMsgs := make(chan []byte, 16)
|
||||||
|
finishedFlag := atomic.NewBool(false)
|
||||||
|
go wsReader(t, ws, respMsgs, finishedFlag)
|
||||||
|
return chain, rpcSrv, ws, respMsgs, finishedFlag
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriptions(t *testing.T) {
|
||||||
|
var subIDs = make([]string, 0)
|
||||||
|
var subFeeds = []string{"block_added", "transaction_added", "notification_from_execution", "transaction_executed"}
|
||||||
|
|
||||||
|
chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t)
|
||||||
|
|
||||||
|
defer chain.Close()
|
||||||
|
defer rpcSrv.Shutdown()
|
||||||
|
|
||||||
|
for _, feed := range subFeeds {
|
||||||
|
var s string
|
||||||
|
resp := callWSGetRaw(t, c, fmt.Sprintf(`{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "subscribe",
|
||||||
|
"params": ["%s"],
|
||||||
|
"id": 1
|
||||||
|
}`, feed), respMsgs)
|
||||||
|
require.Nil(t, resp.Error)
|
||||||
|
require.NotNil(t, resp.Result)
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Result, &s))
|
||||||
|
subIDs = append(subIDs, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, b := range getTestBlocks(t) {
|
||||||
|
require.NoError(t, chain.AddBlock(b))
|
||||||
|
for _, tx := range b.Transactions {
|
||||||
|
var mayNotify bool
|
||||||
|
|
||||||
|
if tx.Type == transaction.InvocationType {
|
||||||
|
resp := getNotification(t, respMsgs)
|
||||||
|
require.Equal(t, response.ExecutionEventID, resp.Event)
|
||||||
|
mayNotify = true
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
resp := getNotification(t, respMsgs)
|
||||||
|
if mayNotify && resp.Event == response.NotificationEventID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
require.Equal(t, response.TransactionEventID, resp.Event)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp := getNotification(t, respMsgs)
|
||||||
|
require.Equal(t, response.BlockEventID, resp.Event)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, id := range subIDs {
|
||||||
|
var b bool
|
||||||
|
|
||||||
|
resp := callWSGetRaw(t, c, fmt.Sprintf(`{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "unsubscribe",
|
||||||
|
"params": ["%s"],
|
||||||
|
"id": 1
|
||||||
|
}`, id), respMsgs)
|
||||||
|
require.Nil(t, resp.Error)
|
||||||
|
require.NotNil(t, resp.Result)
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Result, &b))
|
||||||
|
require.Equal(t, true, b)
|
||||||
|
}
|
||||||
|
finishedFlag.CAS(false, true)
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxSubscriptions(t *testing.T) {
|
||||||
|
var subIDs = make([]string, 0)
|
||||||
|
chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t)
|
||||||
|
|
||||||
|
defer chain.Close()
|
||||||
|
defer rpcSrv.Shutdown()
|
||||||
|
|
||||||
|
for i := 0; i < maxFeeds+1; i++ {
|
||||||
|
var s string
|
||||||
|
resp := callWSGetRaw(t, c, `{"jsonrpc": "2.0", "method": "subscribe", "params": ["block_added"], "id": 1}`, respMsgs)
|
||||||
|
if i < maxFeeds {
|
||||||
|
require.Nil(t, resp.Error)
|
||||||
|
require.NotNil(t, resp.Result)
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Result, &s))
|
||||||
|
// Each ID must be unique.
|
||||||
|
for _, id := range subIDs {
|
||||||
|
require.NotEqual(t, id, s)
|
||||||
|
}
|
||||||
|
subIDs = append(subIDs, s)
|
||||||
|
} else {
|
||||||
|
require.NotNil(t, resp.Error)
|
||||||
|
require.Nil(t, resp.Result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
finishedFlag.CAS(false, true)
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBadSubUnsub(t *testing.T) {
|
||||||
|
var subCases = map[string]string{
|
||||||
|
"no params": `{"jsonrpc": "2.0", "method": "subscribe", "params": [], "id": 1}`,
|
||||||
|
"bad (non-string) event": `{"jsonrpc": "2.0", "method": "subscribe", "params": [1], "id": 1}`,
|
||||||
|
"bad (wrong) event": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["block_removed"], "id": 1}`,
|
||||||
|
}
|
||||||
|
var unsubCases = map[string]string{
|
||||||
|
"no params": `{"jsonrpc": "2.0", "method": "unsubscribe", "params": [], "id": 1}`,
|
||||||
|
"bad id": `{"jsonrpc": "2.0", "method": "unsubscribe", "params": ["vasiliy"], "id": 1}`,
|
||||||
|
"not subscribed id": `{"jsonrpc": "2.0", "method": "unsubscribe", "params": ["7"], "id": 1}`,
|
||||||
|
}
|
||||||
|
chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t)
|
||||||
|
|
||||||
|
defer chain.Close()
|
||||||
|
defer rpcSrv.Shutdown()
|
||||||
|
|
||||||
|
testF := func(t *testing.T, cases map[string]string) func(t *testing.T) {
|
||||||
|
return func(t *testing.T) {
|
||||||
|
for n, s := range cases {
|
||||||
|
t.Run(n, func(t *testing.T) {
|
||||||
|
resp := callWSGetRaw(t, c, s, respMsgs)
|
||||||
|
require.NotNil(t, resp.Error)
|
||||||
|
require.Nil(t, resp.Result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Run("subscribe", testF(t, subCases))
|
||||||
|
t.Run("unsubscribe", testF(t, unsubCases))
|
||||||
|
|
||||||
|
finishedFlag.CAS(false, true)
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func doSomeWSRequest(t *testing.T, ws *websocket.Conn) {
|
||||||
|
ws.SetWriteDeadline(time.Now().Add(time.Second))
|
||||||
|
// It could be just about anything including invalid request,
|
||||||
|
// we only care about server handling being active.
|
||||||
|
require.NoError(t, ws.WriteMessage(websocket.TextMessage, []byte(`{"jsonrpc": "2.0", "method": "getversion", "params": [], "id": 1}`)))
|
||||||
|
ws.SetReadDeadline(time.Now().Add(time.Second))
|
||||||
|
_, _, err := ws.ReadMessage()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWSClientsLimit(t *testing.T) {
|
||||||
|
chain, rpcSrv, httpSrv := initClearServerWithInMemoryChain(t)
|
||||||
|
defer chain.Close()
|
||||||
|
defer rpcSrv.Shutdown()
|
||||||
|
|
||||||
|
dialer := websocket.Dialer{HandshakeTimeout: time.Second}
|
||||||
|
url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws"
|
||||||
|
wss := make([]*websocket.Conn, maxSubscribers)
|
||||||
|
|
||||||
|
for i := 0; i < len(wss)+1; i++ {
|
||||||
|
ws, _, err := dialer.Dial(url, nil)
|
||||||
|
if i < maxSubscribers {
|
||||||
|
require.NoError(t, err)
|
||||||
|
wss[i] = ws
|
||||||
|
// Check that it's completely ready.
|
||||||
|
doSomeWSRequest(t, ws)
|
||||||
|
} else {
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check connections are still alive (it actually is necessary to add
|
||||||
|
// some use of wss to keep connections alive).
|
||||||
|
for i := 0; i < len(wss); i++ {
|
||||||
|
doSomeWSRequest(t, wss[i])
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue