diff --git a/docs/notifications.md b/docs/notifications.md index cd47e841b..68889780b 100644 --- a/docs/notifications.md +++ b/docs/notifications.md @@ -85,7 +85,8 @@ Recognized stream names: * `notification_from_execution` Filter: `contract` field containing a string with hex-encoded Uint160 (LE representation) and/or `name` field containing a string with execution - notification name. + notification name which should be a valid UTF-8 string not longer than + 32 bytes. * `transaction_executed` Filter: `state` field containing `HALT` or `FAULT` string for successful and failed executions respectively and/or `container` field containing diff --git a/pkg/neorpc/filters.go b/pkg/neorpc/filters.go index a85bfd88c..f05cf9343 100644 --- a/pkg/neorpc/filters.go +++ b/pkg/neorpc/filters.go @@ -1,8 +1,13 @@ package neorpc import ( + "errors" + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/core/interop/runtime" "github.com/nspcc-dev/neo-go/pkg/core/mempoolevent" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" ) type ( @@ -49,6 +54,16 @@ type ( } ) +// SubscriptionFilter is an interface for all subscription filters. +type SubscriptionFilter interface { + // IsValid checks whether the filter is valid and returns + // a specific [ErrInvalidSubscriptionFilter] error if not. + IsValid() error +} + +// ErrInvalidSubscriptionFilter is returned when the subscription filter is invalid. +var ErrInvalidSubscriptionFilter = errors.New("invalid subscription filter") + // Copy creates a deep copy of the BlockFilter. It handles nil BlockFilter correctly. func (f *BlockFilter) Copy() *BlockFilter { if f == nil { @@ -70,6 +85,11 @@ func (f *BlockFilter) Copy() *BlockFilter { return res } +// IsValid implements SubscriptionFilter interface. +func (f BlockFilter) IsValid() error { + return nil +} + // Copy creates a deep copy of the TxFilter. It handles nil TxFilter correctly. func (f *TxFilter) Copy() *TxFilter { if f == nil { @@ -87,6 +107,11 @@ func (f *TxFilter) Copy() *TxFilter { return res } +// IsValid implements SubscriptionFilter interface. +func (f TxFilter) IsValid() error { + return nil +} + // Copy creates a deep copy of the NotificationFilter. It handles nil NotificationFilter correctly. func (f *NotificationFilter) Copy() *NotificationFilter { if f == nil { @@ -104,6 +129,14 @@ func (f *NotificationFilter) Copy() *NotificationFilter { return res } +// IsValid implements SubscriptionFilter interface. +func (f NotificationFilter) IsValid() error { + if f.Name != nil && len(*f.Name) > runtime.MaxEventNameLen { + return fmt.Errorf("%w: NotificationFilter name parameter must be less than %d", ErrInvalidSubscriptionFilter, runtime.MaxEventNameLen) + } + return nil +} + // Copy creates a deep copy of the ExecutionFilter. It handles nil ExecutionFilter correctly. func (f *ExecutionFilter) Copy() *ExecutionFilter { if f == nil { @@ -121,6 +154,17 @@ func (f *ExecutionFilter) Copy() *ExecutionFilter { return res } +// IsValid implements SubscriptionFilter interface. +func (f ExecutionFilter) IsValid() error { + if f.State != nil { + if *f.State != vmstate.Halt.String() && *f.State != vmstate.Fault.String() { + return fmt.Errorf("%w: ExecutionFilter state parameter must be either %s or %s", ErrInvalidSubscriptionFilter, vmstate.Halt, vmstate.Fault) + } + } + + return nil +} + // Copy creates a deep copy of the NotaryRequestFilter. It handles nil NotaryRequestFilter correctly. func (f *NotaryRequestFilter) Copy() *NotaryRequestFilter { if f == nil { @@ -141,3 +185,8 @@ func (f *NotaryRequestFilter) Copy() *NotaryRequestFilter { } return res } + +// IsValid implements SubscriptionFilter interface. +func (f NotaryRequestFilter) IsValid() error { + return nil +} diff --git a/pkg/neorpc/rpcevent/filter.go b/pkg/neorpc/rpcevent/filter.go index 90696a2f0..9b0bd0c69 100644 --- a/pkg/neorpc/rpcevent/filter.go +++ b/pkg/neorpc/rpcevent/filter.go @@ -13,7 +13,7 @@ type ( // filter notifications. Comparator interface { EventID() neorpc.EventID - Filter() any + Filter() neorpc.SubscriptionFilter } // Container is an interface required from notification event to be able to // pass filter. diff --git a/pkg/neorpc/rpcevent/filter_test.go b/pkg/neorpc/rpcevent/filter_test.go index 956ab2686..70cfa9225 100644 --- a/pkg/neorpc/rpcevent/filter_test.go +++ b/pkg/neorpc/rpcevent/filter_test.go @@ -18,7 +18,7 @@ import ( type ( testComparator struct { id neorpc.EventID - filter any + filter neorpc.SubscriptionFilter } testContainer struct { id neorpc.EventID @@ -29,7 +29,7 @@ type ( func (c testComparator) EventID() neorpc.EventID { return c.id } -func (c testComparator) Filter() any { +func (c testComparator) Filter() neorpc.SubscriptionFilter { return c.filter } func (c testContainer) EventID() neorpc.EventID { diff --git a/pkg/rpcclient/wsclient.go b/pkg/rpcclient/wsclient.go index 962b26455..9674766da 100644 --- a/pkg/rpcclient/wsclient.go +++ b/pkg/rpcclient/wsclient.go @@ -127,7 +127,7 @@ func (r *blockReceiver) EventID() neorpc.EventID { } // Filter implements neorpc.Comparator interface. -func (r *blockReceiver) Filter() any { +func (r *blockReceiver) Filter() neorpc.SubscriptionFilter { if r.filter == nil { return nil } @@ -174,7 +174,7 @@ func (r *headerOfAddedBlockReceiver) EventID() neorpc.EventID { } // Filter implements neorpc.Comparator interface. -func (r *headerOfAddedBlockReceiver) Filter() any { +func (r *headerOfAddedBlockReceiver) Filter() neorpc.SubscriptionFilter { if r.filter == nil { return nil } @@ -220,7 +220,7 @@ func (r *txReceiver) EventID() neorpc.EventID { } // Filter implements neorpc.Comparator interface. -func (r *txReceiver) Filter() any { +func (r *txReceiver) Filter() neorpc.SubscriptionFilter { if r.filter == nil { return nil } @@ -267,7 +267,7 @@ func (r *executionNotificationReceiver) EventID() neorpc.EventID { } // Filter implements neorpc.Comparator interface. -func (r *executionNotificationReceiver) Filter() any { +func (r *executionNotificationReceiver) Filter() neorpc.SubscriptionFilter { if r.filter == nil { return nil } @@ -314,7 +314,7 @@ func (r *executionReceiver) EventID() neorpc.EventID { } // Filter implements neorpc.Comparator interface. -func (r *executionReceiver) Filter() any { +func (r *executionReceiver) Filter() neorpc.SubscriptionFilter { if r.filter == nil { return nil } @@ -361,7 +361,7 @@ func (r *notaryRequestReceiver) EventID() neorpc.EventID { } // Filter implements neorpc.Comparator interface. -func (r *notaryRequestReceiver) Filter() any { +func (r *notaryRequestReceiver) Filter() neorpc.SubscriptionFilter { if r.filter == nil { return nil } @@ -766,6 +766,11 @@ func (c *WSClient) makeWsRequest(r *neorpc.Request) (*neorpc.Response, error) { func (c *WSClient) performSubscription(params []any, rcvr notificationReceiver) (string, error) { var resp string + if flt := rcvr.Filter(); flt != nil { + if err := flt.IsValid(); err != nil { + return "", err + } + } if err := c.performRequest("subscribe", params, &resp); err != nil { return "", err } @@ -872,11 +877,6 @@ func (c *WSClient) ReceiveExecutions(flt *neorpc.ExecutionFilter, rcvr chan<- *s } params := []any{"transaction_executed"} if flt != nil { - if flt.State != nil { - if *flt.State != "HALT" && *flt.State != "FAULT" { - return "", errors.New("bad state parameter") - } - } flt = flt.Copy() params = append(params, *flt) } diff --git a/pkg/rpcclient/wsclient_test.go b/pkg/rpcclient/wsclient_test.go index a983e0609..693783e79 100644 --- a/pkg/rpcclient/wsclient_test.go +++ b/pkg/rpcclient/wsclient_test.go @@ -369,7 +369,20 @@ func TestWSExecutionVMStateCheck(t *testing.T) { require.NoError(t, wsc.Init()) filter := "NONE" _, err = wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &filter}, make(chan *state.AppExecResult)) - require.Error(t, err) + require.ErrorIs(t, err, neorpc.ErrInvalidSubscriptionFilter) + wsc.Close() +} + +func TestWSExecutionNotificationNameCheck(t *testing.T) { + // Will answer successfully if request slips through. + srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) + require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID + require.NoError(t, wsc.Init()) + filter := "notification_from_execution_with_long_name" + _, err = wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Name: &filter}, make(chan *state.ContainedNotificationEvent)) + require.ErrorIs(t, err, neorpc.ErrInvalidSubscriptionFilter) wsc.Close() } diff --git a/pkg/services/rpcsrv/server.go b/pkg/services/rpcsrv/server.go index bb30f566d..5c4e2e9f6 100644 --- a/pkg/services/rpcsrv/server.go +++ b/pkg/services/rpcsrv/server.go @@ -2729,7 +2729,7 @@ func (s *Server) subscribe(reqParams params.Params, sub *subscriber) (any, *neor return nil, neorpc.WrapErrorWithData(neorpc.ErrInvalidParams, "P2PSigExtensions are disabled") } // Optional filter. - var filter any + var filter neorpc.SubscriptionFilter if p := reqParams.Value(1); p != nil { param := *p jd := json.NewDecoder(bytes.NewReader(param.RawMessage)) @@ -2754,16 +2754,18 @@ func (s *Server) subscribe(reqParams params.Params, sub *subscriber) (any, *neor case neorpc.ExecutionEventID: flt := new(neorpc.ExecutionFilter) err = jd.Decode(flt) - if err == nil && (flt.State == nil || (*flt.State == "HALT" || *flt.State == "FAULT")) { - filter = *flt - } else if err == nil { - err = errors.New("invalid state") - } + filter = *flt } if err != nil { return nil, neorpc.WrapErrorWithData(neorpc.ErrInvalidParams, err.Error()) } } + if filter != nil { + err = filter.IsValid() + if err != nil { + return nil, neorpc.WrapErrorWithData(neorpc.ErrInvalidParams, err.Error()) + } + } s.subsLock.Lock() var id int diff --git a/pkg/services/rpcsrv/subscription.go b/pkg/services/rpcsrv/subscription.go index ad271cae4..9e3b05893 100644 --- a/pkg/services/rpcsrv/subscription.go +++ b/pkg/services/rpcsrv/subscription.go @@ -28,7 +28,7 @@ type ( // feed stores subscriber's desired event ID with filter. feed struct { event neorpc.EventID - filter any + filter neorpc.SubscriptionFilter } ) @@ -38,7 +38,7 @@ func (f feed) EventID() neorpc.EventID { } // Filter implements neorpc.EventComparator interface and returns notification filter. -func (f feed) Filter() any { +func (f feed) Filter() neorpc.SubscriptionFilter { return f.filter } diff --git a/pkg/services/rpcsrv/subscription_test.go b/pkg/services/rpcsrv/subscription_test.go index c3d83eda2..857f3b2d1 100644 --- a/pkg/services/rpcsrv/subscription_test.go +++ b/pkg/services/rpcsrv/subscription_test.go @@ -676,3 +676,30 @@ func TestSubscriptionOverflow(t *testing.T) { finishedFlag.CompareAndSwap(false, true) c.Close() } + +func TestFilteredSubscriptions_InvalidFilter(t *testing.T) { + var cases = map[string]struct { + params string + }{ + "notification with long name": { + params: `["notification_from_execution", {"name":"notification_from_execution_with_long_name"}]`, + }, + "execution with invalid vm state": { + params: `["transaction_executed", {"state":"NOTHALT"}]`, + }, + } + chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t) + defer chain.Close() + defer rpcSrv.Shutdown() + + for name, this := range cases { + t.Run(name, func(t *testing.T) { + resp := callWSGetRaw(t, c, fmt.Sprintf(`{"jsonrpc": "2.0","method": "subscribe","params": %s,"id": 1}`, this.params), respMsgs) + require.NotNil(t, resp.Error) + require.Nil(t, resp.Result) + require.Contains(t, resp.Error.Error(), neorpc.ErrInvalidSubscriptionFilter.Error()) + }) + } + finishedFlag.CompareAndSwap(false, true) + c.Close() +}