From 2c0844221a38fb78783e2421c9f3ff98e4eb00ee Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Wed, 7 Dec 2022 15:07:42 +0300 Subject: [PATCH] rpcclient: fix filtered naive subscriptions receiver We should return the filter itself instead of pointer. --- pkg/rpcclient/wsclient.go | 38 +++--- pkg/services/rpcsrv/client_test.go | 204 +++++++++++++++++++++++++++++ 2 files changed, 222 insertions(+), 20 deletions(-) diff --git a/pkg/rpcclient/wsclient.go b/pkg/rpcclient/wsclient.go index 48f3842fd..fec9e3d2d 100644 --- a/pkg/rpcclient/wsclient.go +++ b/pkg/rpcclient/wsclient.go @@ -640,13 +640,13 @@ func (c *WSClient) performSubscription(params []interface{}, rcvr notificationRe // // Deprecated: please, use ReceiveBlocks. This method will be removed in future versions. func (c *WSClient) SubscribeForNewBlocks(primary *int) (string, error) { - var flt *neorpc.BlockFilter + var flt interface{} if primary != nil { - flt = &neorpc.BlockFilter{Primary: primary} + flt = neorpc.BlockFilter{Primary: primary} } params := []interface{}{"block_added"} if flt != nil { - params = append(params, *flt) + params = append(params, flt) } r := &naiveReceiver{ eventID: neorpc.BlockEventID, @@ -685,13 +685,13 @@ func (c *WSClient) ReceiveBlocks(flt *neorpc.BlockFilter, rcvr chan<- *block.Blo // // Deprecated: please, use ReceiveTransactions. This method will be removed in future versions. func (c *WSClient) SubscribeForNewTransactions(sender *util.Uint160, signer *util.Uint160) (string, error) { - var flt *neorpc.TxFilter + var flt interface{} if sender != nil || signer != nil { - flt = &neorpc.TxFilter{Sender: sender, Signer: signer} + flt = neorpc.TxFilter{Sender: sender, Signer: signer} } params := []interface{}{"transaction_added"} if flt != nil { - params = append(params, *flt) + params = append(params, flt) } r := &naiveReceiver{ eventID: neorpc.TransactionEventID, @@ -731,13 +731,13 @@ func (c *WSClient) ReceiveTransactions(flt *neorpc.TxFilter, rcvr chan<- *transa // // Deprecated: please, use ReceiveExecutionNotifications. This method will be removed in future versions. func (c *WSClient) SubscribeForExecutionNotifications(contract *util.Uint160, name *string) (string, error) { - var flt *neorpc.NotificationFilter + var flt interface{} if contract != nil || name != nil { - flt = &neorpc.NotificationFilter{Contract: contract, Name: name} + flt = neorpc.NotificationFilter{Contract: contract, Name: name} } params := []interface{}{"notification_from_execution"} if flt != nil { - params = append(params, *flt) + params = append(params, flt) } r := &naiveReceiver{ eventID: neorpc.NotificationEventID, @@ -777,18 +777,16 @@ func (c *WSClient) ReceiveExecutionNotifications(flt *neorpc.NotificationFilter, // // Deprecated: please, use ReceiveExecutions. This method will be removed in future versions. func (c *WSClient) SubscribeForTransactionExecutions(state *string) (string, error) { - var flt *neorpc.ExecutionFilter + var flt interface{} if state != nil { - flt = &neorpc.ExecutionFilter{State: state} + if *state != "HALT" && *state != "FAULT" { + return "", errors.New("bad state parameter") + } + flt = neorpc.ExecutionFilter{State: state} } params := []interface{}{"transaction_executed"} if flt != nil { - if flt.State != nil { - if *flt.State != "HALT" && *flt.State != "FAULT" { - return "", errors.New("bad state parameter") - } - } - params = append(params, *flt) + params = append(params, flt) } r := &naiveReceiver{ eventID: neorpc.ExecutionEventID, @@ -834,13 +832,13 @@ func (c *WSClient) ReceiveExecutions(flt *neorpc.ExecutionFilter, rcvr chan<- *s // // Deprecated: please, use ReceiveNotaryRequests. This method will be removed in future versions. func (c *WSClient) SubscribeForNotaryRequests(sender *util.Uint160, mainSigner *util.Uint160) (string, error) { - var flt *neorpc.TxFilter + var flt interface{} if sender != nil || mainSigner != nil { - flt = &neorpc.TxFilter{Sender: sender, Signer: mainSigner} + flt = neorpc.TxFilter{Sender: sender, Signer: mainSigner} } params := []interface{}{"notary_request_event"} if flt != nil { - params = append(params, *flt) + params = append(params, flt) } r := &naiveReceiver{ eventID: neorpc.NotaryRequestEventID, diff --git a/pkg/services/rpcsrv/client_test.go b/pkg/services/rpcsrv/client_test.go index e80eacc2c..d6af19ea4 100644 --- a/pkg/services/rpcsrv/client_test.go +++ b/pkg/services/rpcsrv/client_test.go @@ -57,6 +57,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/nspcc-dev/neo-go/pkg/wallet" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -2260,3 +2261,206 @@ waitloop: } } } + +// TestWSClient_SubscriptionsCompat is aimed to test both deprecated and relevant +// subscriptions API with filtered and non-filtered subscriptions from the WSClient +// user side. +func TestWSClient_SubscriptionsCompat(t *testing.T) { + chain, rpcSrv, httpSrv := initClearServerWithServices(t, false, false, true) + defer chain.Close() + defer rpcSrv.Shutdown() + + url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" + c, err := rpcclient.NewWS(context.Background(), url, rpcclient.Options{}) + require.NoError(t, err) + require.NoError(t, c.Init()) + blocks := getTestBlocks(t) + bCount := uint32(0) + + getData := func(t *testing.T) (*block.Block, int, util.Uint160, string, string) { + b1 := blocks[bCount] + primary := int(b1.PrimaryIndex) + tx := b1.Transactions[0] + sender := tx.Sender() + ntfName := "Transfer" + st := vmstate.Halt.String() + bCount++ + return b1, primary, sender, ntfName, st + } + checkDeprecated := func(t *testing.T, filtered bool) { + b, primary, sender, ntfName, st := getData(t) + var bID, txID, ntfID, aerID string + if filtered { + bID, err = c.SubscribeForNewBlocks(&primary) //nolint:staticcheck // SA1019: c.SubscribeForNewBlocks is deprecated + require.NoError(t, err) + txID, err = c.SubscribeForNewTransactions(&sender, nil) //nolint:staticcheck // SA1019: c.SubscribeForNewTransactions is deprecated + require.NoError(t, err) + ntfID, err = c.SubscribeForExecutionNotifications(nil, &ntfName) //nolint:staticcheck // SA1019: c.SubscribeForExecutionNotifications is deprecated + require.NoError(t, err) + aerID, err = c.SubscribeForTransactionExecutions(&st) //nolint:staticcheck // SA1019: c.SubscribeForTransactionExecutions is deprecated + require.NoError(t, err) + } else { + bID, err = c.SubscribeForNewBlocks(nil) //nolint:staticcheck // SA1019: c.SubscribeForNewBlocks is deprecated + require.NoError(t, err) + txID, err = c.SubscribeForNewTransactions(nil, nil) //nolint:staticcheck // SA1019: c.SubscribeForNewTransactions is deprecated + require.NoError(t, err) + ntfID, err = c.SubscribeForExecutionNotifications(nil, nil) //nolint:staticcheck // SA1019: c.SubscribeForExecutionNotifications is deprecated + require.NoError(t, err) + aerID, err = c.SubscribeForTransactionExecutions(nil) //nolint:staticcheck // SA1019: c.SubscribeForTransactionExecutions is deprecated + require.NoError(t, err) + } + + var ( + lock sync.RWMutex + received byte + exitCh = make(chan struct{}) + ) + go func() { + dispatcher: + for { + select { + case ntf := <-c.Notifications: //nolint:staticcheck // SA1019: c.Notifications is deprecated + lock.Lock() + switch ntf.Type { + case neorpc.BlockEventID: + received |= 1 + case neorpc.TransactionEventID: + received |= 1 << 1 + case neorpc.NotificationEventID: + received |= 1 << 2 + case neorpc.ExecutionEventID: + received |= 1 << 3 + } + lock.Unlock() + case <-exitCh: + break dispatcher + } + } + drainLoop: + for { + select { + case <-c.Notifications: //nolint:staticcheck // SA1019: c.Notifications is deprecated + default: + break drainLoop + } + } + }() + + // Accept the next block and wait for events. + require.NoError(t, chain.AddBlock(b)) + assert.Eventually(t, func() bool { + lock.RLock() + defer lock.RUnlock() + + return received == 1<<4-1 + }, time.Second, 100*time.Millisecond) + + require.NoError(t, c.Unsubscribe(bID)) + require.NoError(t, c.Unsubscribe(txID)) + require.NoError(t, c.Unsubscribe(ntfID)) + require.NoError(t, c.Unsubscribe(aerID)) + exitCh <- struct{}{} + } + t.Run("deprecated, filtered", func(t *testing.T) { + checkDeprecated(t, true) + }) + t.Run("deprecated, non-filtered", func(t *testing.T) { + checkDeprecated(t, false) + }) + + checkRelevant := func(t *testing.T, filtered bool) { + b, primary, sender, ntfName, st := getData(t) + var ( + bID, txID, ntfID, aerID string + blockCh = make(chan *block.Block) + txCh = make(chan *transaction.Transaction) + ntfCh = make(chan *state.ContainedNotificationEvent) + aerCh = make(chan *state.AppExecResult) + bFlt *neorpc.BlockFilter + txFlt *neorpc.TxFilter + ntfFlt *neorpc.NotificationFilter + aerFlt *neorpc.ExecutionFilter + ) + if filtered { + bFlt = &neorpc.BlockFilter{Primary: &primary} + txFlt = &neorpc.TxFilter{Sender: &sender} + ntfFlt = &neorpc.NotificationFilter{Name: &ntfName} + aerFlt = &neorpc.ExecutionFilter{State: &st} + } + bID, err = c.ReceiveBlocks(bFlt, blockCh) + require.NoError(t, err) + txID, err = c.ReceiveTransactions(txFlt, txCh) + require.NoError(t, err) + ntfID, err = c.ReceiveExecutionNotifications(ntfFlt, ntfCh) + require.NoError(t, err) + aerID, err = c.ReceiveExecutions(aerFlt, aerCh) + require.NoError(t, err) + + var ( + lock sync.RWMutex + received byte + exitCh = make(chan struct{}) + ) + go func() { + dispatcher: + for { + select { + case <-blockCh: + lock.Lock() + received |= 1 + lock.Unlock() + case <-txCh: + lock.Lock() + received |= 1 << 1 + lock.Unlock() + case <-ntfCh: + lock.Lock() + received |= 1 << 2 + lock.Unlock() + case <-aerCh: + lock.Lock() + received |= 1 << 3 + lock.Unlock() + case <-exitCh: + break dispatcher + } + } + drainLoop: + for { + select { + case <-blockCh: + case <-txCh: + case <-ntfCh: + case <-aerCh: + default: + break drainLoop + } + } + close(blockCh) + close(txCh) + close(ntfCh) + close(aerCh) + }() + + // Accept the next block and wait for events. + require.NoError(t, chain.AddBlock(b)) + assert.Eventually(t, func() bool { + lock.RLock() + defer lock.RUnlock() + + return received == 1<<4-1 + }, time.Second, 100*time.Millisecond) + + require.NoError(t, c.Unsubscribe(bID)) + require.NoError(t, c.Unsubscribe(txID)) + require.NoError(t, c.Unsubscribe(ntfID)) + require.NoError(t, c.Unsubscribe(aerID)) + exitCh <- struct{}{} + } + t.Run("relevant, filtered", func(t *testing.T) { + checkRelevant(t, true) + }) + t.Run("relevant, non-filtered", func(t *testing.T) { + checkRelevant(t, false) + }) +}