diff --git a/docs/notifications.md b/docs/notifications.md index 119130ed1..68889780b 100644 --- a/docs/notifications.md +++ b/docs/notifications.md @@ -67,11 +67,11 @@ omitted if empty). Recognized stream names: * `block_added` - Filter: `primary` as an integer with primary (speaker) node index from - ConsensusData and/or `since` field as an integer value with block - index starting from which new block notifications will be received and/or - `till` field as an integer values containing block index till which new - block notifications will be received. + Filter: `primary` as an integer with a valid range of 0-255 with + primary (speaker) node index from ConsensusData and/or `since` field as + an integer value with block index starting from which new block + notifications will be received and/or `till` field as an integer values + containing block index till which new block notifications will be received. * `header_of_added_block` Filter: `primary` as an integer with primary (speaker) node index from ConsensusData and/or `since` field as an integer value with header @@ -85,7 +85,8 @@ Recognized stream names: * `notification_from_execution` Filter: `contract` field containing a string with hex-encoded Uint160 (LE representation) and/or `name` field containing a string with execution - notification name. + notification name which should be a valid UTF-8 string not longer than + 32 bytes. * `transaction_executed` Filter: `state` field containing `HALT` or `FAULT` string for successful and failed executions respectively and/or `container` field containing diff --git a/pkg/neorpc/filters.go b/pkg/neorpc/filters.go index 255d94e86..f05cf9343 100644 --- a/pkg/neorpc/filters.go +++ b/pkg/neorpc/filters.go @@ -1,8 +1,13 @@ package neorpc import ( + "errors" + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/core/interop/runtime" "github.com/nspcc-dev/neo-go/pkg/core/mempoolevent" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" ) type ( @@ -11,7 +16,7 @@ type ( // since/till the specified index inclusively). nil value treated as missing // filter. BlockFilter struct { - Primary *int `json:"primary,omitempty"` + Primary *byte `json:"primary,omitempty"` Since *uint32 `json:"since,omitempty"` Till *uint32 `json:"till,omitempty"` } @@ -49,6 +54,16 @@ type ( } ) +// SubscriptionFilter is an interface for all subscription filters. +type SubscriptionFilter interface { + // IsValid checks whether the filter is valid and returns + // a specific [ErrInvalidSubscriptionFilter] error if not. + IsValid() error +} + +// ErrInvalidSubscriptionFilter is returned when the subscription filter is invalid. +var ErrInvalidSubscriptionFilter = errors.New("invalid subscription filter") + // Copy creates a deep copy of the BlockFilter. It handles nil BlockFilter correctly. func (f *BlockFilter) Copy() *BlockFilter { if f == nil { @@ -56,7 +71,7 @@ func (f *BlockFilter) Copy() *BlockFilter { } var res = new(BlockFilter) if f.Primary != nil { - res.Primary = new(int) + res.Primary = new(byte) *res.Primary = *f.Primary } if f.Since != nil { @@ -70,6 +85,11 @@ func (f *BlockFilter) Copy() *BlockFilter { return res } +// IsValid implements SubscriptionFilter interface. +func (f BlockFilter) IsValid() error { + return nil +} + // Copy creates a deep copy of the TxFilter. It handles nil TxFilter correctly. func (f *TxFilter) Copy() *TxFilter { if f == nil { @@ -87,6 +107,11 @@ func (f *TxFilter) Copy() *TxFilter { return res } +// IsValid implements SubscriptionFilter interface. +func (f TxFilter) IsValid() error { + return nil +} + // Copy creates a deep copy of the NotificationFilter. It handles nil NotificationFilter correctly. func (f *NotificationFilter) Copy() *NotificationFilter { if f == nil { @@ -104,6 +129,14 @@ func (f *NotificationFilter) Copy() *NotificationFilter { return res } +// IsValid implements SubscriptionFilter interface. +func (f NotificationFilter) IsValid() error { + if f.Name != nil && len(*f.Name) > runtime.MaxEventNameLen { + return fmt.Errorf("%w: NotificationFilter name parameter must be less than %d", ErrInvalidSubscriptionFilter, runtime.MaxEventNameLen) + } + return nil +} + // Copy creates a deep copy of the ExecutionFilter. It handles nil ExecutionFilter correctly. func (f *ExecutionFilter) Copy() *ExecutionFilter { if f == nil { @@ -121,6 +154,17 @@ func (f *ExecutionFilter) Copy() *ExecutionFilter { return res } +// IsValid implements SubscriptionFilter interface. +func (f ExecutionFilter) IsValid() error { + if f.State != nil { + if *f.State != vmstate.Halt.String() && *f.State != vmstate.Fault.String() { + return fmt.Errorf("%w: ExecutionFilter state parameter must be either %s or %s", ErrInvalidSubscriptionFilter, vmstate.Halt, vmstate.Fault) + } + } + + return nil +} + // Copy creates a deep copy of the NotaryRequestFilter. It handles nil NotaryRequestFilter correctly. func (f *NotaryRequestFilter) Copy() *NotaryRequestFilter { if f == nil { @@ -141,3 +185,8 @@ func (f *NotaryRequestFilter) Copy() *NotaryRequestFilter { } return res } + +// IsValid implements SubscriptionFilter interface. +func (f NotaryRequestFilter) IsValid() error { + return nil +} diff --git a/pkg/neorpc/filters_test.go b/pkg/neorpc/filters_test.go index 8f725a588..e38591c5d 100644 --- a/pkg/neorpc/filters_test.go +++ b/pkg/neorpc/filters_test.go @@ -16,12 +16,12 @@ func TestBlockFilterCopy(t *testing.T) { tf = bf.Copy() require.Equal(t, bf, tf) - bf.Primary = new(int) + bf.Primary = new(byte) *bf.Primary = 42 tf = bf.Copy() require.Equal(t, bf, tf) - *bf.Primary = 100500 + *bf.Primary = 100 require.NotEqual(t, bf, tf) bf.Since = new(uint32) diff --git a/pkg/neorpc/rpcevent/filter.go b/pkg/neorpc/rpcevent/filter.go index f2be233ad..9b0bd0c69 100644 --- a/pkg/neorpc/rpcevent/filter.go +++ b/pkg/neorpc/rpcevent/filter.go @@ -13,7 +13,7 @@ type ( // filter notifications. Comparator interface { EventID() neorpc.EventID - Filter() any + Filter() neorpc.SubscriptionFilter } // Container is an interface required from notification event to be able to // pass filter. @@ -42,7 +42,7 @@ func Matches(f Comparator, r Container) bool { } else { b = &r.EventPayload().(*block.Block).Header } - primaryOk := filt.Primary == nil || *filt.Primary == int(b.PrimaryIndex) + primaryOk := filt.Primary == nil || *filt.Primary == b.PrimaryIndex sinceOk := filt.Since == nil || *filt.Since <= b.Index tillOk := filt.Till == nil || b.Index <= *filt.Till return primaryOk && sinceOk && tillOk diff --git a/pkg/neorpc/rpcevent/filter_test.go b/pkg/neorpc/rpcevent/filter_test.go index 6d8d33aad..70cfa9225 100644 --- a/pkg/neorpc/rpcevent/filter_test.go +++ b/pkg/neorpc/rpcevent/filter_test.go @@ -18,7 +18,7 @@ import ( type ( testComparator struct { id neorpc.EventID - filter any + filter neorpc.SubscriptionFilter } testContainer struct { id neorpc.EventID @@ -29,7 +29,7 @@ type ( func (c testComparator) EventID() neorpc.EventID { return c.id } -func (c testComparator) Filter() any { +func (c testComparator) Filter() neorpc.SubscriptionFilter { return c.filter } func (c testContainer) EventID() neorpc.EventID { @@ -40,8 +40,8 @@ func (c testContainer) EventPayload() any { } func TestMatches(t *testing.T) { - primary := 1 - badPrimary := 2 + primary := byte(1) + badPrimary := byte(2) index := uint32(5) badHigherIndex := uint32(6) badLowerIndex := index - 1 diff --git a/pkg/rpcclient/wsclient.go b/pkg/rpcclient/wsclient.go index 962b26455..9674766da 100644 --- a/pkg/rpcclient/wsclient.go +++ b/pkg/rpcclient/wsclient.go @@ -127,7 +127,7 @@ func (r *blockReceiver) EventID() neorpc.EventID { } // Filter implements neorpc.Comparator interface. -func (r *blockReceiver) Filter() any { +func (r *blockReceiver) Filter() neorpc.SubscriptionFilter { if r.filter == nil { return nil } @@ -174,7 +174,7 @@ func (r *headerOfAddedBlockReceiver) EventID() neorpc.EventID { } // Filter implements neorpc.Comparator interface. -func (r *headerOfAddedBlockReceiver) Filter() any { +func (r *headerOfAddedBlockReceiver) Filter() neorpc.SubscriptionFilter { if r.filter == nil { return nil } @@ -220,7 +220,7 @@ func (r *txReceiver) EventID() neorpc.EventID { } // Filter implements neorpc.Comparator interface. -func (r *txReceiver) Filter() any { +func (r *txReceiver) Filter() neorpc.SubscriptionFilter { if r.filter == nil { return nil } @@ -267,7 +267,7 @@ func (r *executionNotificationReceiver) EventID() neorpc.EventID { } // Filter implements neorpc.Comparator interface. -func (r *executionNotificationReceiver) Filter() any { +func (r *executionNotificationReceiver) Filter() neorpc.SubscriptionFilter { if r.filter == nil { return nil } @@ -314,7 +314,7 @@ func (r *executionReceiver) EventID() neorpc.EventID { } // Filter implements neorpc.Comparator interface. -func (r *executionReceiver) Filter() any { +func (r *executionReceiver) Filter() neorpc.SubscriptionFilter { if r.filter == nil { return nil } @@ -361,7 +361,7 @@ func (r *notaryRequestReceiver) EventID() neorpc.EventID { } // Filter implements neorpc.Comparator interface. -func (r *notaryRequestReceiver) Filter() any { +func (r *notaryRequestReceiver) Filter() neorpc.SubscriptionFilter { if r.filter == nil { return nil } @@ -766,6 +766,11 @@ func (c *WSClient) makeWsRequest(r *neorpc.Request) (*neorpc.Response, error) { func (c *WSClient) performSubscription(params []any, rcvr notificationReceiver) (string, error) { var resp string + if flt := rcvr.Filter(); flt != nil { + if err := flt.IsValid(); err != nil { + return "", err + } + } if err := c.performRequest("subscribe", params, &resp); err != nil { return "", err } @@ -872,11 +877,6 @@ func (c *WSClient) ReceiveExecutions(flt *neorpc.ExecutionFilter, rcvr chan<- *s } params := []any{"transaction_executed"} if flt != nil { - if flt.State != nil { - if *flt.State != "HALT" && *flt.State != "FAULT" { - return "", errors.New("bad state parameter") - } - } flt = flt.Copy() params = append(params, *flt) } diff --git a/pkg/rpcclient/wsclient_test.go b/pkg/rpcclient/wsclient_test.go index 8247901ac..693783e79 100644 --- a/pkg/rpcclient/wsclient_test.go +++ b/pkg/rpcclient/wsclient_test.go @@ -369,7 +369,20 @@ func TestWSExecutionVMStateCheck(t *testing.T) { require.NoError(t, wsc.Init()) filter := "NONE" _, err = wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &filter}, make(chan *state.AppExecResult)) - require.Error(t, err) + require.ErrorIs(t, err, neorpc.ErrInvalidSubscriptionFilter) + wsc.Close() +} + +func TestWSExecutionNotificationNameCheck(t *testing.T) { + // Will answer successfully if request slips through. + srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) + require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID + require.NoError(t, wsc.Init()) + filter := "notification_from_execution_with_long_name" + _, err = wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Name: &filter}, make(chan *state.ContainedNotificationEvent)) + require.ErrorIs(t, err, neorpc.ErrInvalidSubscriptionFilter) wsc.Close() } @@ -381,7 +394,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { }{ {"block header primary", func(t *testing.T, wsc *WSClient) { - primary := 3 + primary := byte(3) _, err := wsc.ReceiveHeadersOfAddedBlocks(&neorpc.BlockFilter{Primary: &primary}, make(chan *block.Header)) require.NoError(t, err) }, @@ -389,7 +402,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { param := p.Value(1) filt := new(neorpc.BlockFilter) require.NoError(t, json.Unmarshal(param.RawMessage, filt)) - require.Equal(t, 3, *filt.Primary) + require.Equal(t, byte(3), *filt.Primary) require.Equal(t, (*uint32)(nil), filt.Since) require.Equal(t, (*uint32)(nil), filt.Till) }, @@ -404,7 +417,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { param := p.Value(1) filt := new(neorpc.BlockFilter) require.NoError(t, json.Unmarshal(param.RawMessage, filt)) - require.Equal(t, (*int)(nil), filt.Primary) + require.Equal(t, (*byte)(nil), filt.Primary) require.Equal(t, uint32(3), *filt.Since) require.Equal(t, (*uint32)(nil), filt.Till) }, @@ -419,7 +432,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { param := p.Value(1) filt := new(neorpc.BlockFilter) require.NoError(t, json.Unmarshal(param.RawMessage, filt)) - require.Equal(t, (*int)(nil), filt.Primary) + require.Equal(t, (*byte)(nil), filt.Primary) require.Equal(t, (*uint32)(nil), filt.Since) require.Equal(t, (uint32)(3), *filt.Till) }, @@ -428,7 +441,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { func(t *testing.T, wsc *WSClient) { var ( since uint32 = 3 - primary = 2 + primary = byte(2) till uint32 = 5 ) _, err := wsc.ReceiveHeadersOfAddedBlocks(&neorpc.BlockFilter{ @@ -442,14 +455,14 @@ func TestWSFilteredSubscriptions(t *testing.T) { param := p.Value(1) filt := new(neorpc.BlockFilter) require.NoError(t, json.Unmarshal(param.RawMessage, filt)) - require.Equal(t, 2, *filt.Primary) + require.Equal(t, byte(2), *filt.Primary) require.Equal(t, uint32(3), *filt.Since) require.Equal(t, uint32(5), *filt.Till) }, }, {"blocks primary", func(t *testing.T, wsc *WSClient) { - primary := 3 + primary := byte(3) _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Primary: &primary}, make(chan *block.Block)) require.NoError(t, err) }, @@ -457,7 +470,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { param := p.Value(1) filt := new(neorpc.BlockFilter) require.NoError(t, json.Unmarshal(param.RawMessage, filt)) - require.Equal(t, 3, *filt.Primary) + require.Equal(t, byte(3), *filt.Primary) require.Equal(t, (*uint32)(nil), filt.Since) require.Equal(t, (*uint32)(nil), filt.Till) }, @@ -472,7 +485,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { param := p.Value(1) filt := new(neorpc.BlockFilter) require.NoError(t, json.Unmarshal(param.RawMessage, filt)) - require.Equal(t, (*int)(nil), filt.Primary) + require.Equal(t, (*byte)(nil), filt.Primary) require.Equal(t, uint32(3), *filt.Since) require.Equal(t, (*uint32)(nil), filt.Till) }, @@ -487,7 +500,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { param := p.Value(1) filt := new(neorpc.BlockFilter) require.NoError(t, json.Unmarshal(param.RawMessage, filt)) - require.Equal(t, (*int)(nil), filt.Primary) + require.Equal(t, (*byte)(nil), filt.Primary) require.Equal(t, (*uint32)(nil), filt.Since) require.Equal(t, (uint32)(3), *filt.Till) }, @@ -496,7 +509,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { func(t *testing.T, wsc *WSClient) { var ( since uint32 = 3 - primary = 2 + primary = byte(2) till uint32 = 5 ) _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{ @@ -510,7 +523,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { param := p.Value(1) filt := new(neorpc.BlockFilter) require.NoError(t, json.Unmarshal(param.RawMessage, filt)) - require.Equal(t, 2, *filt.Primary) + require.Equal(t, byte(2), *filt.Primary) require.Equal(t, uint32(3), *filt.Since) require.Equal(t, uint32(5), *filt.Till) }, diff --git a/pkg/services/rpcsrv/client_test.go b/pkg/services/rpcsrv/client_test.go index b77b44b0f..8cdc51689 100644 --- a/pkg/services/rpcsrv/client_test.go +++ b/pkg/services/rpcsrv/client_test.go @@ -2113,9 +2113,9 @@ func TestWSClient_SubscriptionsCompat(t *testing.T) { blocks := getTestBlocks(t) bCount := uint32(0) - getData := func(t *testing.T) (*block.Block, int, util.Uint160, string, string) { + getData := func(t *testing.T) (*block.Block, byte, util.Uint160, string, string) { b1 := blocks[bCount] - primary := int(b1.PrimaryIndex) + primary := b1.PrimaryIndex tx := b1.Transactions[0] sender := tx.Sender() ntfName := "Transfer" diff --git a/pkg/services/rpcsrv/server.go b/pkg/services/rpcsrv/server.go index bb30f566d..5c4e2e9f6 100644 --- a/pkg/services/rpcsrv/server.go +++ b/pkg/services/rpcsrv/server.go @@ -2729,7 +2729,7 @@ func (s *Server) subscribe(reqParams params.Params, sub *subscriber) (any, *neor return nil, neorpc.WrapErrorWithData(neorpc.ErrInvalidParams, "P2PSigExtensions are disabled") } // Optional filter. - var filter any + var filter neorpc.SubscriptionFilter if p := reqParams.Value(1); p != nil { param := *p jd := json.NewDecoder(bytes.NewReader(param.RawMessage)) @@ -2754,16 +2754,18 @@ func (s *Server) subscribe(reqParams params.Params, sub *subscriber) (any, *neor case neorpc.ExecutionEventID: flt := new(neorpc.ExecutionFilter) err = jd.Decode(flt) - if err == nil && (flt.State == nil || (*flt.State == "HALT" || *flt.State == "FAULT")) { - filter = *flt - } else if err == nil { - err = errors.New("invalid state") - } + filter = *flt } if err != nil { return nil, neorpc.WrapErrorWithData(neorpc.ErrInvalidParams, err.Error()) } } + if filter != nil { + err = filter.IsValid() + if err != nil { + return nil, neorpc.WrapErrorWithData(neorpc.ErrInvalidParams, err.Error()) + } + } s.subsLock.Lock() var id int diff --git a/pkg/services/rpcsrv/subscription.go b/pkg/services/rpcsrv/subscription.go index ad271cae4..9e3b05893 100644 --- a/pkg/services/rpcsrv/subscription.go +++ b/pkg/services/rpcsrv/subscription.go @@ -28,7 +28,7 @@ type ( // feed stores subscriber's desired event ID with filter. feed struct { event neorpc.EventID - filter any + filter neorpc.SubscriptionFilter } ) @@ -38,7 +38,7 @@ func (f feed) EventID() neorpc.EventID { } // Filter implements neorpc.EventComparator interface and returns notification filter. -func (f feed) Filter() any { +func (f feed) Filter() neorpc.SubscriptionFilter { return f.filter } diff --git a/pkg/services/rpcsrv/subscription_test.go b/pkg/services/rpcsrv/subscription_test.go index c3d83eda2..857f3b2d1 100644 --- a/pkg/services/rpcsrv/subscription_test.go +++ b/pkg/services/rpcsrv/subscription_test.go @@ -676,3 +676,30 @@ func TestSubscriptionOverflow(t *testing.T) { finishedFlag.CompareAndSwap(false, true) c.Close() } + +func TestFilteredSubscriptions_InvalidFilter(t *testing.T) { + var cases = map[string]struct { + params string + }{ + "notification with long name": { + params: `["notification_from_execution", {"name":"notification_from_execution_with_long_name"}]`, + }, + "execution with invalid vm state": { + params: `["transaction_executed", {"state":"NOTHALT"}]`, + }, + } + chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t) + defer chain.Close() + defer rpcSrv.Shutdown() + + for name, this := range cases { + t.Run(name, func(t *testing.T) { + resp := callWSGetRaw(t, c, fmt.Sprintf(`{"jsonrpc": "2.0","method": "subscribe","params": %s,"id": 1}`, this.params), respMsgs) + require.NotNil(t, resp.Error) + require.Nil(t, resp.Result) + require.Contains(t, resp.Error.Error(), neorpc.ErrInvalidSubscriptionFilter.Error()) + }) + } + finishedFlag.CompareAndSwap(false, true) + c.Close() +}