rpc/server: add notification subscription

Note that the protocol differs a bit from #895 in its notifications format,
to avoid additional server-side processing we're omitting some metadata like:
 * block size and confirmations
 * transaction fees, confirmations, block hash and timestamp
 * application execution doesn't have ScriptHash populated

Some block fields may also differ in encoding compared to `getblock` results
(like nonce field).

I think these differences are unnoticieable for most use cases, so we can
leave them as is, but it can be changed in the future.
This commit is contained in:
Roman Khimov 2020-05-11 01:00:19 +03:00
parent 29ada4ca46
commit e1408b6525
7 changed files with 688 additions and 24 deletions

View file

@ -0,0 +1,79 @@
package response
import (
"encoding/json"
"github.com/pkg/errors"
)
type (
// EventID represents an event type happening on the chain.
EventID byte
)
const (
// InvalidEventID is an invalid event id that is the default value of
// EventID. It's only used as an initial value similar to nil.
InvalidEventID EventID = iota
// BlockEventID is a `block_added` event.
BlockEventID
// TransactionEventID corresponds to `transaction_added` event.
TransactionEventID
// NotificationEventID represents `notification_from_execution` events.
NotificationEventID
// ExecutionEventID is used for `transaction_executed` events.
ExecutionEventID
)
// String is a good old Stringer implementation.
func (e EventID) String() string {
switch e {
case BlockEventID:
return "block_added"
case TransactionEventID:
return "transaction_added"
case NotificationEventID:
return "notification_from_execution"
case ExecutionEventID:
return "transaction_executed"
default:
return "unknown"
}
}
// GetEventIDFromString converts input string into an EventID if it's possible.
func GetEventIDFromString(s string) (EventID, error) {
switch s {
case "block_added":
return BlockEventID, nil
case "transaction_added":
return TransactionEventID, nil
case "notification_from_execution":
return NotificationEventID, nil
case "transaction_executed":
return ExecutionEventID, nil
default:
return 255, errors.New("invalid stream name")
}
}
// MarshalJSON implements json.Marshaler interface.
func (e EventID) MarshalJSON() ([]byte, error) {
return json.Marshal(e.String())
}
// UnmarshalJSON implements json.Unmarshaler interface.
func (e *EventID) UnmarshalJSON(b []byte) error {
var s string
err := json.Unmarshal(b, &s)
if err != nil {
return err
}
id, err := GetEventIDFromString(s)
if err != nil {
return err
}
*e = id
return nil
}

View file

