diff --git a/pkg/rpc/client/rpc_test.go b/pkg/rpc/client/rpc_test.go index 66844cd74..e13b3326b 100644 --- a/pkg/rpc/client/rpc_test.go +++ b/pkg/rpc/client/rpc_test.go @@ -1426,12 +1426,12 @@ func initTestServer(t *testing.T, resp string) *httptest.Server { ws.Close() return } - r := request.NewIn() + r := request.NewRequest() err := r.DecodeData(req.Body) if err != nil { t.Fatalf("Cannot decode request body: %s", req.Body) } - requestHandler(t, r, w, resp) + requestHandler(t, r.In, w, resp) })) return srv @@ -1480,13 +1480,13 @@ func TestCalculateValidUntilBlock(t *testing.T) { getValidatorsCalled int ) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - r := request.NewIn() + r := request.NewRequest() err := r.DecodeData(req.Body) if err != nil { t.Fatalf("Cannot decode request body: %s", req.Body) } var response string - switch r.Method { + switch r.In.Method { case "getblockcount": getBlockCountCalled++ response = `{"jsonrpc":"2.0","id":1,"result":50}` @@ -1494,7 +1494,7 @@ func TestCalculateValidUntilBlock(t *testing.T) { getValidatorsCalled++ response = `{"id":1,"jsonrpc":"2.0","result":[{"publickey":"02b3622bf4017bdfe317c58aed5f4c753f206b7db896046fa7d774bbc4bf7f8dc2","votes":"0","active":true},{"publickey":"02103a7f7dd016558597f7960d27c516a4394fd968b9e65155eb4b013e4040406e","votes":"0","active":true},{"publickey":"03d90c07df63e690ce77912e10ab51acc944b66860237b608c4f8f8309e71ee699","votes":"0","active":true},{"publickey":"02a7bc55fe8684e0119768d104ba30795bdcc86619e864add26156723ed185cd62","votes":"0","active":true}]}` } - requestHandler(t, r, w, response) + requestHandler(t, r.In, w, response) })) defer srv.Close() @@ -1522,13 +1522,13 @@ func TestCalculateValidUntilBlock(t *testing.T) { func TestGetNetwork(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - r := request.NewIn() + r := request.NewRequest() err := r.DecodeData(req.Body) if err != nil { t.Fatalf("Cannot decode request body: %s", req.Body) } // request handler already have `getversion` response wrapper - requestHandler(t, r, w, "") + requestHandler(t, r.In, w, "") })) defer srv.Close() endpoint := srv.URL diff --git a/pkg/rpc/request/types.go b/pkg/rpc/request/types.go index 6882385f1..8da4f31b5 100644 --- a/pkg/rpc/request/types.go +++ b/pkg/rpc/request/types.go @@ -1,7 +1,9 @@ package request import ( + "bytes" "encoding/json" + "errors" "fmt" "io" ) @@ -9,6 +11,9 @@ import ( const ( // JSONRPCVersion is the only JSON-RPC protocol version supported. JSONRPCVersion = "2.0" + + // maxBatchSize is the maximum number of request per batch. + maxBatchSize = 100 ) // RawParams is just a slice of abstract values, used to represent parameters @@ -35,9 +40,16 @@ type Raw struct { ID int `json:"id"` } +// Request contains standard JSON-RPC 2.0 request and batch of +// requests: http://www.jsonrpc.org/specification. +// It's used in server to represent incoming queries. +type Request struct { + In *In + Batch Batch +} + // In represents a standard JSON-RPC 2.0 -// request: http://www.jsonrpc.org/specification#request_object. It's used in -// server to represent incoming queries. +// request: http://www.jsonrpc.org/specification#request_object. type In struct { JSONRPC string `json:"jsonrpc"` Method string `json:"method"` @@ -45,28 +57,82 @@ type In struct { RawID json.RawMessage `json:"id,omitempty"` } -// NewIn creates a new Request struct. -func NewIn() *In { - return &In{ - JSONRPC: JSONRPCVersion, +// Batch represents a standard JSON-RPC 2.0 +// batch: https://www.jsonrpc.org/specification#batch. +type Batch []In + +// MarshalJSON implements json.Marshaler interface +func (r Request) MarshalJSON() ([]byte, error) { + if r.In != nil { + return json.Marshal(r.In) } + return json.Marshal(r.Batch) +} + +// UnmarshalJSON implements json.Unmarshaler interface. +func (r *Request) UnmarshalJSON(data []byte) error { + var ( + in *In + batch Batch + ) + in = &In{} + err := json.Unmarshal(data, in) + if err == nil { + r.In = in + return nil + } + decoder := json.NewDecoder(bytes.NewReader(data)) + t, err := decoder.Token() // read `[` + if err != nil { + return err + } + if t != json.Delim('[') { + return fmt.Errorf("`[` expected, got %s", t) + } + count := 0 + for decoder.More() { + if count > maxBatchSize { + return fmt.Errorf("the number of requests in batch shouldn't exceed %d", maxBatchSize) + } + in = &In{} + decodeErr := decoder.Decode(in) + if decodeErr != nil { + return decodeErr + } + batch = append(batch, *in) + count++ + } + if len(batch) == 0 { + return errors.New("empty request") + } + r.Batch = batch + return nil } // DecodeData decodes the given reader into the the request // struct. -func (r *In) DecodeData(data io.ReadCloser) error { +func (r *Request) DecodeData(data io.ReadCloser) error { defer data.Close() - err := json.NewDecoder(data).Decode(r) + rawData := json.RawMessage{} + err := json.NewDecoder(data).Decode(&rawData) if err != nil { return fmt.Errorf("error parsing JSON payload: %w", err) } - if r.JSONRPC != JSONRPCVersion { - return fmt.Errorf("invalid version, expected 2.0 got: '%s'", r.JSONRPC) - } + return r.UnmarshalJSON(rawData) +} - return nil +// NewRequest creates a new Request struct. +func NewRequest() *Request { + return &Request{} +} + +// NewIn creates a new In struct. +func NewIn() *In { + return &In{ + JSONRPC: JSONRPCVersion, + } } // Params takes a slice of any type and attempts to bind diff --git a/pkg/rpc/response/types.go b/pkg/rpc/response/types.go index 1cf09bddc..a55bab21f 100644 --- a/pkg/rpc/response/types.go +++ b/pkg/rpc/response/types.go @@ -24,6 +24,12 @@ type Raw struct { Result json.RawMessage `json:"result,omitempty"` } +// AbstractResult is an interface which represents either single JSON-RPC 2.0 response +// or batch JSON-RPC 2.0 response. +type AbstractResult interface { + RunForErrors(f func(jsonErr *Error)) +} + // Abstract represents abstract JSON-RPC 2.0 response, it differs from Raw in // that Result field is an interface here. type Abstract struct { @@ -31,6 +37,25 @@ type Abstract struct { Result interface{} `json:"result,omitempty"` } +// RunForErrors implements AbstractResult interface. +func (a Abstract) RunForErrors(f func(jsonErr *Error)) { + if a.Error != nil { + f(a.Error) + } +} + +// AbstractBatch represents abstract JSON-RPC 2.0 batch-response. +type AbstractBatch []Abstract + +// RunForErrors implements AbstractResult interface. +func (ab AbstractBatch) RunForErrors(f func(jsonErr *Error)) { + for _, a := range ab { + if a.Error != nil { + f(a.Error) + } + } +} + // Notification is a type used to represent wire format of events, they're // special in that they look like requests but they don't have IDs and their // "method" is actually an event name. diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index 92ef4b7d9..718851101 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -225,7 +225,7 @@ func (s *Server) Shutdown() error { } func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) { - req := request.NewIn() + req := request.NewRequest() if httpRequest.URL.Path == "/ws" && httpRequest.Method == "GET" { // Technically there is a race between this check and @@ -237,7 +237,7 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ s.subsLock.RUnlock() if numOfSubs >= maxSubscribers { s.writeHTTPErrorResponse( - req, + request.NewIn(), w, response.NewInternalServerError("websocket users limit reached", nil), ) @@ -248,7 +248,7 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ s.log.Info("websocket connection upgrade failed", zap.Error(err)) return } - resChan := make(chan response.Abstract) + resChan := make(chan response.AbstractResult) // response.Abstract or response.AbstractBatch subChan := make(chan *websocket.PreparedMessage, notificationBufSize) subscr := &subscriber{writer: subChan, ws: ws} s.subsLock.Lock() @@ -261,7 +261,7 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ if httpRequest.Method != "POST" { s.writeHTTPErrorResponse( - req, + request.NewIn(), w, response.NewInvalidParamsError( fmt.Sprintf("Invalid method '%s', please retry with 'POST'", httpRequest.Method), nil, @@ -272,7 +272,7 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ err := req.DecodeData(httpRequest.Body) if err != nil { - s.writeHTTPErrorResponse(req, w, response.NewParseError("Problem parsing JSON-RPC request body", err)) + s.writeHTTPErrorResponse(request.NewIn(), w, response.NewParseError("Problem parsing JSON-RPC request body", err)) return } @@ -280,9 +280,23 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ s.writeHTTPServerResponse(req, w, resp) } -func (s *Server) handleRequest(req *request.In, sub *subscriber) response.Abstract { +func (s *Server) handleRequest(req *request.Request, sub *subscriber) response.AbstractResult { + if req.In != nil { + return s.handleIn(req.In, sub) + } + resp := make(response.AbstractBatch, len(req.Batch)) + for i, in := range req.Batch { + resp[i] = s.handleIn(&in, sub) + } + return resp +} + +func (s *Server) handleIn(req *request.In, sub *subscriber) response.Abstract { var res interface{} var resErr *response.Error + if req.JSONRPC != request.JSONRPCVersion { + return s.packResponse(req, nil, response.NewInvalidParamsError("Problem parsing JSON", fmt.Errorf("invalid version, expected 2.0 got: '%s'", req.JSONRPC))) + } reqParams, err := req.Params() if err != nil { @@ -308,7 +322,7 @@ func (s *Server) handleRequest(req *request.In, sub *subscriber) response.Abstra return s.packResponse(req, res, resErr) } -func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Abstract, subChan <-chan *websocket.PreparedMessage) { +func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.AbstractResult, subChan <-chan *websocket.PreparedMessage) { pingTicker := time.NewTicker(wsPingPeriod) eventloop: for { @@ -355,21 +369,21 @@ drainloop: } } -func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.Abstract, subscr *subscriber) { +func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.AbstractResult, subscr *subscriber) { ws.SetReadLimit(wsReadLimit) ws.SetReadDeadline(time.Now().Add(wsPongLimit)) ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil }) requestloop: for { - req := new(request.In) + req := request.NewRequest() err := ws.ReadJSON(req) if err != nil { break } res := s.handleRequest(req, subscr) - if res.Error != nil { - s.logRequestError(req, res.Error) - } + res.RunForErrors(func(jsonErr *response.Error) { + s.logRequestError(req, jsonErr) + }) select { case <-s.shutdown: break requestloop @@ -1388,15 +1402,17 @@ func (s *Server) packResponse(r *request.In, result interface{}, respErr *respon } // logRequestError is a request error logger. -func (s *Server) logRequestError(r *request.In, jsonErr *response.Error) { +func (s *Server) logRequestError(r *request.Request, jsonErr *response.Error) { logFields := []zap.Field{ zap.Error(jsonErr.Cause), - zap.String("method", r.Method), } - params, err := r.Params() - if err == nil { - logFields = append(logFields, zap.Any("params", params)) + if r.In != nil { + logFields = append(logFields, zap.String("method", r.In.Method)) + params, err := r.In.Params() + if err == nil { + logFields = append(logFields, zap.Any("params", params)) + } } s.log.Error("Error encountered with rpc request", logFields...) @@ -1405,14 +1421,19 @@ func (s *Server) logRequestError(r *request.In, jsonErr *response.Error) { // writeHTTPErrorResponse writes an error response to the ResponseWriter. func (s *Server) writeHTTPErrorResponse(r *request.In, w http.ResponseWriter, jsonErr *response.Error) { resp := s.packResponse(r, nil, jsonErr) - s.writeHTTPServerResponse(r, w, resp) + s.writeHTTPServerResponse(&request.Request{In: r}, w, resp) } -func (s *Server) writeHTTPServerResponse(r *request.In, w http.ResponseWriter, resp response.Abstract) { +func (s *Server) writeHTTPServerResponse(r *request.Request, w http.ResponseWriter, resp response.AbstractResult) { // Errors can happen in many places and we can only catch ALL of them here. - if resp.Error != nil { - s.logRequestError(r, resp.Error) - w.WriteHeader(resp.Error.HTTPCode) + resp.RunForErrors(func(jsonErr *response.Error) { + s.logRequestError(r, jsonErr) + }) + if r.In != nil { + resp := resp.(response.Abstract) + if resp.Error != nil { + w.WriteHeader(resp.Error.HTTPCode) + } } w.Header().Set("Content-Type", "application/json; charset=utf-8") if s.config.EnableCORSWorkaround { @@ -1424,9 +1445,15 @@ func (s *Server) writeHTTPServerResponse(r *request.In, w http.ResponseWriter, r err := encoder.Encode(resp) if err != nil { - s.log.Error("Error encountered while encoding response", - zap.String("err", err.Error()), - zap.String("method", r.Method)) + switch { + case r.In != nil: + s.log.Error("Error encountered while encoding response", + zap.String("err", err.Error()), + zap.String("method", r.In.Method)) + case r.Batch != nil: + s.log.Error("Error encountered while encoding batch response", + zap.String("err", err.Error())) + } } } diff --git a/pkg/rpc/server/server_test.go b/pkg/rpc/server/server_test.go index 66f5a02ba..318cec2e4 100644 --- a/pkg/rpc/server/server_test.go +++ b/pkg/rpc/server/server_test.go @@ -778,31 +778,112 @@ func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) [] defer rpcSrv.Shutdown() e := &executor{chain: chain, httpSrv: httpSrv} - for method, cases := range rpcTestCases { - t.Run(method, func(t *testing.T) { - rpc := `{"jsonrpc": "2.0", "id": 1, "method": "%s", "params": %s}` + t.Run("single request", func(t *testing.T) { + for method, cases := range rpcTestCases { + t.Run(method, func(t *testing.T) { + rpc := `{"jsonrpc": "2.0", "id": 1, "method": "%s", "params": %s}` - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), httpSrv.URL, t) - result := checkErrGetResult(t, body, tc.fail) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), httpSrv.URL, t) + result := checkErrGetResult(t, body, tc.fail) + if tc.fail { + return + } + + expected, res := tc.getResultPair(e) + err := json.Unmarshal(result, res) + require.NoErrorf(t, err, "could not parse response: %s", result) + + if tc.check == nil { + assert.Equal(t, expected, res) + } else { + tc.check(t, e, res) + } + }) + } + }) + } + }) + t.Run("batch with single request", func(t *testing.T) { + for method, cases := range rpcTestCases { + if method == "sendrawtransaction" { + continue // cannot send the same transaction twice + } + t.Run(method, func(t *testing.T) { + rpc := `[{"jsonrpc": "2.0", "id": 1, "method": "%s", "params": %s}]` + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), httpSrv.URL, t) + result := checkErrGetBatchResult(t, body, tc.fail) + if tc.fail { + return + } + + expected, res := tc.getResultPair(e) + err := json.Unmarshal(result, res) + require.NoErrorf(t, err, "could not parse response: %s", result) + + if tc.check == nil { + assert.Equal(t, expected, res) + } else { + tc.check(t, e, res) + } + }) + } + }) + } + }) + + t.Run("batch with multiple requests", func(t *testing.T) { + for method, cases := range rpcTestCases { + if method == "sendrawtransaction" { + continue // cannot send the same transaction twice + } + t.Run(method, func(t *testing.T) { + rpc := `{"jsonrpc": "2.0", "id": %d, "method": "%s", "params": %s},` + var resultRPC string + for i, tc := range cases { + resultRPC += fmt.Sprintf(rpc, i, method, tc.params) + } + resultRPC = `[` + resultRPC[:len(resultRPC)-1] + `]` + body := doRPCCall(resultRPC, httpSrv.URL, t) + var responses []response.Raw + err := json.Unmarshal(body, &responses) + require.Nil(t, err) + for i, tc := range cases { + var resp response.Raw + for _, r := range responses { + if bytes.Equal(r.ID, []byte(strconv.Itoa(i))) { + resp = r + break + } + } + if tc.fail { + require.NotNil(t, resp.Error) + assert.NotEqual(t, 0, resp.Error.Code) + assert.NotEqual(t, "", resp.Error.Message) + } else { + assert.Nil(t, resp.Error) + } if tc.fail { return } - expected, res := tc.getResultPair(e) - err := json.Unmarshal(result, res) - require.NoErrorf(t, err, "could not parse response: %s", result) + err := json.Unmarshal(resp.Result, res) + require.NoErrorf(t, err, "could not parse response: %s", resp.Result) if tc.check == nil { assert.Equal(t, expected, res) } else { tc.check(t, e, res) } - }) - } - }) - } + } + + }) + } + }) t.Run("getapplicationlog for block", func(t *testing.T) { rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getapplicationlog", "params": ["%s"]}` @@ -1075,6 +1156,21 @@ func checkErrGetResult(t *testing.T, body []byte, expectingFail bool) json.RawMe return resp.Result } +func checkErrGetBatchResult(t *testing.T, body []byte, expectingFail bool) json.RawMessage { + var resp []response.Raw + err := json.Unmarshal(body, &resp) + require.Nil(t, err) + require.Equal(t, 1, len(resp)) + if expectingFail { + require.NotNil(t, resp[0].Error) + assert.NotEqual(t, 0, resp[0].Error.Code) + assert.NotEqual(t, "", resp[0].Error.Message) + } else { + assert.Nil(t, resp[0].Error) + } + return resp[0].Result +} + func doRPCCallOverWS(rpcCall string, url string, t *testing.T) []byte { dialer := websocket.Dialer{HandshakeTimeout: time.Second} url = "ws" + strings.TrimPrefix(url, "http")