rpc/server: add websockets support via '/ws' URL

This commit is contained in:
Roman Khimov 2020-04-29 15:25:58 +03:00
parent 8cec6694ae
commit ec62edac68
4 changed files with 108 additions and 4 deletions

View file

@ -9,12 +9,12 @@ import (
"net"
"net/http"
"strconv"
"time"
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
"github.com/nspcc-dev/neo-go/pkg/rpc"
"github.com/gorilla/websocket"
"github.com/nspcc-dev/neo-go/pkg/core"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
@ -22,6 +22,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/network"
"github.com/nspcc-dev/neo-go/pkg/rpc"
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
"github.com/nspcc-dev/neo-go/pkg/rpc/response"
"github.com/nspcc-dev/neo-go/pkg/rpc/response/result"
@ -44,6 +45,20 @@ type (
}
)
const (
// Message limit for receiving side.
wsReadLimit = 4096
// Disconnection timeout.
wsPongLimit = 60 * time.Second
// Ping period for connection liveness check.
wsPingPeriod = wsPongLimit / 2
// Write deadline.
wsWriteLimit = wsPingPeriod / 2
)
var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *response.Error){
"getaccountstate": (*Server).getAccountState,
"getapplicationlog": (*Server).getApplicationLog,
@ -81,6 +96,10 @@ var invalidBlockHeightError = func(index int, height int) *response.Error {
return response.NewRPCError(fmt.Sprintf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height), "", nil)
}
// upgrader is a no-op websocket.Upgrader that reuses HTTP server buffers and
// doesn't set any Error function.
var upgrader = websocket.Upgrader{}
// New creates a new Server struct.
func New(chain blockchainer.Blockchainer, conf rpc.Config, coreServer *network.Server, log *zap.Logger) Server {
httpServer := &http.Server{
@ -150,6 +169,18 @@ func (s *Server) Shutdown() error {
}
func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) {
if httpRequest.URL.Path == "/ws" && httpRequest.Method == "GET" {
ws, err := upgrader.Upgrade(w, httpRequest, nil)
if err != nil {
s.log.Info("websocket connection upgrade failed", zap.Error(err))
return
}
resChan := make(chan response.Raw)
go s.handleWsWrites(ws, resChan)
s.handleWsReads(ws, resChan)
return
}
req := request.NewIn()
if httpRequest.Method != "POST" {
@ -193,6 +224,49 @@ func (s *Server) handleRequest(req *request.In) response.Raw {
return s.packResponseToRaw(req, res, resErr)
}
func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw) {
pingTicker := time.NewTicker(wsPingPeriod)
defer ws.Close()
defer pingTicker.Stop()
for {
select {
case res, ok := <-resChan:
if !ok {
return
}
ws.SetWriteDeadline(time.Now().Add(wsWriteLimit))
if err := ws.WriteJSON(res); err != nil {
return
}
case <-pingTicker.C:
ws.SetWriteDeadline(time.Now().Add(wsWriteLimit))
if err := ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
return
}
}
}
}
func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.Raw) {
ws.SetReadLimit(wsReadLimit)
ws.SetReadDeadline(time.Now().Add(wsPongLimit))
ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil })
for {
req := new(request.In)
err := ws.ReadJSON(req)
if err != nil {
break
}
res := s.handleRequest(req)
if res.Error != nil {
s.logRequestError(req, res.Error)
}
resChan <- res
}
close(resChan)
ws.Close()
}
func (s *Server) getBestBlockHash(_ request.Params) (interface{}, *response.Error) {
return "0x" + s.chain.CurrentBlockHash().StringLE(), nil
}