rpc: allow batch JSON-RPC requests

Close #1509
This commit is contained in:
Anna Shaleva 2020-10-26 20:22:20 +03:00
parent 559024671a
commit ef3eb0a842
4 changed files with 263 additions and 51 deletions

View file

@ -1,7 +1,9 @@
package request package request
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -10,6 +12,9 @@ import (
const ( const (
// JSONRPCVersion is the only JSON-RPC protocol version supported. // JSONRPCVersion is the only JSON-RPC protocol version supported.
JSONRPCVersion = "2.0" JSONRPCVersion = "2.0"
// maxBatchSize is the maximum number of request per batch.
maxBatchSize = 100
) )
// RawParams is just a slice of abstract values, used to represent parameters // RawParams is just a slice of abstract values, used to represent parameters
@ -36,9 +41,16 @@ type Raw struct {
ID int `json:"id"` ID int `json:"id"`
} }
// Request contains standard JSON-RPC 2.0 request and batch of
// requests: http://www.jsonrpc.org/specification.
// It's used in server to represent incoming queries.
type Request struct {
In *In
Batch Batch
}
// In represents a standard JSON-RPC 2.0 // In represents a standard JSON-RPC 2.0
// request: http://www.jsonrpc.org/specification#request_object. It's used in // request: http://www.jsonrpc.org/specification#request_object.
// server to represent incoming queries.
type In struct { type In struct {
JSONRPC string `json:"jsonrpc"` JSONRPC string `json:"jsonrpc"`
Method string `json:"method"` Method string `json:"method"`
@ -46,28 +58,81 @@ type In struct {
RawID json.RawMessage `json:"id,omitempty"` RawID json.RawMessage `json:"id,omitempty"`
} }
// NewIn creates a new Request struct. // Batch represents a standard JSON-RPC 2.0
func NewIn() *In { // batch: https://www.jsonrpc.org/specification#batch.
return &In{ type Batch []In
JSONRPC: JSONRPCVersion,
// MarshalJSON implements json.Marshaler interface
func (r Request) MarshalJSON() ([]byte, error) {
if r.In != nil {
return json.Marshal(r.In)
} }
return json.Marshal(r.Batch)
}
// UnmarshalJSON implements json.Unmarshaler interface.
func (r *Request) UnmarshalJSON(data []byte) error {
var (
in *In
batch Batch
)
in = &In{}
err := json.Unmarshal(data, in)
if err == nil {
r.In = in
return nil
}
decoder := json.NewDecoder(bytes.NewReader(data))
t, err := decoder.Token() // read `[`
if err != nil {
return err
}
if t != json.Delim('[') {
return fmt.Errorf("`[` expected, got %s", t)
}
count := 0
for decoder.More() {
if count > maxBatchSize {
return fmt.Errorf("the number of requests in batch shouldn't exceed %d", maxBatchSize)
}
in = &In{}
decodeErr := decoder.Decode(in)
if decodeErr != nil {
return decodeErr
}
batch = append(batch, *in)
count++
}
if len(batch) == 0 {
return errors.New("empty request")
}
r.Batch = batch
return nil
} }
// DecodeData decodes the given reader into the the request // DecodeData decodes the given reader into the the request
// struct. // struct.
func (r *In) DecodeData(data io.ReadCloser) error { func (r *Request) DecodeData(data io.ReadCloser) error {
defer data.Close() defer data.Close()
err := json.NewDecoder(data).Decode(r) rawData := json.RawMessage{}
err := json.NewDecoder(data).Decode(&rawData)
if err != nil { if err != nil {
return errors.Errorf("error parsing JSON payload: %s", err) return errors.Errorf("error parsing JSON payload: %s", err)
} }
return r.UnmarshalJSON(rawData)
}
if r.JSONRPC != JSONRPCVersion { // NewRequest creates a new Request struct.
return errors.Errorf("invalid version, expected 2.0 got: '%s'", r.JSONRPC) func NewRequest() *Request {
return &Request{}
}
// NewIn creates a new In struct.
func NewIn() *In {
return &In{
JSONRPC: JSONRPCVersion,
} }
return nil
} }
// Params takes a slice of any type and attempts to bind // Params takes a slice of any type and attempts to bind

View file

@ -38,6 +38,31 @@ type GetRawTx struct {
Result *result.TransactionOutputRaw `json:"result"` Result *result.TransactionOutputRaw `json:"result"`
} }
// AbstractResult is an interface which represents either single JSON-RPC 2.0 response
// or batch JSON-RPC 2.0 response.
type AbstractResult interface {
RunForErrors(f func(jsonErr *Error))
}
// RunForErrors implements AbstractResult interface.
func (r Raw) RunForErrors(f func(jsonErr *Error)) {
if r.Error != nil {
f(r.Error)
}
}
// RawBatch represents abstract JSON-RPC 2.0 batch-response.
type RawBatch []Raw
// RunForErrors implements AbstractResult interface.
func (rb RawBatch) RunForErrors(f func(jsonErr *Error)) {
for _, r := range rb {
if r.Error != nil {
f(r.Error)
}
}
}
// Notification is a type used to represent wire format of events, they're // Notification is a type used to represent wire format of events, they're
// special in that they look like requests but they don't have IDs and their // special in that they look like requests but they don't have IDs and their
// "method" is actually an event name. // "method" is actually an event name.

