diff --git a/pkg/rpc/request/types.go b/pkg/rpc/request/types.go index 9020d508e..0af1074b1 100644 --- a/pkg/rpc/request/types.go +++ b/pkg/rpc/request/types.go @@ -1,7 +1,9 @@ package request import ( + "bytes" "encoding/json" + "fmt" "io" "github.com/pkg/errors" @@ -10,6 +12,9 @@ import ( const ( // JSONRPCVersion is the only JSON-RPC protocol version supported. JSONRPCVersion = "2.0" + + // maxBatchSize is the maximum number of request per batch. + maxBatchSize = 100 ) // RawParams is just a slice of abstract values, used to represent parameters @@ -36,9 +41,16 @@ type Raw struct { ID int `json:"id"` } +// Request contains standard JSON-RPC 2.0 request and batch of +// requests: http://www.jsonrpc.org/specification. +// It's used in server to represent incoming queries. +type Request struct { + In *In + Batch Batch +} + // In represents a standard JSON-RPC 2.0 -// request: http://www.jsonrpc.org/specification#request_object. It's used in -// server to represent incoming queries. +// request: http://www.jsonrpc.org/specification#request_object. type In struct { JSONRPC string `json:"jsonrpc"` Method string `json:"method"` @@ -46,28 +58,81 @@ type In struct { RawID json.RawMessage `json:"id,omitempty"` } -// NewIn creates a new Request struct. -func NewIn() *In { - return &In{ - JSONRPC: JSONRPCVersion, +// Batch represents a standard JSON-RPC 2.0 +// batch: https://www.jsonrpc.org/specification#batch. +type Batch []In + +// MarshalJSON implements json.Marshaler interface +func (r Request) MarshalJSON() ([]byte, error) { + if r.In != nil { + return json.Marshal(r.In) } + return json.Marshal(r.Batch) +} + +// UnmarshalJSON implements json.Unmarshaler interface. +func (r *Request) UnmarshalJSON(data []byte) error { + var ( + in *In + batch Batch + ) + in = &In{} + err := json.Unmarshal(data, in) + if err == nil { + r.In = in + return nil + } + decoder := json.NewDecoder(bytes.NewReader(data)) + t, err := decoder.Token() // read `[` + if err != nil { + return err + } + if t != json.Delim('[') { + return fmt.Errorf("`[` expected, got %s", t) + } + count := 0 + for decoder.More() { + if count > maxBatchSize { + return fmt.Errorf("the number of requests in batch shouldn't exceed %d", maxBatchSize) + } + in = &In{} + decodeErr := decoder.Decode(in) + if decodeErr != nil { + return decodeErr + } + batch = append(batch, *in) + count++ + } + if len(batch) == 0 { + return errors.New("empty request") + } + r.Batch = batch + return nil } // DecodeData decodes the given reader into the the request // struct. -func (r *In) DecodeData(data io.ReadCloser) error { +func (r *Request) DecodeData(data io.ReadCloser) error { defer data.Close() - err := json.NewDecoder(data).Decode(r) + rawData := json.RawMessage{} + err := json.NewDecoder(data).Decode(&rawData) if err != nil { return errors.Errorf("error parsing JSON payload: %s", err) } + return r.UnmarshalJSON(rawData) +} - if r.JSONRPC != JSONRPCVersion { - return errors.Errorf("invalid version, expected 2.0 got: '%s'", r.JSONRPC) +// NewRequest creates a new Request struct. +func NewRequest() *Request { + return &Request{} +} + +// NewIn creates a new In struct. +func NewIn() *In { + return &In{ + JSONRPC: JSONRPCVersion, } - - return nil } // Params takes a slice of any type and attempts to bind diff --git a/pkg/rpc/response/types.go b/pkg/rpc/response/types.go index ba23c7677..63db81f31 100644 --- a/pkg/rpc/response/types.go +++ b/pkg/rpc/response/types.go @@ -38,6 +38,31 @@ type GetRawTx struct { Result *result.TransactionOutputRaw `json:"result"` } +// AbstractResult is an interface which represents either single JSON-RPC 2.0 response +// or batch JSON-RPC 2.0 response. +type AbstractResult interface { + RunForErrors(f func(jsonErr *Error)) +} + +// RunForErrors implements AbstractResult interface. +func (r Raw) RunForErrors(f func(jsonErr *Error)) { + if r.Error != nil { + f(r.Error) + } +} + +// RawBatch represents abstract JSON-RPC 2.0 batch-response. +type RawBatch []Raw + +// RunForErrors implements AbstractResult interface. +func (rb RawBatch) RunForErrors(f func(jsonErr *Error)) { + for _, r := range rb { + if r.Error != nil { + f(r.Error) + } + } +} + // Notification is a type used to represent wire format of events, they're // special in that they look like requests but they don't have IDs and their // "method" is actually an event name. diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index 6073d36bc..6eabd408d 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -220,7 +220,7 @@ func (s *Server) Shutdown() error { } func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) { - req := request.NewIn() + req := request.NewRequest() if httpRequest.URL.Path == "/ws" && httpRequest.Method == "GET" { // Technically there is a race between this check and @@ -232,7 +232,7 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ s.subsLock.RUnlock() if numOfSubs >= maxSubscribers { s.writeHTTPErrorResponse( - req, + request.NewIn(), w, response.NewInternalServerError("websocket users limit reached", nil), ) @@ -243,7 +243,7 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ s.log.Info("websocket connection upgrade failed", zap.Error(err)) return } - resChan := make(chan response.Raw) + resChan := make(chan response.AbstractResult) // response.Raw or response.RawBatch subChan := make(chan *websocket.PreparedMessage, notificationBufSize) subscr := &subscriber{writer: subChan, ws: ws} s.subsLock.Lock() @@ -256,7 +256,7 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ if httpRequest.Method != "POST" { s.writeHTTPErrorResponse( - req, + request.NewIn(), w, response.NewInvalidParamsError( fmt.Sprintf("Invalid method '%s', please retry with 'POST'", httpRequest.Method), nil, @@ -267,7 +267,7 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ err := req.DecodeData(httpRequest.Body) if err != nil { - s.writeHTTPErrorResponse(req, w, response.NewParseError("Problem parsing JSON-RPC request body", err)) + s.writeHTTPErrorResponse(request.NewIn(), w, response.NewParseError("Problem parsing JSON-RPC request body", err)) return } @@ -275,9 +275,23 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ s.writeHTTPServerResponse(req, w, resp) } -func (s *Server) handleRequest(req *request.In, sub *subscriber) response.Raw { +func (s *Server) handleRequest(req *request.Request, sub *subscriber) response.AbstractResult { + if req.In != nil { + return s.handleIn(req.In, sub) + } + resp := make(response.RawBatch, len(req.Batch)) + for i, in := range req.Batch { + resp[i] = s.handleIn(&in, sub) + } + return resp +} + +func (s *Server) handleIn(req *request.In, sub *subscriber) response.Raw { var res interface{} var resErr *response.Error + if req.JSONRPC != request.JSONRPCVersion { + return s.packResponseToRaw(req, nil, response.NewInvalidParamsError("Problem parsing JSON", fmt.Errorf("invalid version, expected 2.0 got: '%s'", req.JSONRPC))) + } reqParams, err := req.Params() if err != nil { @@ -303,7 +317,7 @@ func (s *Server) handleRequest(req *request.In, sub *subscriber) response.Raw { return s.packResponseToRaw(req, res, resErr) } -func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw, subChan <-chan *websocket.PreparedMessage) { +func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.AbstractResult, subChan <-chan *websocket.PreparedMessage) { pingTicker := time.NewTicker(wsPingPeriod) eventloop: for { @@ -350,21 +364,21 @@ drainloop: } } -func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.Raw, subscr *subscriber) { +func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.AbstractResult, subscr *subscriber) { ws.SetReadLimit(wsReadLimit) ws.SetReadDeadline(time.Now().Add(wsPongLimit)) ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil }) requestloop: for { - req := new(request.In) + req := request.NewRequest() err := ws.ReadJSON(req) if err != nil { break } res := s.handleRequest(req, subscr) - if res.Error != nil { - s.logRequestError(req, res.Error) - } + res.RunForErrors(func(jsonErr *response.Error) { + s.logRequestError(req, jsonErr) + }) select { case <-s.shutdown: break requestloop @@ -1842,15 +1856,17 @@ func (s *Server) packResponseToRaw(r *request.In, result interface{}, respErr *r } // logRequestError is a request error logger. -func (s *Server) logRequestError(r *request.In, jsonErr *response.Error) { +func (s *Server) logRequestError(r *request.Request, jsonErr *response.Error) { logFields := []zap.Field{ zap.Error(jsonErr.Cause), - zap.String("method", r.Method), } - params, err := r.Params() - if err == nil { - logFields = append(logFields, zap.Any("params", params)) + if r.In != nil { + logFields = append(logFields, zap.String("method", r.In.Method)) + params, err := r.In.Params() + if err == nil { + logFields = append(logFields, zap.Any("params", params)) + } } s.log.Error("Error encountered with rpc request", logFields...) @@ -1859,14 +1875,19 @@ func (s *Server) logRequestError(r *request.In, jsonErr *response.Error) { // writeHTTPErrorResponse writes an error response to the ResponseWriter. func (s *Server) writeHTTPErrorResponse(r *request.In, w http.ResponseWriter, jsonErr *response.Error) { resp := s.packResponseToRaw(r, nil, jsonErr) - s.writeHTTPServerResponse(r, w, resp) + s.writeHTTPServerResponse(&request.Request{In: r}, w, resp) } -func (s *Server) writeHTTPServerResponse(r *request.In, w http.ResponseWriter, resp response.Raw) { +func (s *Server) writeHTTPServerResponse(r *request.Request, w http.ResponseWriter, resp response.AbstractResult) { // Errors can happen in many places and we can only catch ALL of them here. - if resp.Error != nil { - s.logRequestError(r, resp.Error) - w.WriteHeader(resp.Error.HTTPCode) + resp.RunForErrors(func(jsonErr *response.Error) { + s.logRequestError(r, jsonErr) + }) + if r.In != nil { + resp := resp.(response.Raw) + if resp.Error != nil { + w.WriteHeader(resp.Error.HTTPCode) + } } w.Header().Set("Content-Type", "application/json; charset=utf-8") if s.config.EnableCORSWorkaround { @@ -1878,9 +1899,15 @@ func (s *Server) writeHTTPServerResponse(r *request.In, w http.ResponseWriter, r err := encoder.Encode(resp) if err != nil { - s.log.Error("Error encountered while encoding response", - zap.String("err", err.Error()), - zap.String("method", r.Method)) + switch { + case r.In != nil: + s.log.Error("Error encountered while encoding response", + zap.String("err", err.Error()), + zap.String("method", r.In.Method)) + case r.Batch != nil: + s.log.Error("Error encountered while encoding batch response", + zap.String("err", err.Error())) + } } } diff --git a/pkg/rpc/server/server_test.go b/pkg/rpc/server/server_test.go index 721c65b35..e1cae687d 100644 --- a/pkg/rpc/server/server_test.go +++ b/pkg/rpc/server/server_test.go @@ -968,31 +968,111 @@ func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) [] defer rpcSrv.Shutdown() e := &executor{chain: chain, httpSrv: httpSrv} - for method, cases := range rpcTestCases { - t.Run(method, func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "%s", "params": %s}` + t.Run("single request", func(t *testing.T) { + for method, cases := range rpcTestCases { + t.Run(method, func(t *testing.T) { + rpc := `{"jsonrpc": "2.0", "id": 1, "method": "%s", "params": %s}` - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), httpSrv.URL, t) - result := checkErrGetResult(t, body, tc.fail) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), httpSrv.URL, t) + result := checkErrGetResult(t, body, tc.fail) + if tc.fail { + return + } + + expected, res := tc.getResultPair(e) + err := json.Unmarshal(result, res) + require.NoErrorf(t, err, "could not parse response: %s", result) + + if tc.check == nil { + assert.Equal(t, expected, res) + } else { + tc.check(t, e, res) + } + }) + } + }) + } + }) + t.Run("batch with single request", func(t *testing.T) { + for method, cases := range rpcTestCases { + if method == "sendrawtransaction" || method == "submitblock" { + continue // cannot send the same transaction twice + } + t.Run(method, func(t *testing.T) { + rpc := `[{"jsonrpc": "2.0", "id": 1, "method": "%s", "params": %s}]` + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), httpSrv.URL, t) + result := checkErrGetBatchResult(t, body, tc.fail) + if tc.fail { + return + } + + expected, res := tc.getResultPair(e) + err := json.Unmarshal(result, res) + require.NoErrorf(t, err, "could not parse response: %s", result) + + if tc.check == nil { + assert.Equal(t, expected, res) + } else { + tc.check(t, e, res) + } + }) + } + }) + } + }) + + t.Run("batch with multiple requests", func(t *testing.T) { + for method, cases := range rpcTestCases { + if method == "sendrawtransaction" || method == "submitblock" { + continue // cannot send the same transaction twice + } + t.Run(method, func(t *testing.T) { + rpc := `{"jsonrpc": "2.0", "id": %d, "method": "%s", "params": %s},` + var resultRPC string + for i, tc := range cases { + resultRPC += fmt.Sprintf(rpc, i, method, tc.params) + } + resultRPC = `[` + resultRPC[:len(resultRPC)-1] + `]` + body := doRPCCall(resultRPC, httpSrv.URL, t) + var responses []response.Raw + err := json.Unmarshal(body, &responses) + require.Nil(t, err) + for i, tc := range cases { + var resp response.Raw + for _, r := range responses { + if bytes.Equal(r.ID, []byte(strconv.Itoa(i))) { + resp = r + break + } + } + if tc.fail { + require.NotNil(t, resp.Error) + assert.NotEqual(t, 0, resp.Error.Code) + assert.NotEqual(t, "", resp.Error.Message) + } else { + assert.Nil(t, resp.Error) + } if tc.fail { return } - expected, res := tc.getResultPair(e) - err := json.Unmarshal(result, res) - require.NoErrorf(t, err, "could not parse response: %s", result) + err := json.Unmarshal(resp.Result, res) + require.NoErrorf(t, err, "could not parse response: %s", resp.Result) if tc.check == nil { assert.Equal(t, expected, res) } else { tc.check(t, e, res) } - }) - } - }) - } + } + }) + } + }) t.Run("getproof", func(t *testing.T) { r, err := chain.GetStateRoot(210) @@ -1429,6 +1509,21 @@ func checkErrGetResult(t *testing.T, body []byte, expectingFail bool) json.RawMe return resp.Result } +func checkErrGetBatchResult(t *testing.T, body []byte, expectingFail bool) json.RawMessage { + var resp []response.Raw + err := json.Unmarshal(body, &resp) + require.Nil(t, err) + require.Equal(t, 1, len(resp)) + if expectingFail { + require.NotNil(t, resp[0].Error) + assert.NotEqual(t, 0, resp[0].Error.Code) + assert.NotEqual(t, "", resp[0].Error.Message) + } else { + assert.Nil(t, resp[0].Error) + } + return resp[0].Result +} + func doRPCCallOverWS(rpcCall string, url string, t *testing.T) []byte { dialer := websocket.Dialer{HandshakeTimeout: time.Second} url = "ws" + strings.TrimPrefix(url, "http")