@ -30,16 +30,22 @@ type NotificationEvent struct {
Item smartcontract.Parameter `json:"state"` Item smartcontract.Parameter `json:"state"`
} }
// StateEventToResultNotification converts state.NotificationEvent to
// result.NotificationEvent.
func StateEventToResultNotification(event state.NotificationEvent) NotificationEvent {
seen := make(map[vm.StackItem]bool)
item := event.Item.ToContractParameter(seen)
return NotificationEvent{
Contract: event.ScriptHash,
Item: item,
}
}
// NewApplicationLog creates a new ApplicationLog wrapper. // NewApplicationLog creates a new ApplicationLog wrapper.
func NewApplicationLog(appExecRes *state.AppExecResult, scriptHash util.Uint160) ApplicationLog { func NewApplicationLog(appExecRes *state.AppExecResult, scriptHash util.Uint160) ApplicationLog {
events := make([]NotificationEvent, 0, len(appExecRes.Events)) events := make([]NotificationEvent, 0, len(appExecRes.Events))
for _, e := range appExecRes.Events { for _, e := range appExecRes.Events {
seen := make(map[vm.StackItem]bool) events = append(events, StateEventToResultNotification(e))
item := e.Item.ToContractParameter(seen)
events = append(events, NotificationEvent{
Contract: e.ScriptHash,
Item: item,
})
} }
triggerString := appExecRes.Trigger.String() triggerString := appExecRes.Trigger.String()

View file

@ -37,3 +37,12 @@ type GetRawTx struct {
HeaderAndError HeaderAndError
Result *result.TransactionOutputRaw `json:"result"` Result *result.TransactionOutputRaw `json:"result"`
} }
// Notification is a type used to represent wire format of events, they're
// special in that they look like requests but they don't have IDs and their
// "method" is actually an event name.
type Notification struct {
JSONRPC string `json:"jsonrpc"`
Event EventID `json:"method"`
Payload []interface{} `json:"params"`
}

View file

@ -9,6 +9,7 @@ import (
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
"sync"
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -41,6 +42,19 @@ type (
coreServer *network.Server coreServer *network.Server
log *zap.Logger log *zap.Logger
https *http.Server https *http.Server
shutdown chan struct{}
subsLock sync.RWMutex
subscribers map[*subscriber]bool
subsGroup sync.WaitGroup
blockSubs int
executionSubs int
notificationSubs int
transactionSubs int
blockCh chan *block.Block
executionCh chan *state.AppExecResult
notificationCh chan *state.NotificationEvent
transactionCh chan *transaction.Transaction
} }
) )
@ -56,6 +70,11 @@ const (
// Write deadline. // Write deadline.
wsWriteLimit = wsPingPeriod / 2 wsWriteLimit = wsPingPeriod / 2
// Maximum number of subscribers per Server. Each websocket client is
// treated like subscriber, so technically it's a limit on websocket
// connections.
maxSubscribers = 64
) )
var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *response.Error){ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *response.Error){
@ -91,6 +110,11 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *respon
"validateaddress": (*Server).validateAddress, "validateaddress": (*Server).validateAddress,
} }
var rpcWsHandlers = map[string]func(*Server, request.Params, *subscriber) (interface{}, *response.Error){
"subscribe": (*Server).subscribe,
"unsubscribe": (*Server).unsubscribe,
}
var invalidBlockHeightError = func(index int, height int) *response.Error { var invalidBlockHeightError = func(index int, height int) *response.Error {
return response.NewRPCError(fmt.Sprintf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height), "", nil) return response.NewRPCError(fmt.Sprintf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height), "", nil)
} }
@ -119,6 +143,14 @@ func New(chain core.Blockchainer, conf rpc.Config, coreServer *network.Server, l
coreServer: coreServer, coreServer: coreServer,
log: log, log: log,
https: tlsServer, https: tlsServer,
shutdown: make(chan struct{}),
subscribers: make(map[*subscriber]bool),
// These are NOT buffered to preserve original order of events.
blockCh: make(chan *block.Block),
executionCh: make(chan *state.AppExecResult),
notificationCh: make(chan *state.NotificationEvent),
transactionCh: make(chan *transaction.Transaction),
} }
} }
@ -133,6 +165,7 @@ func (s *Server) Start(errChan chan error) {
s.Handler = http.HandlerFunc(s.handleHTTPRequest) s.Handler = http.HandlerFunc(s.handleHTTPRequest)
s.log.Info("starting rpc-server", zap.String("endpoint", s.Addr)) s.log.Info("starting rpc-server", zap.String("endpoint", s.Addr))
go s.handleSubEvents()
if cfg := s.config.TLSConfig; cfg.Enabled { if cfg := s.config.TLSConfig; cfg.Enabled {
s.https.Handler = http.HandlerFunc(s.handleHTTPRequest) s.https.Handler = http.HandlerFunc(s.handleHTTPRequest)
s.log.Info("starting rpc-server (https)", zap.String("endpoint", s.https.Addr)) s.log.Info("starting rpc-server (https)", zap.String("endpoint", s.https.Addr))
@ -155,6 +188,10 @@ func (s *Server) Start(errChan chan error) {
// method. // method.
func (s *Server) Shutdown() error { func (s *Server) Shutdown() error {
var httpsErr error var httpsErr error
// Signal to websocket writer routines and handleSubEvents.
close(s.shutdown)
if s.config.TLSConfig.Enabled { if s.config.TLSConfig.Enabled {
s.log.Info("shutting down rpc-server (https)", zap.String("endpoint", s.https.Addr)) s.log.Info("shutting down rpc-server (https)", zap.String("endpoint", s.https.Addr))
httpsErr = s.https.Shutdown(context.Background()) httpsErr = s.https.Shutdown(context.Background())
@ -162,6 +199,10 @@ func (s *Server) Shutdown() error {
s.log.Info("shutting down rpc-server", zap.String("endpoint", s.Addr)) s.log.Info("shutting down rpc-server", zap.String("endpoint", s.Addr))
err := s.Server.Shutdown(context.Background()) err := s.Server.Shutdown(context.Background())
// Wait for handleSubEvents to finish.
<-s.executionCh
if err == nil { if err == nil {
return httpsErr return httpsErr
} }
@ -169,20 +210,40 @@ func (s *Server) Shutdown() error {
} }
func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) { func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) {
req := request.NewIn()
if httpRequest.URL.Path == "/ws" && httpRequest.Method == "GET" { if httpRequest.URL.Path == "/ws" && httpRequest.Method == "GET" {
// Technically there is a race between this check and
// s.subscribers modification 20 lines below, but it's tiny
// and not really critical to bother with it. Some additional
// clients may sneak in, no big deal.
s.subsLock.RLock()
numOfSubs := len(s.subscribers)
s.subsLock.RUnlock()
if numOfSubs >= maxSubscribers {
s.writeHTTPErrorResponse(
req,
w,
response.NewInternalServerError("websocket users limit reached", nil),
)
return
}
ws, err := upgrader.Upgrade(w, httpRequest, nil) ws, err := upgrader.Upgrade(w, httpRequest, nil)
if err != nil { if err != nil {
s.log.Info("websocket connection upgrade failed", zap.Error(err)) s.log.Info("websocket connection upgrade failed", zap.Error(err))
return return
} }
resChan := make(chan response.Raw) resChan := make(chan response.Raw)
go s.handleWsWrites(ws, resChan) subChan := make(chan *websocket.PreparedMessage, notificationBufSize)
s.handleWsReads(ws, resChan) subscr := &subscriber{writer: subChan, ws: ws}
s.subsLock.Lock()
s.subscribers[subscr] = true
s.subsLock.Unlock()
go s.handleWsWrites(ws, resChan, subChan)
s.handleWsReads(ws, resChan, subscr)
return return
} }
req := request.NewIn()
if httpRequest.Method != "POST" { if httpRequest.Method != "POST" {
s.writeHTTPErrorResponse( s.writeHTTPErrorResponse(
req, req,
@ -200,11 +261,14 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ
return return
} }
resp := s.handleRequest(req) resp := s.handleRequest(req, nil)
s.writeHTTPServerResponse(req, w, resp) s.writeHTTPServerResponse(req, w, resp)
} }
func (s *Server) handleRequest(req *request.In) response.Raw { func (s *Server) handleRequest(req *request.In, sub *subscriber) response.Raw {
var res interface{}
var resErr *response.Error
reqParams, err := req.Params() reqParams, err := req.Params()
if err != nil { if err != nil {
return s.packResponseToRaw(req, nil, response.NewInvalidParamsError("Problem parsing request parameters", err)) return s.packResponseToRaw(req, nil, response.NewInvalidParamsError("Problem parsing request parameters", err))
@ -216,20 +280,37 @@ func (s *Server) handleRequest(req *request.In) response.Raw {
incCounter(req.Method) incCounter(req.Method)
resErr = response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil)
handler, ok := rpcHandlers[req.Method] handler, ok := rpcHandlers[req.Method]
if !ok { if ok {
return s.packResponseToRaw(req, nil, response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil)) res, resErr = handler(s, *reqParams)
} else if sub != nil {
handler, ok := rpcWsHandlers[req.Method]
if ok {
res, resErr = handler(s, *reqParams, sub)
}
} }
res, resErr := handler(s, *reqParams)
return s.packResponseToRaw(req, res, resErr) return s.packResponseToRaw(req, res, resErr)
} }
func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw) { func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw, subChan <-chan *websocket.PreparedMessage) {
pingTicker := time.NewTicker(wsPingPeriod) pingTicker := time.NewTicker(wsPingPeriod)
defer ws.Close() defer ws.Close()
defer pingTicker.Stop() defer pingTicker.Stop()
for { for {
select { select {
case <-s.shutdown:
// Signal to the reader routine.
ws.Close()
return
case event, ok := <-subChan:
if !ok {
return
}
ws.SetWriteDeadline(time.Now().Add(wsWriteLimit))
if err := ws.WritePreparedMessage(event); err != nil {
return
}
case res, ok := <-resChan: case res, ok := <-resChan:
if !ok { if !ok {
return return
@ -247,22 +328,36 @@ func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw)
} }
} }
func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.Raw) { func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.Raw, subscr *subscriber) {
ws.SetReadLimit(wsReadLimit) ws.SetReadLimit(wsReadLimit)
ws.SetReadDeadline(time.Now().Add(wsPongLimit)) ws.SetReadDeadline(time.Now().Add(wsPongLimit))
ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil }) ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil })
requestloop:
for { for {
req := new(request.In) req := new(request.In)
err := ws.ReadJSON(req) err := ws.ReadJSON(req)
if err != nil { if err != nil {
break break
} }
res := s.handleRequest(req) res := s.handleRequest(req, subscr)
if res.Error != nil { if res.Error != nil {
s.logRequestError(req, res.Error) s.logRequestError(req, res.Error)
} }
resChan <- res select {
case <-s.shutdown:
break requestloop
case resChan <- res:
}
} }
s.subsLock.Lock()
delete(s.subscribers, subscr)
for _, e := range subscr.feeds {
if e != response.InvalidEventID {
s.unsubscribeFromChannel(e)
}
}
s.subsLock.Unlock()
close(resChan) close(resChan)
ws.Close() ws.Close()
} }
@ -1025,6 +1120,201 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *res
return results, resultsErr return results, resultsErr
} }
// subscribe handles subscription requests from websocket clients.
func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) {
p, ok := reqParams.Value(0)
if !ok {
return nil, response.ErrInvalidParams
}
streamName, err := p.GetString()
if err != nil {
return nil, response.ErrInvalidParams
}
event, err := response.GetEventIDFromString(streamName)
if err != nil {
return nil, response.ErrInvalidParams
}
s.subsLock.Lock()
defer s.subsLock.Unlock()
select {
case <-s.shutdown:
return nil, response.NewInternalServerError("server is shutting down", nil)
default:
}
var id int
for ; id < len(sub.feeds); id++ {
if sub.feeds[id] == response.InvalidEventID {
break
}
}
if id == len(sub.feeds) {
return nil, response.NewInternalServerError("maximum number of subscriptions is reached", nil)
}
sub.feeds[id] = event
s.subscribeToChannel(event)
return strconv.FormatInt(int64(id), 10), nil
}
// subscribeToChannel subscribes RPC server to appropriate chain events if
// it's not yet subscribed for them. It's supposed to be called with s.subsLock
// taken by the caller.
func (s *Server) subscribeToChannel(event response.EventID) {
switch event {
case response.BlockEventID:
if s.blockSubs == 0 {
s.chain.SubscribeForBlocks(s.blockCh)
}
s.blockSubs++
case response.TransactionEventID:
if s.transactionSubs == 0 {
s.chain.SubscribeForTransactions(s.transactionCh)
}
s.transactionSubs++
case response.NotificationEventID:
if s.notificationSubs == 0 {
s.chain.SubscribeForNotifications(s.notificationCh)
}
s.notificationSubs++
case response.ExecutionEventID:
if s.executionSubs == 0 {
s.chain.SubscribeForExecutions(s.executionCh)
}
s.executionSubs++
}
}
// unsubscribe handles unsubscription requests from websocket clients.
func (s *Server) unsubscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) {
p, ok := reqParams.Value(0)
if !ok {
return nil, response.ErrInvalidParams
}
id, err := p.GetInt()
if err != nil || id < 0 {
return nil, response.ErrInvalidParams
}
s.subsLock.Lock()
defer s.subsLock.Unlock()
if len(sub.feeds) <= id || sub.feeds[id] == response.InvalidEventID {
return nil, response.ErrInvalidParams
}
event := sub.feeds[id]
sub.feeds[id] = response.InvalidEventID
s.unsubscribeFromChannel(event)
return true, nil
}
// unsubscribeFromChannel unsubscribes RPC server from appropriate chain events
// if there are no other subscribers for it. It's supposed to be called with
// s.subsLock taken by the caller.
func (s *Server) unsubscribeFromChannel(event response.EventID) {
switch event {
case response.BlockEventID:
s.blockSubs--
if s.blockSubs == 0 {
s.chain.UnsubscribeFromBlocks(s.blockCh)
}
case response.TransactionEventID:
s.transactionSubs--
if s.transactionSubs == 0 {
s.chain.UnsubscribeFromTransactions(s.transactionCh)
}
case response.NotificationEventID:
s.notificationSubs--
if s.notificationSubs == 0 {
s.chain.UnsubscribeFromNotifications(s.notificationCh)
}
case response.ExecutionEventID:
s.executionSubs--
if s.executionSubs == 0 {
s.chain.UnsubscribeFromExecutions(s.executionCh)
}
}
}
func (s *Server) handleSubEvents() {
chloop:
for {
var resp = response.Notification{
JSONRPC: request.JSONRPCVersion,
Payload: make([]interface{}, 1),
}
var msg *websocket.PreparedMessage
select {
case <-s.shutdown:
break chloop
case b := <-s.blockCh:
resp.Event = response.BlockEventID
resp.Payload[0] = b
case execution := <-s.executionCh:
resp.Event = response.ExecutionEventID
resp.Payload[0] = result.NewApplicationLog(execution, util.Uint160{})
case notification := <-s.notificationCh:
resp.Event = response.NotificationEventID
resp.Payload[0] = result.StateEventToResultNotification(*notification)
case tx := <-s.transactionCh:
resp.Event = response.TransactionEventID
resp.Payload[0] = tx
}
s.subsLock.RLock()
subloop:
for sub := range s.subscribers {
for _, subID := range sub.feeds {
if subID == resp.Event {
if msg == nil {
b, err := json.Marshal(resp)
if err != nil {
s.log.Error("failed to marshal notification",
zap.Error(err),
zap.String("type", resp.Event.String()))
break subloop
}
msg, err = websocket.NewPreparedMessage(websocket.TextMessage, b)
if err != nil {
s.log.Error("failed to prepare notification message",
zap.Error(err),
zap.String("type", resp.Event.String()))
break subloop
}
}
sub.writer <- msg
// The message is sent only once per subscriber.
break
}
}
}
s.subsLock.RUnlock()
}
// It's important to do it with lock held because no subscription routine
// should be running concurrently to this one. And even if one is to run
// after unlock, it'll see closed s.shutdown and won't subscribe.
s.subsLock.Lock()
// There might be no subscription in reality, but it's not a problem as
// core.Blockchain allows unsubscribing non-subscribed channels.
s.chain.UnsubscribeFromBlocks(s.blockCh)
s.chain.UnsubscribeFromTransactions(s.transactionCh)
s.chain.UnsubscribeFromNotifications(s.notificationCh)
s.chain.UnsubscribeFromExecutions(s.executionCh)
s.subsLock.Unlock()
drainloop:
for {
select {
case <-s.blockCh:
case <-s.executionCh:
case <-s.notificationCh:
case <-s.transactionCh:
default:
break drainloop
}
}
// It's not required closing these, but since they're drained already
// this is safe and it also allows to give a signal to Shutdown routine.
close(s.blockCh)
close(s.transactionCh)
close(s.notificationCh)
close(s.executionCh)
}
func (s *Server) blockHeightFromParam(param *request.Param) (int, *response.Error) { func (s *Server) blockHeightFromParam(param *request.Param) (int, *response.Error) {
num, err := param.GetInt() num, err := param.GetInt()
if err != nil { if err != nil {

View file

@ -15,12 +15,11 @@ import (
"github.com/nspcc-dev/neo-go/pkg/network" "github.com/nspcc-dev/neo-go/pkg/network"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap"
"go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest"
) )
func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *httptest.Server) { func getUnitTestChain(t *testing.T) (*core.Blockchain, config.Config, *zap.Logger) {
var nBlocks uint32
net := config.ModeUnitTestNet net := config.ModeUnitTestNet
configPath := "../../../config" configPath := "../../../config"
cfg, err := config.Load(configPath, net) cfg, err := config.Load(configPath, net)
@ -33,6 +32,11 @@ func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *http
go chain.Run() go chain.Run()
return chain, cfg, logger
}
func getTestBlocks(t *testing.T) []*block.Block {
blocks := make([]*block.Block, 0)
// File "./testdata/testblocks.acc" was generated by function core._ // File "./testdata/testblocks.acc" was generated by function core._
// ("neo-go/pkg/core/helper_test.go"). // ("neo-go/pkg/core/helper_test.go").
// To generate new "./testdata/testblocks.acc", follow the steps: // To generate new "./testdata/testblocks.acc", follow the steps:
@ -42,15 +46,20 @@ func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *http
f, err := os.Open("testdata/testblocks.acc") f, err := os.Open("testdata/testblocks.acc")
require.Nil(t, err) require.Nil(t, err)
br := io.NewBinReaderFromIO(f) br := io.NewBinReaderFromIO(f)
nBlocks = br.ReadU32LE() nBlocks := br.ReadU32LE()
require.Nil(t, br.Err) require.Nil(t, br.Err)
for i := 0; i < int(nBlocks); i++ { for i := 0; i < int(nBlocks); i++ {
_ = br.ReadU32LE() _ = br.ReadU32LE()
b := &block.Block{} b := &block.Block{}
b.DecodeBinary(br) b.DecodeBinary(br)
require.Nil(t, br.Err) require.Nil(t, br.Err)
require.NoError(t, chain.AddBlock(b)) blocks = append(blocks, b)
} }
return blocks
}
func initClearServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *httptest.Server) {
chain, cfg, logger := getUnitTestChain(t)
serverConfig := network.NewServerConfig(cfg) serverConfig := network.NewServerConfig(cfg)
server, err := network.NewServer(serverConfig, chain, logger) server, err := network.NewServer(serverConfig, chain, logger)
@ -65,6 +74,15 @@ func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *http
return chain, &rpcServer, srv return chain, &rpcServer, srv
} }
func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *httptest.Server) {
chain, rpcServer, srv := initClearServerWithInMemoryChain(t)
for _, b := range getTestBlocks(t) {
require.NoError(t, chain.AddBlock(b))
}
return chain, rpcServer, srv
}
type FeerStub struct{} type FeerStub struct{}
func (fs *FeerStub) NetworkFee(*transaction.Transaction) util.Fixed8 { func (fs *FeerStub) NetworkFee(*transaction.Transaction) util.Fixed8 {

View file

@ -0,0 +1,35 @@
package server
import (
"github.com/gorilla/websocket"
"github.com/nspcc-dev/neo-go/pkg/rpc/response"
)
type (
// subscriber is an event subscriber.
subscriber struct {
writer chan<- *websocket.PreparedMessage
ws *websocket.Conn
// These work like slots as there is not a lot of them (it's
// cheaper doing it this way rather than creating a map),
// pointing to EventID is an obvious overkill at the moment, but
// that's not for long.
feeds [maxFeeds]response.EventID
}
)
const (
// Maximum number of subscriptions per one client.
maxFeeds = 16
// This sets notification messages buffer depth, it may seem to be quite
// big, but there is a big gap in speed between internal event processing
// and networking communication that is combined with spiky nature of our
// event generation process, which leads to lots of events generated in
// short time and they will put some pressure to this buffer (consider
// ~500 invocation txs in one block with some notifications). At the same
// time this channel is about sending pointers, so it's doesn't cost
// a lot in terms of memory used.
notificationBufSize = 1024
)

View file

@ -0,0 +1,227 @@
package server
import (
"encoding/json"
"fmt"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/nspcc-dev/neo-go/pkg/core"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/rpc/response"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
)
func wsReader(t *testing.T, ws *websocket.Conn, msgCh chan<- []byte, isFinished *atomic.Bool) {
for {
ws.SetReadDeadline(time.Now().Add(time.Second))
_, body, err := ws.ReadMessage()
if isFinished.Load() {
require.Error(t, err)
break
}
require.NoError(t, err)
msgCh <- body
}
}
func callWSGetRaw(t *testing.T, ws *websocket.Conn, msg string, respCh <-chan []byte) *response.Raw {
var resp = new(response.Raw)
ws.SetWriteDeadline(time.Now().Add(time.Second))
require.NoError(t, ws.WriteMessage(websocket.TextMessage, []byte(msg)))
body := <-respCh
require.NoError(t, json.Unmarshal(body, resp))
return resp
}
func getNotification(t *testing.T, respCh <-chan []byte) *response.Notification {
var resp = new(response.Notification)
body := <-respCh
require.NoError(t, json.Unmarshal(body, resp))
return resp
}
func initCleanServerAndWSClient(t *testing.T) (*core.Blockchain, *Server, *websocket.Conn, chan []byte, *atomic.Bool) {
chain, rpcSrv, httpSrv := initClearServerWithInMemoryChain(t)
dialer := websocket.Dialer{HandshakeTimeout: time.Second}
url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws"
ws, _, err := dialer.Dial(url, nil)
require.NoError(t, err)
// Use buffered channel to read server's messages and then read expected
// responses from it.
respMsgs := make(chan []byte, 16)
finishedFlag := atomic.NewBool(false)
go wsReader(t, ws, respMsgs, finishedFlag)
return chain, rpcSrv, ws, respMsgs, finishedFlag
}
func TestSubscriptions(t *testing.T) {
var subIDs = make([]string, 0)
var subFeeds = []string{"block_added", "transaction_added", "notification_from_execution", "transaction_executed"}
chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t)
defer chain.Close()
defer rpcSrv.Shutdown()
for _, feed := range subFeeds {
var s string
resp := callWSGetRaw(t, c, fmt.Sprintf(`{
"jsonrpc": "2.0",
"method": "subscribe",
"params": ["%s"],
"id": 1
}`, feed), respMsgs)
require.Nil(t, resp.Error)
require.NotNil(t, resp.Result)
require.NoError(t, json.Unmarshal(resp.Result, &s))
subIDs = append(subIDs, s)
}
for _, b := range getTestBlocks(t) {
require.NoError(t, chain.AddBlock(b))
for _, tx := range b.Transactions {
var mayNotify bool
if tx.Type == transaction.InvocationType {
resp := getNotification(t, respMsgs)
require.Equal(t, response.ExecutionEventID, resp.Event)
mayNotify = true
}
for {
resp := getNotification(t, respMsgs)
if mayNotify && resp.Event == response.NotificationEventID {
continue
}
require.Equal(t, response.TransactionEventID, resp.Event)
break
}
}
resp := getNotification(t, respMsgs)
require.Equal(t, response.BlockEventID, resp.Event)
}
for _, id := range subIDs {
var b bool
resp := callWSGetRaw(t, c, fmt.Sprintf(`{
"jsonrpc": "2.0",
"method": "unsubscribe",
"params": ["%s"],
"id": 1
}`, id), respMsgs)
require.Nil(t, resp.Error)
require.NotNil(t, resp.Result)
require.NoError(t, json.Unmarshal(resp.Result, &b))
require.Equal(t, true, b)
}
finishedFlag.CAS(false, true)
c.Close()
}
func TestMaxSubscriptions(t *testing.T) {
var subIDs = make([]string, 0)
chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t)
defer chain.Close()
defer rpcSrv.Shutdown()
for i := 0; i < maxFeeds+1; i++ {
var s string
resp := callWSGetRaw(t, c, `{"jsonrpc": "2.0", "method": "subscribe", "params": ["block_added"], "id": 1}`, respMsgs)
if i < maxFeeds {
require.Nil(t, resp.Error)
require.NotNil(t, resp.Result)
require.NoError(t, json.Unmarshal(resp.Result, &s))
// Each ID must be unique.
for _, id := range subIDs {
require.NotEqual(t, id, s)
}
subIDs = append(subIDs, s)
} else {
require.NotNil(t, resp.Error)
require.Nil(t, resp.Result)
}
}
finishedFlag.CAS(false, true)
c.Close()
}
func TestBadSubUnsub(t *testing.T) {
var subCases = map[string]string{
"no params": `{"jsonrpc": "2.0", "method": "subscribe", "params": [], "id": 1}`,
"bad (non-string) event": `{"jsonrpc": "2.0", "method": "subscribe", "params": [1], "id": 1}`,
"bad (wrong) event": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["block_removed"], "id": 1}`,
}
var unsubCases = map[string]string{
"no params": `{"jsonrpc": "2.0", "method": "unsubscribe", "params": [], "id": 1}`,
"bad id": `{"jsonrpc": "2.0", "method": "unsubscribe", "params": ["vasiliy"], "id": 1}`,
"not subscribed id": `{"jsonrpc": "2.0", "method": "unsubscribe", "params": ["7"], "id": 1}`,
}
chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t)
defer chain.Close()
defer rpcSrv.Shutdown()
testF := func(t *testing.T, cases map[string]string) func(t *testing.T) {
return func(t *testing.T) {
for n, s := range cases {
t.Run(n, func(t *testing.T) {
resp := callWSGetRaw(t, c, s, respMsgs)
require.NotNil(t, resp.Error)
require.Nil(t, resp.Result)
})
}
}
}
t.Run("subscribe", testF(t, subCases))
t.Run("unsubscribe", testF(t, unsubCases))
finishedFlag.CAS(false, true)
c.Close()
}
func doSomeWSRequest(t *testing.T, ws *websocket.Conn) {
ws.SetWriteDeadline(time.Now().Add(time.Second))
// It could be just about anything including invalid request,
// we only care about server handling being active.
require.NoError(t, ws.WriteMessage(websocket.TextMessage, []byte(`{"jsonrpc": "2.0", "method": "getversion", "params": [], "id": 1}`)))
ws.SetReadDeadline(time.Now().Add(time.Second))
_, _, err := ws.ReadMessage()
require.NoError(t, err)
}
func TestWSClientsLimit(t *testing.T) {
chain, rpcSrv, httpSrv := initClearServerWithInMemoryChain(t)
defer chain.Close()
defer rpcSrv.Shutdown()
dialer := websocket.Dialer{HandshakeTimeout: time.Second}
url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws"
wss := make([]*websocket.Conn, maxSubscribers)
for i := 0; i < len(wss)+1; i++ {
ws, _, err := dialer.Dial(url, nil)
if i < maxSubscribers {
require.NoError(t, err)
wss[i] = ws
// Check that it's completely ready.
doSomeWSRequest(t, ws)
} else {
require.Error(t, err)
}
}
// Check connections are still alive (it actually is necessary to add
// some use of wss to keep connections alive).
for i := 0; i < len(wss); i++ {
doSomeWSRequest(t, wss[i])
}
}