mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-11-21 23:29:38 +00:00
Merge pull request #3258 from nspcc-dev/rpc_validation
rpc: subscription filters validity check
This commit is contained in:
commit
3176f72878
11 changed files with 142 additions and 50 deletions
|
@ -67,11 +67,11 @@ omitted if empty).
|
||||||
|
|
||||||
Recognized stream names:
|
Recognized stream names:
|
||||||
* `block_added`
|
* `block_added`
|
||||||
Filter: `primary` as an integer with primary (speaker) node index from
|
Filter: `primary` as an integer with a valid range of 0-255 with
|
||||||
ConsensusData and/or `since` field as an integer value with block
|
primary (speaker) node index from ConsensusData and/or `since` field as
|
||||||
index starting from which new block notifications will be received and/or
|
an integer value with block index starting from which new block
|
||||||
`till` field as an integer values containing block index till which new
|
notifications will be received and/or `till` field as an integer values
|
||||||
block notifications will be received.
|
containing block index till which new block notifications will be received.
|
||||||
* `header_of_added_block`
|
* `header_of_added_block`
|
||||||
Filter: `primary` as an integer with primary (speaker) node index from
|
Filter: `primary` as an integer with primary (speaker) node index from
|
||||||
ConsensusData and/or `since` field as an integer value with header
|
ConsensusData and/or `since` field as an integer value with header
|
||||||
|
@ -85,7 +85,8 @@ Recognized stream names:
|
||||||
* `notification_from_execution`
|
* `notification_from_execution`
|
||||||
Filter: `contract` field containing a string with hex-encoded Uint160 (LE
|
Filter: `contract` field containing a string with hex-encoded Uint160 (LE
|
||||||
representation) and/or `name` field containing a string with execution
|
representation) and/or `name` field containing a string with execution
|
||||||
notification name.
|
notification name which should be a valid UTF-8 string not longer than
|
||||||
|
32 bytes.
|
||||||
* `transaction_executed`
|
* `transaction_executed`
|
||||||
Filter: `state` field containing `HALT` or `FAULT` string for successful
|
Filter: `state` field containing `HALT` or `FAULT` string for successful
|
||||||
and failed executions respectively and/or `container` field containing
|
and failed executions respectively and/or `container` field containing
|
||||||
|
|
|
@ -1,8 +1,13 @@
|
||||||
package neorpc
|
package neorpc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/interop/runtime"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/mempoolevent"
|
"github.com/nspcc-dev/neo-go/pkg/core/mempoolevent"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/vm/vmstate"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -11,7 +16,7 @@ type (
|
||||||
// since/till the specified index inclusively). nil value treated as missing
|
// since/till the specified index inclusively). nil value treated as missing
|
||||||
// filter.
|
// filter.
|
||||||
BlockFilter struct {
|
BlockFilter struct {
|
||||||
Primary *int `json:"primary,omitempty"`
|
Primary *byte `json:"primary,omitempty"`
|
||||||
Since *uint32 `json:"since,omitempty"`
|
Since *uint32 `json:"since,omitempty"`
|
||||||
Till *uint32 `json:"till,omitempty"`
|
Till *uint32 `json:"till,omitempty"`
|
||||||
}
|
}
|
||||||
|
@ -49,6 +54,16 @@ type (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SubscriptionFilter is an interface for all subscription filters.
|
||||||
|
type SubscriptionFilter interface {
|
||||||
|
// IsValid checks whether the filter is valid and returns
|
||||||
|
// a specific [ErrInvalidSubscriptionFilter] error if not.
|
||||||
|
IsValid() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrInvalidSubscriptionFilter is returned when the subscription filter is invalid.
|
||||||
|
var ErrInvalidSubscriptionFilter = errors.New("invalid subscription filter")
|
||||||
|
|
||||||
// Copy creates a deep copy of the BlockFilter. It handles nil BlockFilter correctly.
|
// Copy creates a deep copy of the BlockFilter. It handles nil BlockFilter correctly.
|
||||||
func (f *BlockFilter) Copy() *BlockFilter {
|
func (f *BlockFilter) Copy() *BlockFilter {
|
||||||
if f == nil {
|
if f == nil {
|
||||||
|
@ -56,7 +71,7 @@ func (f *BlockFilter) Copy() *BlockFilter {
|
||||||
}
|
}
|
||||||
var res = new(BlockFilter)
|
var res = new(BlockFilter)
|
||||||
if f.Primary != nil {
|
if f.Primary != nil {
|
||||||
res.Primary = new(int)
|
res.Primary = new(byte)
|
||||||
*res.Primary = *f.Primary
|
*res.Primary = *f.Primary
|
||||||
}
|
}
|
||||||
if f.Since != nil {
|
if f.Since != nil {
|
||||||
|
@ -70,6 +85,11 @@ func (f *BlockFilter) Copy() *BlockFilter {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsValid implements SubscriptionFilter interface.
|
||||||
|
func (f BlockFilter) IsValid() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Copy creates a deep copy of the TxFilter. It handles nil TxFilter correctly.
|
// Copy creates a deep copy of the TxFilter. It handles nil TxFilter correctly.
|
||||||
func (f *TxFilter) Copy() *TxFilter {
|
func (f *TxFilter) Copy() *TxFilter {
|
||||||
if f == nil {
|
if f == nil {
|
||||||
|
@ -87,6 +107,11 @@ func (f *TxFilter) Copy() *TxFilter {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsValid implements SubscriptionFilter interface.
|
||||||
|
func (f TxFilter) IsValid() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Copy creates a deep copy of the NotificationFilter. It handles nil NotificationFilter correctly.
|
// Copy creates a deep copy of the NotificationFilter. It handles nil NotificationFilter correctly.
|
||||||
func (f *NotificationFilter) Copy() *NotificationFilter {
|
func (f *NotificationFilter) Copy() *NotificationFilter {
|
||||||
if f == nil {
|
if f == nil {
|
||||||
|
@ -104,6 +129,14 @@ func (f *NotificationFilter) Copy() *NotificationFilter {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsValid implements SubscriptionFilter interface.
|
||||||
|
func (f NotificationFilter) IsValid() error {
|
||||||
|
if f.Name != nil && len(*f.Name) > runtime.MaxEventNameLen {
|
||||||
|
return fmt.Errorf("%w: NotificationFilter name parameter must be less than %d", ErrInvalidSubscriptionFilter, runtime.MaxEventNameLen)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Copy creates a deep copy of the ExecutionFilter. It handles nil ExecutionFilter correctly.
|
// Copy creates a deep copy of the ExecutionFilter. It handles nil ExecutionFilter correctly.
|
||||||
func (f *ExecutionFilter) Copy() *ExecutionFilter {
|
func (f *ExecutionFilter) Copy() *ExecutionFilter {
|
||||||
if f == nil {
|
if f == nil {
|
||||||
|
@ -121,6 +154,17 @@ func (f *ExecutionFilter) Copy() *ExecutionFilter {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsValid implements SubscriptionFilter interface.
|
||||||
|
func (f ExecutionFilter) IsValid() error {
|
||||||
|
if f.State != nil {
|
||||||
|
if *f.State != vmstate.Halt.String() && *f.State != vmstate.Fault.String() {
|
||||||
|
return fmt.Errorf("%w: ExecutionFilter state parameter must be either %s or %s", ErrInvalidSubscriptionFilter, vmstate.Halt, vmstate.Fault)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Copy creates a deep copy of the NotaryRequestFilter. It handles nil NotaryRequestFilter correctly.
|
// Copy creates a deep copy of the NotaryRequestFilter. It handles nil NotaryRequestFilter correctly.
|
||||||
func (f *NotaryRequestFilter) Copy() *NotaryRequestFilter {
|
func (f *NotaryRequestFilter) Copy() *NotaryRequestFilter {
|
||||||
if f == nil {
|
if f == nil {
|
||||||
|
@ -141,3 +185,8 @@ func (f *NotaryRequestFilter) Copy() *NotaryRequestFilter {
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsValid implements SubscriptionFilter interface.
|
||||||
|
func (f NotaryRequestFilter) IsValid() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -16,12 +16,12 @@ func TestBlockFilterCopy(t *testing.T) {
|
||||||
tf = bf.Copy()
|
tf = bf.Copy()
|
||||||
require.Equal(t, bf, tf)
|
require.Equal(t, bf, tf)
|
||||||
|
|
||||||
bf.Primary = new(int)
|
bf.Primary = new(byte)
|
||||||
*bf.Primary = 42
|
*bf.Primary = 42
|
||||||
|
|
||||||
tf = bf.Copy()
|
tf = bf.Copy()
|
||||||
require.Equal(t, bf, tf)
|
require.Equal(t, bf, tf)
|
||||||
*bf.Primary = 100500
|
*bf.Primary = 100
|
||||||
require.NotEqual(t, bf, tf)
|
require.NotEqual(t, bf, tf)
|
||||||
|
|
||||||
bf.Since = new(uint32)
|
bf.Since = new(uint32)
|
||||||
|
|
|
@ -13,7 +13,7 @@ type (
|
||||||
// filter notifications.
|
// filter notifications.
|
||||||
Comparator interface {
|
Comparator interface {
|
||||||
EventID() neorpc.EventID
|
EventID() neorpc.EventID
|
||||||
Filter() any
|
Filter() neorpc.SubscriptionFilter
|
||||||
}
|
}
|
||||||
// Container is an interface required from notification event to be able to
|
// Container is an interface required from notification event to be able to
|
||||||
// pass filter.
|
// pass filter.
|
||||||
|
@ -42,7 +42,7 @@ func Matches(f Comparator, r Container) bool {
|
||||||
} else {
|
} else {
|
||||||
b = &r.EventPayload().(*block.Block).Header
|
b = &r.EventPayload().(*block.Block).Header
|
||||||
}
|
}
|
||||||
primaryOk := filt.Primary == nil || *filt.Primary == int(b.PrimaryIndex)
|
primaryOk := filt.Primary == nil || *filt.Primary == b.PrimaryIndex
|
||||||
sinceOk := filt.Since == nil || *filt.Since <= b.Index
|
sinceOk := filt.Since == nil || *filt.Since <= b.Index
|
||||||
tillOk := filt.Till == nil || b.Index <= *filt.Till
|
tillOk := filt.Till == nil || b.Index <= *filt.Till
|
||||||
return primaryOk && sinceOk && tillOk
|
return primaryOk && sinceOk && tillOk
|
||||||
|
|
|
@ -18,7 +18,7 @@ import (
|
||||||
type (
|
type (
|
||||||
testComparator struct {
|
testComparator struct {
|
||||||
id neorpc.EventID
|
id neorpc.EventID
|
||||||
filter any
|
filter neorpc.SubscriptionFilter
|
||||||
}
|
}
|
||||||
testContainer struct {
|
testContainer struct {
|
||||||
id neorpc.EventID
|
id neorpc.EventID
|
||||||
|
@ -29,7 +29,7 @@ type (
|
||||||
func (c testComparator) EventID() neorpc.EventID {
|
func (c testComparator) EventID() neorpc.EventID {
|
||||||
return c.id
|
return c.id
|
||||||
}
|
}
|
||||||
func (c testComparator) Filter() any {
|
func (c testComparator) Filter() neorpc.SubscriptionFilter {
|
||||||
return c.filter
|
return c.filter
|
||||||
}
|
}
|
||||||
func (c testContainer) EventID() neorpc.EventID {
|
func (c testContainer) EventID() neorpc.EventID {
|
||||||
|
@ -40,8 +40,8 @@ func (c testContainer) EventPayload() any {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMatches(t *testing.T) {
|
func TestMatches(t *testing.T) {
|
||||||
primary := 1
|
primary := byte(1)
|
||||||
badPrimary := 2
|
badPrimary := byte(2)
|
||||||
index := uint32(5)
|
index := uint32(5)
|
||||||
badHigherIndex := uint32(6)
|
badHigherIndex := uint32(6)
|
||||||
badLowerIndex := index - 1
|
badLowerIndex := index - 1
|
||||||
|
|
|
@ -127,7 +127,7 @@ func (r *blockReceiver) EventID() neorpc.EventID {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter implements neorpc.Comparator interface.
|
// Filter implements neorpc.Comparator interface.
|
||||||
func (r *blockReceiver) Filter() any {
|
func (r *blockReceiver) Filter() neorpc.SubscriptionFilter {
|
||||||
if r.filter == nil {
|
if r.filter == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -174,7 +174,7 @@ func (r *headerOfAddedBlockReceiver) EventID() neorpc.EventID {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter implements neorpc.Comparator interface.
|
// Filter implements neorpc.Comparator interface.
|
||||||
func (r *headerOfAddedBlockReceiver) Filter() any {
|
func (r *headerOfAddedBlockReceiver) Filter() neorpc.SubscriptionFilter {
|
||||||
if r.filter == nil {
|
if r.filter == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -220,7 +220,7 @@ func (r *txReceiver) EventID() neorpc.EventID {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter implements neorpc.Comparator interface.
|
// Filter implements neorpc.Comparator interface.
|
||||||
func (r *txReceiver) Filter() any {
|
func (r *txReceiver) Filter() neorpc.SubscriptionFilter {
|
||||||
if r.filter == nil {
|
if r.filter == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -267,7 +267,7 @@ func (r *executionNotificationReceiver) EventID() neorpc.EventID {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter implements neorpc.Comparator interface.
|
// Filter implements neorpc.Comparator interface.
|
||||||
func (r *executionNotificationReceiver) Filter() any {
|
func (r *executionNotificationReceiver) Filter() neorpc.SubscriptionFilter {
|
||||||
if r.filter == nil {
|
if r.filter == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -314,7 +314,7 @@ func (r *executionReceiver) EventID() neorpc.EventID {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter implements neorpc.Comparator interface.
|
// Filter implements neorpc.Comparator interface.
|
||||||
func (r *executionReceiver) Filter() any {
|
func (r *executionReceiver) Filter() neorpc.SubscriptionFilter {
|
||||||
if r.filter == nil {
|
if r.filter == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -361,7 +361,7 @@ func (r *notaryRequestReceiver) EventID() neorpc.EventID {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter implements neorpc.Comparator interface.
|
// Filter implements neorpc.Comparator interface.
|
||||||
func (r *notaryRequestReceiver) Filter() any {
|
func (r *notaryRequestReceiver) Filter() neorpc.SubscriptionFilter {
|
||||||
if r.filter == nil {
|
if r.filter == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -766,6 +766,11 @@ func (c *WSClient) makeWsRequest(r *neorpc.Request) (*neorpc.Response, error) {
|
||||||
func (c *WSClient) performSubscription(params []any, rcvr notificationReceiver) (string, error) {
|
func (c *WSClient) performSubscription(params []any, rcvr notificationReceiver) (string, error) {
|
||||||
var resp string
|
var resp string
|
||||||
|
|
||||||
|
if flt := rcvr.Filter(); flt != nil {
|
||||||
|
if err := flt.IsValid(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
if err := c.performRequest("subscribe", params, &resp); err != nil {
|
if err := c.performRequest("subscribe", params, &resp); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -872,11 +877,6 @@ func (c *WSClient) ReceiveExecutions(flt *neorpc.ExecutionFilter, rcvr chan<- *s
|
||||||
}
|
}
|
||||||
params := []any{"transaction_executed"}
|
params := []any{"transaction_executed"}
|
||||||
if flt != nil {
|
if flt != nil {
|
||||||
if flt.State != nil {
|
|
||||||
if *flt.State != "HALT" && *flt.State != "FAULT" {
|
|
||||||
return "", errors.New("bad state parameter")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
flt = flt.Copy()
|
flt = flt.Copy()
|
||||||
params = append(params, *flt)
|
params = append(params, *flt)
|
||||||
}
|
}
|
||||||
|
|
|
@ -369,7 +369,20 @@ func TestWSExecutionVMStateCheck(t *testing.T) {
|
||||||
require.NoError(t, wsc.Init())
|
require.NoError(t, wsc.Init())
|
||||||
filter := "NONE"
|
filter := "NONE"
|
||||||
_, err = wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &filter}, make(chan *state.AppExecResult))
|
_, err = wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &filter}, make(chan *state.AppExecResult))
|
||||||
require.Error(t, err)
|
require.ErrorIs(t, err, neorpc.ErrInvalidSubscriptionFilter)
|
||||||
|
wsc.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWSExecutionNotificationNameCheck(t *testing.T) {
|
||||||
|
// Will answer successfully if request slips through.
|
||||||
|
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
|
||||||
|
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
wsc.getNextRequestID = getTestRequestID
|
||||||
|
require.NoError(t, wsc.Init())
|
||||||
|
filter := "notification_from_execution_with_long_name"
|
||||||
|
_, err = wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Name: &filter}, make(chan *state.ContainedNotificationEvent))
|
||||||
|
require.ErrorIs(t, err, neorpc.ErrInvalidSubscriptionFilter)
|
||||||
wsc.Close()
|
wsc.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -381,7 +394,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{"block header primary",
|
{"block header primary",
|
||||||
func(t *testing.T, wsc *WSClient) {
|
func(t *testing.T, wsc *WSClient) {
|
||||||
primary := 3
|
primary := byte(3)
|
||||||
_, err := wsc.ReceiveHeadersOfAddedBlocks(&neorpc.BlockFilter{Primary: &primary}, make(chan *block.Header))
|
_, err := wsc.ReceiveHeadersOfAddedBlocks(&neorpc.BlockFilter{Primary: &primary}, make(chan *block.Header))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
|
@ -389,7 +402,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
||||||
param := p.Value(1)
|
param := p.Value(1)
|
||||||
filt := new(neorpc.BlockFilter)
|
filt := new(neorpc.BlockFilter)
|
||||||
require.NoError(t, json.Unmarshal(param.RawMessage, filt))
|
require.NoError(t, json.Unmarshal(param.RawMessage, filt))
|
||||||
require.Equal(t, 3, *filt.Primary)
|
require.Equal(t, byte(3), *filt.Primary)
|
||||||
require.Equal(t, (*uint32)(nil), filt.Since)
|
require.Equal(t, (*uint32)(nil), filt.Since)
|
||||||
require.Equal(t, (*uint32)(nil), filt.Till)
|
require.Equal(t, (*uint32)(nil), filt.Till)
|
||||||
},
|
},
|
||||||
|
@ -404,7 +417,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
||||||
param := p.Value(1)
|
param := p.Value(1)
|
||||||
filt := new(neorpc.BlockFilter)
|
filt := new(neorpc.BlockFilter)
|
||||||
require.NoError(t, json.Unmarshal(param.RawMessage, filt))
|
require.NoError(t, json.Unmarshal(param.RawMessage, filt))
|
||||||
require.Equal(t, (*int)(nil), filt.Primary)
|
require.Equal(t, (*byte)(nil), filt.Primary)
|
||||||
require.Equal(t, uint32(3), *filt.Since)
|
require.Equal(t, uint32(3), *filt.Since)
|
||||||
require.Equal(t, (*uint32)(nil), filt.Till)
|
require.Equal(t, (*uint32)(nil), filt.Till)
|
||||||
},
|
},
|
||||||
|
@ -419,7 +432,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
||||||
param := p.Value(1)
|
param := p.Value(1)
|
||||||
filt := new(neorpc.BlockFilter)
|
filt := new(neorpc.BlockFilter)
|
||||||
require.NoError(t, json.Unmarshal(param.RawMessage, filt))
|
require.NoError(t, json.Unmarshal(param.RawMessage, filt))
|
||||||
require.Equal(t, (*int)(nil), filt.Primary)
|
require.Equal(t, (*byte)(nil), filt.Primary)
|
||||||
require.Equal(t, (*uint32)(nil), filt.Since)
|
require.Equal(t, (*uint32)(nil), filt.Since)
|
||||||
require.Equal(t, (uint32)(3), *filt.Till)
|
require.Equal(t, (uint32)(3), *filt.Till)
|
||||||
},
|
},
|
||||||
|
@ -428,7 +441,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
||||||
func(t *testing.T, wsc *WSClient) {
|
func(t *testing.T, wsc *WSClient) {
|
||||||
var (
|
var (
|
||||||
since uint32 = 3
|
since uint32 = 3
|
||||||
primary = 2
|
primary = byte(2)
|
||||||
till uint32 = 5
|
till uint32 = 5
|
||||||
)
|
)
|
||||||
_, err := wsc.ReceiveHeadersOfAddedBlocks(&neorpc.BlockFilter{
|
_, err := wsc.ReceiveHeadersOfAddedBlocks(&neorpc.BlockFilter{
|
||||||
|
@ -442,14 +455,14 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
||||||
param := p.Value(1)
|
param := p.Value(1)
|
||||||
filt := new(neorpc.BlockFilter)
|
filt := new(neorpc.BlockFilter)
|
||||||
require.NoError(t, json.Unmarshal(param.RawMessage, filt))
|
require.NoError(t, json.Unmarshal(param.RawMessage, filt))
|
||||||
require.Equal(t, 2, *filt.Primary)
|
require.Equal(t, byte(2), *filt.Primary)
|
||||||
require.Equal(t, uint32(3), *filt.Since)
|
require.Equal(t, uint32(3), *filt.Since)
|
||||||
require.Equal(t, uint32(5), *filt.Till)
|
require.Equal(t, uint32(5), *filt.Till)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{"blocks primary",
|
{"blocks primary",
|
||||||
func(t *testing.T, wsc *WSClient) {
|
func(t *testing.T, wsc *WSClient) {
|
||||||
primary := 3
|
primary := byte(3)
|
||||||
_, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Primary: &primary}, make(chan *block.Block))
|
_, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Primary: &primary}, make(chan *block.Block))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
|
@ -457,7 +470,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
||||||
param := p.Value(1)
|
param := p.Value(1)
|
||||||
filt := new(neorpc.BlockFilter)
|
filt := new(neorpc.BlockFilter)
|
||||||
require.NoError(t, json.Unmarshal(param.RawMessage, filt))
|
require.NoError(t, json.Unmarshal(param.RawMessage, filt))
|
||||||
require.Equal(t, 3, *filt.Primary)
|
require.Equal(t, byte(3), *filt.Primary)
|
||||||
require.Equal(t, (*uint32)(nil), filt.Since)
|
require.Equal(t, (*uint32)(nil), filt.Since)
|
||||||
require.Equal(t, (*uint32)(nil), filt.Till)
|
require.Equal(t, (*uint32)(nil), filt.Till)
|
||||||
},
|
},
|
||||||
|
@ -472,7 +485,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
||||||
param := p.Value(1)
|
param := p.Value(1)
|
||||||
filt := new(neorpc.BlockFilter)
|
filt := new(neorpc.BlockFilter)
|
||||||
require.NoError(t, json.Unmarshal(param.RawMessage, filt))
|
require.NoError(t, json.Unmarshal(param.RawMessage, filt))
|
||||||
require.Equal(t, (*int)(nil), filt.Primary)
|
require.Equal(t, (*byte)(nil), filt.Primary)
|
||||||
require.Equal(t, uint32(3), *filt.Since)
|
require.Equal(t, uint32(3), *filt.Since)
|
||||||
require.Equal(t, (*uint32)(nil), filt.Till)
|
require.Equal(t, (*uint32)(nil), filt.Till)
|
||||||
},
|
},
|
||||||
|
@ -487,7 +500,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
||||||
param := p.Value(1)
|
param := p.Value(1)
|
||||||
filt := new(neorpc.BlockFilter)
|
filt := new(neorpc.BlockFilter)
|
||||||
require.NoError(t, json.Unmarshal(param.RawMessage, filt))
|
require.NoError(t, json.Unmarshal(param.RawMessage, filt))
|
||||||
require.Equal(t, (*int)(nil), filt.Primary)
|
require.Equal(t, (*byte)(nil), filt.Primary)
|
||||||
require.Equal(t, (*uint32)(nil), filt.Since)
|
require.Equal(t, (*uint32)(nil), filt.Since)
|
||||||
require.Equal(t, (uint32)(3), *filt.Till)
|
require.Equal(t, (uint32)(3), *filt.Till)
|
||||||
},
|
},
|
||||||
|
@ -496,7 +509,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
||||||
func(t *testing.T, wsc *WSClient) {
|
func(t *testing.T, wsc *WSClient) {
|
||||||
var (
|
var (
|
||||||
since uint32 = 3
|
since uint32 = 3
|
||||||
primary = 2
|
primary = byte(2)
|
||||||
till uint32 = 5
|
till uint32 = 5
|
||||||
)
|
)
|
||||||
_, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{
|
_, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{
|
||||||
|
@ -510,7 +523,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
||||||
param := p.Value(1)
|
param := p.Value(1)
|
||||||
filt := new(neorpc.BlockFilter)
|
filt := new(neorpc.BlockFilter)
|
||||||
require.NoError(t, json.Unmarshal(param.RawMessage, filt))
|
require.NoError(t, json.Unmarshal(param.RawMessage, filt))
|
||||||
require.Equal(t, 2, *filt.Primary)
|
require.Equal(t, byte(2), *filt.Primary)
|
||||||
require.Equal(t, uint32(3), *filt.Since)
|
require.Equal(t, uint32(3), *filt.Since)
|
||||||
require.Equal(t, uint32(5), *filt.Till)
|
require.Equal(t, uint32(5), *filt.Till)
|
||||||
},
|
},
|
||||||
|
|
|
@ -2113,9 +2113,9 @@ func TestWSClient_SubscriptionsCompat(t *testing.T) {
|
||||||
blocks := getTestBlocks(t)
|
blocks := getTestBlocks(t)
|
||||||
bCount := uint32(0)
|
bCount := uint32(0)
|
||||||
|
|
||||||
getData := func(t *testing.T) (*block.Block, int, util.Uint160, string, string) {
|
getData := func(t *testing.T) (*block.Block, byte, util.Uint160, string, string) {
|
||||||
b1 := blocks[bCount]
|
b1 := blocks[bCount]
|
||||||
primary := int(b1.PrimaryIndex)
|
primary := b1.PrimaryIndex
|
||||||
tx := b1.Transactions[0]
|
tx := b1.Transactions[0]
|
||||||
sender := tx.Sender()
|
sender := tx.Sender()
|
||||||
ntfName := "Transfer"
|
ntfName := "Transfer"
|
||||||
|
|
|
@ -2729,7 +2729,7 @@ func (s *Server) subscribe(reqParams params.Params, sub *subscriber) (any, *neor
|
||||||
return nil, neorpc.WrapErrorWithData(neorpc.ErrInvalidParams, "P2PSigExtensions are disabled")
|
return nil, neorpc.WrapErrorWithData(neorpc.ErrInvalidParams, "P2PSigExtensions are disabled")
|
||||||
}
|
}
|
||||||
// Optional filter.
|
// Optional filter.
|
||||||
var filter any
|
var filter neorpc.SubscriptionFilter
|
||||||
if p := reqParams.Value(1); p != nil {
|
if p := reqParams.Value(1); p != nil {
|
||||||
param := *p
|
param := *p
|
||||||
jd := json.NewDecoder(bytes.NewReader(param.RawMessage))
|
jd := json.NewDecoder(bytes.NewReader(param.RawMessage))
|
||||||
|
@ -2754,16 +2754,18 @@ func (s *Server) subscribe(reqParams params.Params, sub *subscriber) (any, *neor
|
||||||
case neorpc.ExecutionEventID:
|
case neorpc.ExecutionEventID:
|
||||||
flt := new(neorpc.ExecutionFilter)
|
flt := new(neorpc.ExecutionFilter)
|
||||||
err = jd.Decode(flt)
|
err = jd.Decode(flt)
|
||||||
if err == nil && (flt.State == nil || (*flt.State == "HALT" || *flt.State == "FAULT")) {
|
filter = *flt
|
||||||
filter = *flt
|
|
||||||
} else if err == nil {
|
|
||||||
err = errors.New("invalid state")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, neorpc.WrapErrorWithData(neorpc.ErrInvalidParams, err.Error())
|
return nil, neorpc.WrapErrorWithData(neorpc.ErrInvalidParams, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if filter != nil {
|
||||||
|
err = filter.IsValid()
|
||||||
|
if err != nil {
|
||||||
|
return nil, neorpc.WrapErrorWithData(neorpc.ErrInvalidParams, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
s.subsLock.Lock()
|
s.subsLock.Lock()
|
||||||
var id int
|
var id int
|
||||||
|
|
|
@ -28,7 +28,7 @@ type (
|
||||||
// feed stores subscriber's desired event ID with filter.
|
// feed stores subscriber's desired event ID with filter.
|
||||||
feed struct {
|
feed struct {
|
||||||
event neorpc.EventID
|
event neorpc.EventID
|
||||||
filter any
|
filter neorpc.SubscriptionFilter
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ func (f feed) EventID() neorpc.EventID {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter implements neorpc.EventComparator interface and returns notification filter.
|
// Filter implements neorpc.EventComparator interface and returns notification filter.
|
||||||
func (f feed) Filter() any {
|
func (f feed) Filter() neorpc.SubscriptionFilter {
|
||||||
return f.filter
|
return f.filter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -676,3 +676,30 @@ func TestSubscriptionOverflow(t *testing.T) {
|
||||||
finishedFlag.CompareAndSwap(false, true)
|
finishedFlag.CompareAndSwap(false, true)
|
||||||
c.Close()
|
c.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFilteredSubscriptions_InvalidFilter(t *testing.T) {
|
||||||
|
var cases = map[string]struct {
|
||||||
|
params string
|
||||||
|
}{
|
||||||
|
"notification with long name": {
|
||||||
|
params: `["notification_from_execution", {"name":"notification_from_execution_with_long_name"}]`,
|
||||||
|
},
|
||||||
|
"execution with invalid vm state": {
|
||||||
|
params: `["transaction_executed", {"state":"NOTHALT"}]`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t)
|
||||||
|
defer chain.Close()
|
||||||
|
defer rpcSrv.Shutdown()
|
||||||
|
|
||||||
|
for name, this := range cases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
resp := callWSGetRaw(t, c, fmt.Sprintf(`{"jsonrpc": "2.0","method": "subscribe","params": %s,"id": 1}`, this.params), respMsgs)
|
||||||
|
require.NotNil(t, resp.Error)
|
||||||
|
require.Nil(t, resp.Result)
|
||||||
|
require.Contains(t, resp.Error.Error(), neorpc.ErrInvalidSubscriptionFilter.Error())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
finishedFlag.CompareAndSwap(false, true)
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue