Merge pull request #2836 from nspcc-dev/fix-subs

rpcclient: fix filtered naive subscriptions receiver
This commit is contained in:
Roman Khimov 2022-12-07 21:17:07 +07:00 committed by GitHub
commit b1f1405f42
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 222 additions and 20 deletions

View file

@ -640,13 +640,13 @@ func (c *WSClient) performSubscription(params []interface{}, rcvr notificationRe
// //
// Deprecated: please, use ReceiveBlocks. This method will be removed in future versions. // Deprecated: please, use ReceiveBlocks. This method will be removed in future versions.
func (c *WSClient) SubscribeForNewBlocks(primary *int) (string, error) { func (c *WSClient) SubscribeForNewBlocks(primary *int) (string, error) {
var flt *neorpc.BlockFilter var flt interface{}
if primary != nil { if primary != nil {
flt = &neorpc.BlockFilter{Primary: primary} flt = neorpc.BlockFilter{Primary: primary}
} }
params := []interface{}{"block_added"} params := []interface{}{"block_added"}
if flt != nil { if flt != nil {
params = append(params, *flt) params = append(params, flt)
} }
r := &naiveReceiver{ r := &naiveReceiver{
eventID: neorpc.BlockEventID, eventID: neorpc.BlockEventID,
@ -685,13 +685,13 @@ func (c *WSClient) ReceiveBlocks(flt *neorpc.BlockFilter, rcvr chan<- *block.Blo
// //
// Deprecated: please, use ReceiveTransactions. This method will be removed in future versions. // Deprecated: please, use ReceiveTransactions. This method will be removed in future versions.
func (c *WSClient) SubscribeForNewTransactions(sender *util.Uint160, signer *util.Uint160) (string, error) { func (c *WSClient) SubscribeForNewTransactions(sender *util.Uint160, signer *util.Uint160) (string, error) {
var flt *neorpc.TxFilter var flt interface{}
if sender != nil || signer != nil { if sender != nil || signer != nil {
flt = &neorpc.TxFilter{Sender: sender, Signer: signer} flt = neorpc.TxFilter{Sender: sender, Signer: signer}
} }
params := []interface{}{"transaction_added"} params := []interface{}{"transaction_added"}
if flt != nil { if flt != nil {
params = append(params, *flt) params = append(params, flt)
} }
r := &naiveReceiver{ r := &naiveReceiver{
eventID: neorpc.TransactionEventID, eventID: neorpc.TransactionEventID,
@ -731,13 +731,13 @@ func (c *WSClient) ReceiveTransactions(flt *neorpc.TxFilter, rcvr chan<- *transa
// //
// Deprecated: please, use ReceiveExecutionNotifications. This method will be removed in future versions. // Deprecated: please, use ReceiveExecutionNotifications. This method will be removed in future versions.
func (c *WSClient) SubscribeForExecutionNotifications(contract *util.Uint160, name *string) (string, error) { func (c *WSClient) SubscribeForExecutionNotifications(contract *util.Uint160, name *string) (string, error) {
var flt *neorpc.NotificationFilter var flt interface{}
if contract != nil || name != nil { if contract != nil || name != nil {
flt = &neorpc.NotificationFilter{Contract: contract, Name: name} flt = neorpc.NotificationFilter{Contract: contract, Name: name}
} }
params := []interface{}{"notification_from_execution"} params := []interface{}{"notification_from_execution"}
if flt != nil { if flt != nil {
params = append(params, *flt) params = append(params, flt)
} }
r := &naiveReceiver{ r := &naiveReceiver{
eventID: neorpc.NotificationEventID, eventID: neorpc.NotificationEventID,
@ -777,18 +777,16 @@ func (c *WSClient) ReceiveExecutionNotifications(flt *neorpc.NotificationFilter,
// //
// Deprecated: please, use ReceiveExecutions. This method will be removed in future versions. // Deprecated: please, use ReceiveExecutions. This method will be removed in future versions.
func (c *WSClient) SubscribeForTransactionExecutions(state *string) (string, error) { func (c *WSClient) SubscribeForTransactionExecutions(state *string) (string, error) {
var flt *neorpc.ExecutionFilter var flt interface{}
if state != nil { if state != nil {
flt = &neorpc.ExecutionFilter{State: state} if *state != "HALT" && *state != "FAULT" {
return "", errors.New("bad state parameter")
}
flt = neorpc.ExecutionFilter{State: state}
} }
params := []interface{}{"transaction_executed"} params := []interface{}{"transaction_executed"}
if flt != nil { if flt != nil {
if flt.State != nil { params = append(params, flt)
if *flt.State != "HALT" && *flt.State != "FAULT" {
return "", errors.New("bad state parameter")
}
}
params = append(params, *flt)
} }
r := &naiveReceiver{ r := &naiveReceiver{
eventID: neorpc.ExecutionEventID, eventID: neorpc.ExecutionEventID,
@ -834,13 +832,13 @@ func (c *WSClient) ReceiveExecutions(flt *neorpc.ExecutionFilter, rcvr chan<- *s
// //
// Deprecated: please, use ReceiveNotaryRequests. This method will be removed in future versions. // Deprecated: please, use ReceiveNotaryRequests. This method will be removed in future versions.
func (c *WSClient) SubscribeForNotaryRequests(sender *util.Uint160, mainSigner *util.Uint160) (string, error) { func (c *WSClient) SubscribeForNotaryRequests(sender *util.Uint160, mainSigner *util.Uint160) (string, error) {
var flt *neorpc.TxFilter var flt interface{}
if sender != nil || mainSigner != nil { if sender != nil || mainSigner != nil {
flt = &neorpc.TxFilter{Sender: sender, Signer: mainSigner} flt = neorpc.TxFilter{Sender: sender, Signer: mainSigner}
} }
params := []interface{}{"notary_request_event"} params := []interface{}{"notary_request_event"}
if flt != nil { if flt != nil {
params = append(params, *flt) params = append(params, flt)
} }
r := &naiveReceiver{ r := &naiveReceiver{
eventID: neorpc.NotaryRequestEventID, eventID: neorpc.NotaryRequestEventID,

View file

@ -57,6 +57,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
"github.com/nspcc-dev/neo-go/pkg/vm/vmstate" "github.com/nspcc-dev/neo-go/pkg/vm/vmstate"
"github.com/nspcc-dev/neo-go/pkg/wallet" "github.com/nspcc-dev/neo-go/pkg/wallet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -2261,3 +2262,206 @@ waitloop:
} }
} }
} }
// TestWSClient_SubscriptionsCompat is aimed to test both deprecated and relevant
// subscriptions API with filtered and non-filtered subscriptions from the WSClient
// user side.
func TestWSClient_SubscriptionsCompat(t *testing.T) {
chain, rpcSrv, httpSrv := initClearServerWithServices(t, false, false, true)
defer chain.Close()
defer rpcSrv.Shutdown()
url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws"
c, err := rpcclient.NewWS(context.Background(), url, rpcclient.Options{})
require.NoError(t, err)
require.NoError(t, c.Init())
blocks := getTestBlocks(t)
bCount := uint32(0)
getData := func(t *testing.T) (*block.Block, int, util.Uint160, string, string) {
b1 := blocks[bCount]
primary := int(b1.PrimaryIndex)
tx := b1.Transactions[0]
sender := tx.Sender()
ntfName := "Transfer"
st := vmstate.Halt.String()
bCount++
return b1, primary, sender, ntfName, st
}
checkDeprecated := func(t *testing.T, filtered bool) {
b, primary, sender, ntfName, st := getData(t)
var bID, txID, ntfID, aerID string
if filtered {
bID, err = c.SubscribeForNewBlocks(&primary) //nolint:staticcheck // SA1019: c.SubscribeForNewBlocks is deprecated
require.NoError(t, err)
txID, err = c.SubscribeForNewTransactions(&sender, nil) //nolint:staticcheck // SA1019: c.SubscribeForNewTransactions is deprecated
require.NoError(t, err)
ntfID, err = c.SubscribeForExecutionNotifications(nil, &ntfName) //nolint:staticcheck // SA1019: c.SubscribeForExecutionNotifications is deprecated
require.NoError(t, err)
aerID, err = c.SubscribeForTransactionExecutions(&st) //nolint:staticcheck // SA1019: c.SubscribeForTransactionExecutions is deprecated
require.NoError(t, err)
} else {
bID, err = c.SubscribeForNewBlocks(nil) //nolint:staticcheck // SA1019: c.SubscribeForNewBlocks is deprecated
require.NoError(t, err)
txID, err = c.SubscribeForNewTransactions(nil, nil) //nolint:staticcheck // SA1019: c.SubscribeForNewTransactions is deprecated
require.NoError(t, err)
ntfID, err = c.SubscribeForExecutionNotifications(nil, nil) //nolint:staticcheck // SA1019: c.SubscribeForExecutionNotifications is deprecated
require.NoError(t, err)
aerID, err = c.SubscribeForTransactionExecutions(nil) //nolint:staticcheck // SA1019: c.SubscribeForTransactionExecutions is deprecated
require.NoError(t, err)
}
var (
lock sync.RWMutex
received byte
exitCh = make(chan struct{})
)
go func() {
dispatcher:
for {
select {
case ntf := <-c.Notifications: //nolint:staticcheck // SA1019: c.Notifications is deprecated
lock.Lock()
switch ntf.Type {
case neorpc.BlockEventID:
received |= 1
case neorpc.TransactionEventID:
received |= 1 << 1
case neorpc.NotificationEventID:
received |= 1 << 2
case neorpc.ExecutionEventID:
received |= 1 << 3
}
lock.Unlock()
case <-exitCh:
break dispatcher
}
}
drainLoop:
for {
select {
case <-c.Notifications: //nolint:staticcheck // SA1019: c.Notifications is deprecated
default:
break drainLoop
}
}
}()
// Accept the next block and wait for events.
require.NoError(t, chain.AddBlock(b))
assert.Eventually(t, func() bool {
lock.RLock()
defer lock.RUnlock()
return received == 1<<4-1
}, time.Second, 100*time.Millisecond)
require.NoError(t, c.Unsubscribe(bID))
require.NoError(t, c.Unsubscribe(txID))
require.NoError(t, c.Unsubscribe(ntfID))
require.NoError(t, c.Unsubscribe(aerID))
exitCh <- struct{}{}
}
t.Run("deprecated, filtered", func(t *testing.T) {
checkDeprecated(t, true)
})
t.Run("deprecated, non-filtered", func(t *testing.T) {
checkDeprecated(t, false)
})
checkRelevant := func(t *testing.T, filtered bool) {
b, primary, sender, ntfName, st := getData(t)
var (
bID, txID, ntfID, aerID string
blockCh = make(chan *block.Block)
txCh = make(chan *transaction.Transaction)
ntfCh = make(chan *state.ContainedNotificationEvent)
aerCh = make(chan *state.AppExecResult)
bFlt *neorpc.BlockFilter
txFlt *neorpc.TxFilter
ntfFlt *neorpc.NotificationFilter
aerFlt *neorpc.ExecutionFilter
)
if filtered {
bFlt = &neorpc.BlockFilter{Primary: &primary}
txFlt = &neorpc.TxFilter{Sender: &sender}
ntfFlt = &neorpc.NotificationFilter{Name: &ntfName}
aerFlt = &neorpc.ExecutionFilter{State: &st}
}
bID, err = c.ReceiveBlocks(bFlt, blockCh)
require.NoError(t, err)
txID, err = c.ReceiveTransactions(txFlt, txCh)
require.NoError(t, err)
ntfID, err = c.ReceiveExecutionNotifications(ntfFlt, ntfCh)
require.NoError(t, err)
aerID, err = c.ReceiveExecutions(aerFlt, aerCh)
require.NoError(t, err)
var (
lock sync.RWMutex
received byte
exitCh = make(chan struct{})
)
go func() {
dispatcher:
for {
select {
case <-blockCh:
lock.Lock()
received |= 1
lock.Unlock()
case <-txCh:
lock.Lock()
received |= 1 << 1
lock.Unlock()
case <-ntfCh:
lock.Lock()
received |= 1 << 2
lock.Unlock()
case <-aerCh:
lock.Lock()
received |= 1 << 3
lock.Unlock()
case <-exitCh:
break dispatcher
}
}
drainLoop:
for {
select {
case <-blockCh:
case <-txCh:
case <-ntfCh:
case <-aerCh:
default:
break drainLoop
}
}
close(blockCh)
close(txCh)
close(ntfCh)
close(aerCh)
}()
// Accept the next block and wait for events.
require.NoError(t, chain.AddBlock(b))
assert.Eventually(t, func() bool {
lock.RLock()
defer lock.RUnlock()
return received == 1<<4-1
}, time.Second, 100*time.Millisecond)
require.NoError(t, c.Unsubscribe(bID))
require.NoError(t, c.Unsubscribe(txID))
require.NoError(t, c.Unsubscribe(ntfID))
require.NoError(t, c.Unsubscribe(aerID))
exitCh <- struct{}{}
}
t.Run("relevant, filtered", func(t *testing.T) {
checkRelevant(t, true)
})
t.Run("relevant, non-filtered", func(t *testing.T) {
checkRelevant(t, false)
})
}