diff --git a/pkg/rpc/client/wsclient.go b/pkg/rpc/client/wsclient.go index 2c144e3d6..24363f1dd 100644 --- a/pkg/rpc/client/wsclient.go +++ b/pkg/rpc/client/wsclient.go @@ -12,6 +12,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/rpc/response" "github.com/nspcc-dev/neo-go/pkg/rpc/response/result" + "github.com/nspcc-dev/neo-go/pkg/util" ) // WSClient is a websocket-enabled RPC client that can be used with appropriate @@ -239,30 +240,51 @@ func (c *WSClient) performUnsubscription(id string) error { } // SubscribeForNewBlocks adds subscription for new block events to this instance -// of client. -func (c *WSClient) SubscribeForNewBlocks() (string, error) { +// of client. It can filtered by primary consensus node index, nil value doesn't +// add any filters. +func (c *WSClient) SubscribeForNewBlocks(primary *int) (string, error) { params := request.NewRawParams("block_added") + if primary != nil { + params.Values = append(params.Values, request.BlockFilter{Primary: *primary}) + } return c.performSubscription(params) } // SubscribeForNewTransactions adds subscription for new transaction events to -// this instance of client. -func (c *WSClient) SubscribeForNewTransactions() (string, error) { +// this instance of client. It can be filtered by sender and/or cosigner, nil +// value is treated as missing filter. +func (c *WSClient) SubscribeForNewTransactions(sender *util.Uint160, cosigner *util.Uint160) (string, error) { params := request.NewRawParams("transaction_added") + if sender != nil || cosigner != nil { + params.Values = append(params.Values, request.TxFilter{Sender: sender, Cosigner: cosigner}) + } return c.performSubscription(params) } // SubscribeForExecutionNotifications adds subscription for notifications -// generated during transaction execution to this instance of client. -func (c *WSClient) SubscribeForExecutionNotifications() (string, error) { +// generated during transaction execution to this instance of client. It can be +// filtered by contract's hash (that emits notifications), nil value puts no such +// restrictions. +func (c *WSClient) SubscribeForExecutionNotifications(contract *util.Uint160) (string, error) { params := request.NewRawParams("notification_from_execution") + if contract != nil { + params.Values = append(params.Values, request.NotificationFilter{Contract: *contract}) + } return c.performSubscription(params) } // SubscribeForTransactionExecutions adds subscription for application execution -// results generated during transaction execution to this instance of client. -func (c *WSClient) SubscribeForTransactionExecutions() (string, error) { +// results generated during transaction execution to this instance of client. Can +// be filtered by state (HALT/FAULT) to check for successful or failing +// transactions, nil value means no filtering. +func (c *WSClient) SubscribeForTransactionExecutions(state *string) (string, error) { params := request.NewRawParams("transaction_executed") + if state != nil { + if *state != "HALT" && *state != "FAULT" { + return "", errors.New("bad state parameter") + } + params.Values = append(params.Values, request.ExecutionFilter{State: *state}) + } return c.performSubscription(params) } diff --git a/pkg/rpc/client/wsclient_test.go b/pkg/rpc/client/wsclient_test.go index 708c18d45..494d31072 100644 --- a/pkg/rpc/client/wsclient_test.go +++ b/pkg/rpc/client/wsclient_test.go @@ -8,6 +8,8 @@ import ( "time" "github.com/gorilla/websocket" + "github.com/nspcc-dev/neo-go/pkg/rpc/request" + "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/require" ) @@ -21,10 +23,18 @@ func TestWSClientClose(t *testing.T) { func TestWSClientSubscription(t *testing.T) { var cases = map[string]func(*WSClient) (string, error){ - "blocks": (*WSClient).SubscribeForNewBlocks, - "transactions": (*WSClient).SubscribeForNewTransactions, - "notifications": (*WSClient).SubscribeForExecutionNotifications, - "executions": (*WSClient).SubscribeForTransactionExecutions, + "blocks": func(wsc *WSClient) (string, error) { + return wsc.SubscribeForNewBlocks(nil) + }, + "transactions": func(wsc *WSClient) (string, error) { + return wsc.SubscribeForNewTransactions(nil, nil) + }, + "notifications": func(wsc *WSClient) (string, error) { + return wsc.SubscribeForExecutionNotifications(nil) + }, + "executions": func(wsc *WSClient) (string, error) { + return wsc.SubscribeForTransactionExecutions(nil) + }, } t.Run("good", func(t *testing.T) { for name, f := range cases { @@ -145,3 +155,145 @@ func TestWSClientEvents(t *testing.T) { // Connection closed by server. require.False(t, ok) } + +func TestWSExecutionVMStateCheck(t *testing.T) { + // Will answer successfully if request slips through. + srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) + defer srv.Close() + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + filter := "NONE" + _, err = wsc.SubscribeForTransactionExecutions(&filter) + require.Error(t, err) + wsc.Close() +} + +func TestWSFilteredSubscriptions(t *testing.T) { + var cases = []struct { + name string + clientCode func(*testing.T, *WSClient) + serverCode func(*testing.T, *request.Params) + }{ + {"blocks", + func(t *testing.T, wsc *WSClient) { + primary := 3 + _, err := wsc.SubscribeForNewBlocks(&primary) + require.NoError(t, err) + }, + func(t *testing.T, p *request.Params) { + param, ok := p.Value(1) + require.Equal(t, true, ok) + require.Equal(t, request.BlockFilterT, param.Type) + filt, ok := param.Value.(request.BlockFilter) + require.Equal(t, true, ok) + require.Equal(t, 3, filt.Primary) + }, + }, + {"transactions sender", + func(t *testing.T, wsc *WSClient) { + sender := util.Uint160{1, 2, 3, 4, 5} + _, err := wsc.SubscribeForNewTransactions(&sender, nil) + require.NoError(t, err) + }, + func(t *testing.T, p *request.Params) { + param, ok := p.Value(1) + require.Equal(t, true, ok) + require.Equal(t, request.TxFilterT, param.Type) + filt, ok := param.Value.(request.TxFilter) + require.Equal(t, true, ok) + require.Equal(t, util.Uint160{1, 2, 3, 4, 5}, *filt.Sender) + require.Nil(t, filt.Cosigner) + }, + }, + {"transactions cosigner", + func(t *testing.T, wsc *WSClient) { + cosigner := util.Uint160{0, 42} + _, err := wsc.SubscribeForNewTransactions(nil, &cosigner) + require.NoError(t, err) + }, + func(t *testing.T, p *request.Params) { + param, ok := p.Value(1) + require.Equal(t, true, ok) + require.Equal(t, request.TxFilterT, param.Type) + filt, ok := param.Value.(request.TxFilter) + require.Equal(t, true, ok) + require.Nil(t, filt.Sender) + require.Equal(t, util.Uint160{0, 42}, *filt.Cosigner) + }, + }, + {"transactions sender and cosigner", + func(t *testing.T, wsc *WSClient) { + sender := util.Uint160{1, 2, 3, 4, 5} + cosigner := util.Uint160{0, 42} + _, err := wsc.SubscribeForNewTransactions(&sender, &cosigner) + require.NoError(t, err) + }, + func(t *testing.T, p *request.Params) { + param, ok := p.Value(1) + require.Equal(t, true, ok) + require.Equal(t, request.TxFilterT, param.Type) + filt, ok := param.Value.(request.TxFilter) + require.Equal(t, true, ok) + require.Equal(t, util.Uint160{1, 2, 3, 4, 5}, *filt.Sender) + require.Equal(t, util.Uint160{0, 42}, *filt.Cosigner) + }, + }, + {"notifications", + func(t *testing.T, wsc *WSClient) { + contract := util.Uint160{1, 2, 3, 4, 5} + _, err := wsc.SubscribeForExecutionNotifications(&contract) + require.NoError(t, err) + }, + func(t *testing.T, p *request.Params) { + param, ok := p.Value(1) + require.Equal(t, true, ok) + require.Equal(t, request.NotificationFilterT, param.Type) + filt, ok := param.Value.(request.NotificationFilter) + require.Equal(t, true, ok) + require.Equal(t, util.Uint160{1, 2, 3, 4, 5}, filt.Contract) + }, + }, + {"executions", + func(t *testing.T, wsc *WSClient) { + state := "FAULT" + _, err := wsc.SubscribeForTransactionExecutions(&state) + require.NoError(t, err) + }, + func(t *testing.T, p *request.Params) { + param, ok := p.Value(1) + require.Equal(t, true, ok) + require.Equal(t, request.ExecutionFilterT, param.Type) + filt, ok := param.Value.(request.ExecutionFilter) + require.Equal(t, true, ok) + require.Equal(t, "FAULT", filt.State) + }, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.Path == "/ws" && req.Method == "GET" { + var upgrader = websocket.Upgrader{} + ws, err := upgrader.Upgrade(w, req, nil) + require.NoError(t, err) + ws.SetReadDeadline(time.Now().Add(2 * time.Second)) + req := request.In{} + err = ws.ReadJSON(&req) + require.NoError(t, err) + params, err := req.Params() + require.NoError(t, err) + c.serverCode(t, params) + ws.SetWriteDeadline(time.Now().Add(2 * time.Second)) + err = ws.WriteMessage(1, []byte(`{"jsonrpc": "2.0", "id": 1, "result": "0"}`)) + require.NoError(t, err) + ws.Close() + } + })) + defer srv.Close() + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + c.clientCode(t, wsc) + wsc.Close() + }) + } +} diff --git a/pkg/rpc/request/param.go b/pkg/rpc/request/param.go index 55052990e..f917a33ee 100644 --- a/pkg/rpc/request/param.go +++ b/pkg/rpc/request/param.go @@ -29,6 +29,29 @@ type ( Type smartcontract.ParamType `json:"type"` Value Param `json:"value"` } + // BlockFilter is a wrapper structure for block event filter. The only + // allowed filter is primary index. + BlockFilter struct { + Primary int `json:"primary"` + } + // TxFilter is a wrapper structure for transaction event filter. It + // allows to filter transactions by senders and cosigners. + TxFilter struct { + Sender *util.Uint160 `json:"sender,omitempty"` + Cosigner *util.Uint160 `json:"cosigner,omitempty"` + } + // NotificationFilter is a wrapper structure representing filter used for + // notifications generated during transaction execution. Notifications can + // only be filtered by contract hash. + NotificationFilter struct { + Contract util.Uint160 `json:"contract"` + } + // ExecutionFilter is a wrapper structure used for transaction execution + // events. It allows to choose failing or successful transactions based + // on their VM state. + ExecutionFilter struct { + State string `json:"state"` + } ) // These are parameter types accepted by RPC server. @@ -38,6 +61,10 @@ const ( NumberT ArrayT FuncParamT + BlockFilterT + TxFilterT + NotificationFilterT + ExecutionFilterT ) func (p Param) String() string { @@ -130,38 +157,46 @@ func (p Param) GetBytesHex() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler interface. func (p *Param) UnmarshalJSON(data []byte) error { var s string - if err := json.Unmarshal(data, &s); err == nil { - p.Type = StringT - p.Value = s - - return nil - } - var num float64 - if err := json.Unmarshal(data, &num); err == nil { - p.Type = NumberT - p.Value = int(num) - - return nil + // To unmarshal correctly we need to pass pointers into the decoder. + var attempts = [...]Param{ + {NumberT, &num}, + {StringT, &s}, + {FuncParamT, &FuncParam{}}, + {BlockFilterT, &BlockFilter{}}, + {TxFilterT, &TxFilter{}}, + {NotificationFilterT, &NotificationFilter{}}, + {ExecutionFilterT, &ExecutionFilter{}}, + {ArrayT, &[]Param{}}, } - r := bytes.NewReader(data) - jd := json.NewDecoder(r) - jd.DisallowUnknownFields() - var fp FuncParam - if err := jd.Decode(&fp); err == nil { - p.Type = FuncParamT - p.Value = fp - - return nil - } - - var ps []Param - if err := json.Unmarshal(data, &ps); err == nil { - p.Type = ArrayT - p.Value = ps - - return nil + for _, cur := range attempts { + r := bytes.NewReader(data) + jd := json.NewDecoder(r) + jd.DisallowUnknownFields() + if err := jd.Decode(cur.Value); err == nil { + p.Type = cur.Type + // But we need to store actual values, not pointers. + switch val := cur.Value.(type) { + case *float64: + p.Value = int(*val) + case *string: + p.Value = *val + case *FuncParam: + p.Value = *val + case *BlockFilter: + p.Value = *val + case *TxFilter: + p.Value = *val + case *NotificationFilter: + p.Value = *val + case *ExecutionFilter: + p.Value = *val + case *[]Param: + p.Value = *val + } + return nil + } } return errors.New("unknown type") diff --git a/pkg/rpc/request/param_test.go b/pkg/rpc/request/param_test.go index 26d29bab1..032173446 100644 --- a/pkg/rpc/request/param_test.go +++ b/pkg/rpc/request/param_test.go @@ -13,7 +13,15 @@ import ( ) func TestParam_UnmarshalJSON(t *testing.T) { - msg := `["str1", 123, ["str2", 3], [{"type": "String", "value": "jajaja"}]]` + msg := `["str1", 123, ["str2", 3], [{"type": "String", "value": "jajaja"}], + {"primary": 1}, + {"sender": "f84d6a337fbc3d3a201d41da99e86b479e7a2554"}, + {"cosigner": "f84d6a337fbc3d3a201d41da99e86b479e7a2554"}, + {"sender": "f84d6a337fbc3d3a201d41da99e86b479e7a2554", "cosigner": "f84d6a337fbc3d3a201d41da99e86b479e7a2554"}, + {"contract": "f84d6a337fbc3d3a201d41da99e86b479e7a2554"}, + {"state": "HALT"}]` + contr, err := util.Uint160DecodeStringLE("f84d6a337fbc3d3a201d41da99e86b479e7a2554") + require.NoError(t, err) expected := Params{ { Type: StringT, @@ -51,6 +59,30 @@ func TestParam_UnmarshalJSON(t *testing.T) { }, }, }, + { + Type: BlockFilterT, + Value: BlockFilter{Primary: 1}, + }, + { + Type: TxFilterT, + Value: TxFilter{Sender: &contr}, + }, + { + Type: TxFilterT, + Value: TxFilter{Cosigner: &contr}, + }, + { + Type: TxFilterT, + Value: TxFilter{Sender: &contr, Cosigner: &contr}, + }, + { + Type: NotificationFilterT, + Value: NotificationFilter{Contract: contr}, + }, + { + Type: ExecutionFilterT, + Value: ExecutionFilter{State: "HALT"}, + }, } var ps Params