View file

@ -220,7 +220,7 @@ func (s *Server) Shutdown() error {
} }
func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) { func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) {
req := request.NewIn() req := request.NewRequest()
if httpRequest.URL.Path == "/ws" && httpRequest.Method == "GET" { if httpRequest.URL.Path == "/ws" && httpRequest.Method == "GET" {
// Technically there is a race between this check and // Technically there is a race between this check and
@ -232,7 +232,7 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ
s.subsLock.RUnlock() s.subsLock.RUnlock()
if numOfSubs >= maxSubscribers { if numOfSubs >= maxSubscribers {
s.writeHTTPErrorResponse( s.writeHTTPErrorResponse(
req, request.NewIn(),
w, w,
response.NewInternalServerError("websocket users limit reached", nil), response.NewInternalServerError("websocket users limit reached", nil),
) )
@ -243,7 +243,7 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ
s.log.Info("websocket connection upgrade failed", zap.Error(err)) s.log.Info("websocket connection upgrade failed", zap.Error(err))
return return
} }
resChan := make(chan response.Raw) resChan := make(chan response.AbstractResult) // response.Raw or response.RawBatch
subChan := make(chan *websocket.PreparedMessage, notificationBufSize) subChan := make(chan *websocket.PreparedMessage, notificationBufSize)
subscr := &subscriber{writer: subChan, ws: ws} subscr := &subscriber{writer: subChan, ws: ws}
s.subsLock.Lock() s.subsLock.Lock()
@ -256,7 +256,7 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ
if httpRequest.Method != "POST" { if httpRequest.Method != "POST" {
s.writeHTTPErrorResponse( s.writeHTTPErrorResponse(
req, request.NewIn(),
w, w,
response.NewInvalidParamsError( response.NewInvalidParamsError(
fmt.Sprintf("Invalid method '%s', please retry with 'POST'", httpRequest.Method), nil, fmt.Sprintf("Invalid method '%s', please retry with 'POST'", httpRequest.Method), nil,
@ -267,7 +267,7 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ
err := req.DecodeData(httpRequest.Body) err := req.DecodeData(httpRequest.Body)
if err != nil { if err != nil {
s.writeHTTPErrorResponse(req, w, response.NewParseError("Problem parsing JSON-RPC request body", err)) s.writeHTTPErrorResponse(request.NewIn(), w, response.NewParseError("Problem parsing JSON-RPC request body", err))
return return
} }
@ -275,9 +275,23 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ
s.writeHTTPServerResponse(req, w, resp) s.writeHTTPServerResponse(req, w, resp)
} }
func (s *Server) handleRequest(req *request.In, sub *subscriber) response.Raw { func (s *Server) handleRequest(req *request.Request, sub *subscriber) response.AbstractResult {
if req.In != nil {
return s.handleIn(req.In, sub)
}
resp := make(response.RawBatch, len(req.Batch))
for i, in := range req.Batch {
resp[i] = s.handleIn(&in, sub)
}
return resp
}
func (s *Server) handleIn(req *request.In, sub *subscriber) response.Raw {
var res interface{} var res interface{}
var resErr *response.Error var resErr *response.Error
if req.JSONRPC != request.JSONRPCVersion {
return s.packResponseToRaw(req, nil, response.NewInvalidParamsError("Problem parsing JSON", fmt.Errorf("invalid version, expected 2.0 got: '%s'", req.JSONRPC)))
}
reqParams, err := req.Params() reqParams, err := req.Params()
if err != nil { if err != nil {
@ -303,7 +317,7 @@ func (s *Server) handleRequest(req *request.In, sub *subscriber) response.Raw {
return s.packResponseToRaw(req, res, resErr) return s.packResponseToRaw(req, res, resErr)
} }
func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw, subChan <-chan *websocket.PreparedMessage) { func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.AbstractResult, subChan <-chan *websocket.PreparedMessage) {
pingTicker := time.NewTicker(wsPingPeriod) pingTicker := time.NewTicker(wsPingPeriod)
eventloop: eventloop:
for { for {
@ -350,21 +364,21 @@ drainloop:
} }
} }
func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.Raw, subscr *subscriber) { func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.AbstractResult, subscr *subscriber) {
ws.SetReadLimit(wsReadLimit) ws.SetReadLimit(wsReadLimit)
ws.SetReadDeadline(time.Now().Add(wsPongLimit)) ws.SetReadDeadline(time.Now().Add(wsPongLimit))
ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil }) ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil })
requestloop: requestloop:
for { for {
req := new(request.In) req := request.NewRequest()
err := ws.ReadJSON(req) err := ws.ReadJSON(req)
if err != nil { if err != nil {
break break
} }
res := s.handleRequest(req, subscr) res := s.handleRequest(req, subscr)
if res.Error != nil { res.RunForErrors(func(jsonErr *response.Error) {
s.logRequestError(req, res.Error) s.logRequestError(req, jsonErr)
} })
select { select {
case <-s.shutdown: case <-s.shutdown:
break requestloop break requestloop
@ -1842,15 +1856,17 @@ func (s *Server) packResponseToRaw(r *request.In, result interface{}, respErr *r
} }
// logRequestError is a request error logger. // logRequestError is a request error logger.
func (s *Server) logRequestError(r *request.In, jsonErr *response.Error) { func (s *Server) logRequestError(r *request.Request, jsonErr *response.Error) {
logFields := []zap.Field{ logFields := []zap.Field{
zap.Error(jsonErr.Cause), zap.Error(jsonErr.Cause),
zap.String("method", r.Method),
} }
params, err := r.Params() if r.In != nil {
if err == nil { logFields = append(logFields, zap.String("method", r.In.Method))
logFields = append(logFields, zap.Any("params", params)) params, err := r.In.Params()
if err == nil {
logFields = append(logFields, zap.Any("params", params))
}
} }
s.log.Error("Error encountered with rpc request", logFields...) s.log.Error("Error encountered with rpc request", logFields...)
@ -1859,14 +1875,19 @@ func (s *Server) logRequestError(r *request.In, jsonErr *response.Error) {
// writeHTTPErrorResponse writes an error response to the ResponseWriter. // writeHTTPErrorResponse writes an error response to the ResponseWriter.
func (s *Server) writeHTTPErrorResponse(r *request.In, w http.ResponseWriter, jsonErr *response.Error) { func (s *Server) writeHTTPErrorResponse(r *request.In, w http.ResponseWriter, jsonErr *response.Error) {
resp := s.packResponseToRaw(r, nil, jsonErr) resp := s.packResponseToRaw(r, nil, jsonErr)
s.writeHTTPServerResponse(r, w, resp) s.writeHTTPServerResponse(&request.Request{In: r}, w, resp)
} }
func (s *Server) writeHTTPServerResponse(r *request.In, w http.ResponseWriter, resp response.Raw) { func (s *Server) writeHTTPServerResponse(r *request.Request, w http.ResponseWriter, resp response.AbstractResult) {
// Errors can happen in many places and we can only catch ALL of them here. // Errors can happen in many places and we can only catch ALL of them here.
if resp.Error != nil { resp.RunForErrors(func(jsonErr *response.Error) {
s.logRequestError(r, resp.Error) s.logRequestError(r, jsonErr)
w.WriteHeader(resp.Error.HTTPCode) })
if r.In != nil {
resp := resp.(response.Raw)
if resp.Error != nil {
w.WriteHeader(resp.Error.HTTPCode)
}
} }
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
if s.config.EnableCORSWorkaround { if s.config.EnableCORSWorkaround {
@ -1878,9 +1899,15 @@ func (s *Server) writeHTTPServerResponse(r *request.In, w http.ResponseWriter, r
err := encoder.Encode(resp) err := encoder.Encode(resp)
if err != nil { if err != nil {
s.log.Error("Error encountered while encoding response", switch {
zap.String("err", err.Error()), case r.In != nil:
zap.String("method", r.Method)) s.log.Error("Error encountered while encoding response",
zap.String("err", err.Error()),
zap.String("method", r.In.Method))
case r.Batch != nil:
s.log.Error("Error encountered while encoding batch response",
zap.String("err", err.Error()))
}
} }
} }

View file

@ -968,31 +968,111 @@ func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) []
defer rpcSrv.Shutdown() defer rpcSrv.Shutdown()
e := &executor{chain: chain, httpSrv: httpSrv} e := &executor{chain: chain, httpSrv: httpSrv}
for method, cases := range rpcTestCases { t.Run("single request", func(t *testing.T) {
t.Run(method, func(t *testing.T) { for method, cases := range rpcTestCases {
rpc := `{"jsonrpc": "2.0", "id": 1, "method": "%s", "params": %s}` t.Run(method, func(t *testing.T) {
rpc := `{"jsonrpc": "2.0", "id": 1, "method": "%s", "params": %s}`
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), httpSrv.URL, t) body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), httpSrv.URL, t)
result := checkErrGetResult(t, body, tc.fail) result := checkErrGetResult(t, body, tc.fail)
if tc.fail {
return
}
expected, res := tc.getResultPair(e)
err := json.Unmarshal(result, res)
require.NoErrorf(t, err, "could not parse response: %s", result)
if tc.check == nil {
assert.Equal(t, expected, res)
} else {
tc.check(t, e, res)
}
})
}
})
}
})
t.Run("batch with single request", func(t *testing.T) {
for method, cases := range rpcTestCases {
if method == "sendrawtransaction" || method == "submitblock" {
continue // cannot send the same transaction twice
}
t.Run(method, func(t *testing.T) {
rpc := `[{"jsonrpc": "2.0", "id": 1, "method": "%s", "params": %s}]`
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), httpSrv.URL, t)
result := checkErrGetBatchResult(t, body, tc.fail)
if tc.fail {
return
}
expected, res := tc.getResultPair(e)
err := json.Unmarshal(result, res)
require.NoErrorf(t, err, "could not parse response: %s", result)
if tc.check == nil {
assert.Equal(t, expected, res)
} else {
tc.check(t, e, res)
}
})
}
})
}
})
t.Run("batch with multiple requests", func(t *testing.T) {
for method, cases := range rpcTestCases {
if method == "sendrawtransaction" || method == "submitblock" {
continue // cannot send the same transaction twice
}
t.Run(method, func(t *testing.T) {
rpc := `{"jsonrpc": "2.0", "id": %d, "method": "%s", "params": %s},`
var resultRPC string
for i, tc := range cases {
resultRPC += fmt.Sprintf(rpc, i, method, tc.params)
}
resultRPC = `[` + resultRPC[:len(resultRPC)-1] + `]`
body := doRPCCall(resultRPC, httpSrv.URL, t)
var responses []response.Raw
err := json.Unmarshal(body, &responses)
require.Nil(t, err)
for i, tc := range cases {
var resp response.Raw
for _, r := range responses {
if bytes.Equal(r.ID, []byte(strconv.Itoa(i))) {
resp = r
break
}
}
if tc.fail {
require.NotNil(t, resp.Error)
assert.NotEqual(t, 0, resp.Error.Code)
assert.NotEqual(t, "", resp.Error.Message)
} else {
assert.Nil(t, resp.Error)
}
if tc.fail { if tc.fail {
return return
} }
expected, res := tc.getResultPair(e) expected, res := tc.getResultPair(e)
err := json.Unmarshal(result, res) err := json.Unmarshal(resp.Result, res)
require.NoErrorf(t, err, "could not parse response: %s", result) require.NoErrorf(t, err, "could not parse response: %s", resp.Result)
if tc.check == nil { if tc.check == nil {
assert.Equal(t, expected, res) assert.Equal(t, expected, res)
} else { } else {
tc.check(t, e, res) tc.check(t, e, res)
} }
}) }
} })
}) }
} })
t.Run("getproof", func(t *testing.T) { t.Run("getproof", func(t *testing.T) {
r, err := chain.GetStateRoot(210) r, err := chain.GetStateRoot(210)
@ -1429,6 +1509,21 @@ func checkErrGetResult(t *testing.T, body []byte, expectingFail bool) json.RawMe
return resp.Result return resp.Result
} }
func checkErrGetBatchResult(t *testing.T, body []byte, expectingFail bool) json.RawMessage {
var resp []response.Raw
err := json.Unmarshal(body, &resp)
require.Nil(t, err)
require.Equal(t, 1, len(resp))
if expectingFail {
require.NotNil(t, resp[0].Error)
assert.NotEqual(t, 0, resp[0].Error.Code)
assert.NotEqual(t, "", resp[0].Error.Message)
} else {
assert.Nil(t, resp[0].Error)
}
return resp[0].Result
}
func doRPCCallOverWS(rpcCall string, url string, t *testing.T) []byte { func doRPCCallOverWS(rpcCall string, url string, t *testing.T) []byte {
dialer := websocket.Dialer{HandshakeTimeout: time.Second} dialer := websocket.Dialer{HandshakeTimeout: time.Second}
url = "ws" + strings.TrimPrefix(url, "http") url = "ws" + strings.TrimPrefix(url, "http")