diff --git a/pkg/rpcclient/local.go b/pkg/rpcclient/local.go new file mode 100644 index 000000000..19e67ceaa --- /dev/null +++ b/pkg/rpcclient/local.go @@ -0,0 +1,82 @@ +package rpcclient + +import ( + "context" + + "github.com/nspcc-dev/neo-go/pkg/neorpc" + "go.uber.org/atomic" +) + +// InternalHook is a function signature that is required to create a local client +// (see NewInternal). It performs registration of local client's event channel +// and returns a request handler function. +type InternalHook func(context.Context, chan<- neorpc.Notification) func(*neorpc.Request) (*neorpc.Response, error) + +// Internal is an experimental "local" client that does not connect to RPC via +// network. It's made for deeply integrated applications like NeoFS that have +// blockchain running in the same process, so use it only if you know what you're +// doing. It provides the same interface WSClient does. +type Internal struct { + WSClient + + events chan neorpc.Notification +} + +// NewInternal creates an instance of internal client. It accepts a method +// provided by RPC server. +func NewInternal(ctx context.Context, register InternalHook) (*Internal, error) { + c := &Internal{ + WSClient: WSClient{ + Client: Client{}, + Notifications: make(chan Notification), + + shutdown: make(chan struct{}), + done: make(chan struct{}), + closeCalled: *atomic.NewBool(false), + subscriptions: make(map[string]notificationReceiver), + receivers: make(map[interface{}][]string), + }, + events: make(chan neorpc.Notification), + } + + err := initClient(ctx, &c.WSClient.Client, "localhost:0", Options{}) + if err != nil { + return nil, err // Can't really happen for internal client. + } + c.cli = nil + go c.eventLoop() + // c.ctx is inherited from ctx in fact (see initClient). + c.requestF = register(c.ctx, c.events) //nolint:contextcheck // Non-inherited new context, use function like `context.WithXXX` instead + return c, nil +} + +func (c *Internal) eventLoop() { +eventloop: + for { + select { + case <-c.ctx.Done(): + break eventloop + case <-c.shutdown: + break eventloop + case ev := <-c.events: + ntf := Notification{Type: ev.Event} + if len(ev.Payload) > 0 { + ntf.Value = ev.Payload[0] + } + c.notifySubscribers(ntf) + } + } + close(c.done) + close(c.Notifications) + c.ctxCancel() + // ctx is cancelled, server is notified and will finish soon. +drainloop: + for { + select { + case <-c.events: + default: + break drainloop + } + } + close(c.events) +} diff --git a/pkg/rpcclient/local_test.go b/pkg/rpcclient/local_test.go new file mode 100644 index 000000000..0be93a646 --- /dev/null +++ b/pkg/rpcclient/local_test.go @@ -0,0 +1,18 @@ +package rpcclient + +import ( + "context" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/neorpc" + "github.com/stretchr/testify/require" +) + +func TestInternalClientClose(t *testing.T) { + icl, err := NewInternal(context.TODO(), func(ctx context.Context, ch chan<- neorpc.Notification) func(*neorpc.Request) (*neorpc.Response, error) { + return nil + }) + require.NoError(t, err) + icl.Close() + require.NoError(t, icl.GetError()) +} diff --git a/pkg/rpcclient/wsclient.go b/pkg/rpcclient/wsclient.go index 08dc72310..8cb224ccf 100644 --- a/pkg/rpcclient/wsclient.go +++ b/pkg/rpcclient/wsclient.go @@ -456,7 +456,7 @@ readloop: connCloseErr = fmt.Errorf("bad event received: %s / %d", event, len(rr.RawParams)) break readloop } - var val interface{} + ntf := Notification{Type: event} switch event { case neorpc.BlockEventID: sr, err := c.StateRootInHeader() @@ -465,15 +465,15 @@ readloop: connCloseErr = fmt.Errorf("failed to fetch StateRootInHeader: %w", err) break readloop } - val = block.New(sr) + ntf.Value = block.New(sr) case neorpc.TransactionEventID: - val = &transaction.Transaction{} + ntf.Value = &transaction.Transaction{} case neorpc.NotificationEventID: - val = new(state.ContainedNotificationEvent) + ntf.Value = new(state.ContainedNotificationEvent) case neorpc.ExecutionEventID: - val = new(state.AppExecResult) + ntf.Value = new(state.AppExecResult) case neorpc.NotaryRequestEventID: - val = new(result.NotaryRequestEvent) + ntf.Value = new(result.NotaryRequestEvent) case neorpc.MissedEventID: // No value. default: @@ -482,32 +482,14 @@ readloop: break readloop } if event != neorpc.MissedEventID { - err = json.Unmarshal(rr.RawParams[0], val) + err = json.Unmarshal(rr.RawParams[0], ntf.Value) if err != nil { // Bad event received. connCloseErr = fmt.Errorf("failed to unmarshal event of type %s from JSON: %w", event, err) break readloop } } - if event == neorpc.MissedEventID { - c.subscriptionsLock.Lock() - for rcvr, ids := range c.receivers { - c.subscriptions[ids[0]].Close() - delete(c.receivers, rcvr) - } - c.subscriptionsLock.Unlock() - continue readloop - } - c.subscriptionsLock.RLock() - ntf := Notification{Type: event, Value: val} - for _, ids := range c.receivers { - for _, id := range ids { - if c.subscriptions[id].TrySend(ntf) { - break // strictly one notification per channel - } - } - } - c.subscriptionsLock.RUnlock() + c.notifySubscribers(ntf) } else if rr.ID != nil && (rr.Error != nil || rr.Result != nil) { id, err := strconv.ParseUint(string(rr.ID), 10, 64) if err != nil { @@ -580,6 +562,27 @@ writeloop: } } +func (c *WSClient) notifySubscribers(ntf Notification) { + if ntf.Type == neorpc.MissedEventID { + c.subscriptionsLock.Lock() + for rcvr, ids := range c.receivers { + c.subscriptions[ids[0]].Close() + delete(c.receivers, rcvr) + } + c.subscriptionsLock.Unlock() + return + } + c.subscriptionsLock.RLock() + for _, ids := range c.receivers { + for _, id := range ids { + if c.subscriptions[id].TrySend(ntf) { + break // strictly one notification per channel + } + } + } + c.subscriptionsLock.RUnlock() +} + func (c *WSClient) unregisterRespChannel(id uint64) { c.respLock.Lock() defer c.respLock.Unlock() diff --git a/pkg/services/rpcsrv/client_test.go b/pkg/services/rpcsrv/client_test.go index 372333efa..031dab5ee 100644 --- a/pkg/services/rpcsrv/client_test.go +++ b/pkg/services/rpcsrv/client_test.go @@ -2033,15 +2033,45 @@ func TestClient_Wait(t *testing.T) { check(t, util.Uint256{1, 2, 3}, chain.BlockHeight()-1, true) } -func TestWSClient_Wait(t *testing.T) { +func mkSubsClient(t *testing.T, rpcSrv *Server, httpSrv *httptest.Server, local bool) *rpcclient.WSClient { + var ( + c *rpcclient.WSClient + err error + icl *rpcclient.Internal + ) + if local { + icl, err = rpcclient.NewInternal(context.Background(), rpcSrv.RegisterLocal) + } else { + url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" + c, err = rpcclient.NewWS(context.Background(), url, rpcclient.Options{}) + } + require.NoError(t, err) + if local { + c = &icl.WSClient + } + require.NoError(t, c.Init()) + return c +} + +func runWSAndLocal(t *testing.T, test func(*testing.T, bool)) { + t.Run("ws", func(t *testing.T) { + test(t, false) + }) + t.Run("local", func(t *testing.T) { + test(t, true) + }) +} + +func TestSubClientWait(t *testing.T) { + runWSAndLocal(t, testSubClientWait) +} + +func testSubClientWait(t *testing.T, local bool) { chain, rpcSrv, httpSrv := initClearServerWithServices(t, false, false, true) defer chain.Close() defer rpcSrv.Shutdown() - url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" - c, err := rpcclient.NewWS(context.Background(), url, rpcclient.Options{}) - require.NoError(t, err) - require.NoError(t, c.Init()) + c := mkSubsClient(t, rpcSrv, httpSrv, local) acc, err := wallet.NewAccount() require.NoError(t, err) act, err := actor.New(c, []actor.SignerAccount{ @@ -2135,15 +2165,16 @@ func TestWSClient_Wait(t *testing.T) { require.True(t, faultedChecked, "FAULTed transaction wasn't checked") } -func TestWSClient_WaitWithLateSubscription(t *testing.T) { +func TestSubClientWaitWithLateSubscription(t *testing.T) { + runWSAndLocal(t, testSubClientWaitWithLateSubscription) +} + +func testSubClientWaitWithLateSubscription(t *testing.T, local bool) { chain, rpcSrv, httpSrv := initClearServerWithServices(t, false, false, true) defer chain.Close() defer rpcSrv.Shutdown() - url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" - c, err := rpcclient.NewWS(context.Background(), url, rpcclient.Options{}) - require.NoError(t, err) - require.NoError(t, c.Init()) + c := mkSubsClient(t, rpcSrv, httpSrv, local) acc, err := wallet.NewAccount() require.NoError(t, err) act, err := actor.New(c, []actor.SignerAccount{ @@ -2182,15 +2213,16 @@ func TestWSClientHandshakeError(t *testing.T) { require.ErrorContains(t, err, "websocket users limit reached") } -func TestWSClient_WaitWithMissedEvent(t *testing.T) { +func TestSubClientWaitWithMissedEvent(t *testing.T) { + runWSAndLocal(t, testSubClientWaitWithMissedEvent) +} + +func testSubClientWaitWithMissedEvent(t *testing.T, local bool) { chain, rpcSrv, httpSrv := initClearServerWithServices(t, false, false, true) defer chain.Close() defer rpcSrv.Shutdown() - url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" - c, err := rpcclient.NewWS(context.Background(), url, rpcclient.Options{}) - require.NoError(t, err) - require.NoError(t, c.Init()) + c := mkSubsClient(t, rpcSrv, httpSrv, local) acc, err := wallet.NewAccount() require.NoError(t, err) act, err := actor.New(c, []actor.SignerAccount{ @@ -2272,10 +2304,7 @@ func TestWSClient_SubscriptionsCompat(t *testing.T) { defer chain.Close() defer rpcSrv.Shutdown() - url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" - c, err := rpcclient.NewWS(context.Background(), url, rpcclient.Options{}) - require.NoError(t, err) - require.NoError(t, c.Init()) + c := mkSubsClient(t, rpcSrv, httpSrv, false) blocks := getTestBlocks(t) bCount := uint32(0) @@ -2290,8 +2319,11 @@ func TestWSClient_SubscriptionsCompat(t *testing.T) { return b1, primary, sender, ntfName, st } checkDeprecated := func(t *testing.T, filtered bool) { + var ( + bID, txID, ntfID, aerID string + err error + ) b, primary, sender, ntfName, st := getData(t) - var bID, txID, ntfID, aerID string if filtered { bID, err = c.SubscribeForNewBlocks(&primary) //nolint:staticcheck // SA1019: c.SubscribeForNewBlocks is deprecated require.NoError(t, err) @@ -2382,6 +2414,7 @@ func TestWSClient_SubscriptionsCompat(t *testing.T) { txFlt *neorpc.TxFilter ntfFlt *neorpc.NotificationFilter aerFlt *neorpc.ExecutionFilter + err error ) if filtered { bFlt = &neorpc.BlockFilter{Primary: &primary} diff --git a/pkg/services/rpcsrv/local_test.go b/pkg/services/rpcsrv/local_test.go new file mode 100644 index 000000000..50e7660da --- /dev/null +++ b/pkg/services/rpcsrv/local_test.go @@ -0,0 +1,57 @@ +package rpcsrv + +import ( + "context" + "math/big" + "testing" + + "github.com/nspcc-dev/neo-go/internal/testchain" + "github.com/nspcc-dev/neo-go/pkg/config" + "github.com/nspcc-dev/neo-go/pkg/rpcclient" + "github.com/nspcc-dev/neo-go/pkg/rpcclient/actor" + "github.com/nspcc-dev/neo-go/pkg/rpcclient/gas" + "github.com/nspcc-dev/neo-go/pkg/rpcclient/invoker" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/wallet" + "github.com/stretchr/testify/require" +) + +func TestLocalClient(t *testing.T) { + _, rpcSrv, _ := initClearServerWithCustomConfig(t, func(cfg *config.Config) { + // No addresses configured -> RPC server listens nothing (but it + // has MaxGasInvoke, sessions and other stuff). + cfg.ApplicationConfiguration.RPC.BasicService.Enabled = true + cfg.ApplicationConfiguration.RPC.BasicService.Address = nil //nolint:staticcheck // SA1019: cfg.ApplicationConfiguration.RPC.BasicService.Address is deprecated + cfg.ApplicationConfiguration.RPC.BasicService.Port = nil //nolint:staticcheck // SA1019: cfg.ApplicationConfiguration.RPC.BasicService.Port is deprecated + cfg.ApplicationConfiguration.RPC.BasicService.Addresses = nil + cfg.ApplicationConfiguration.RPC.TLSConfig.Address = nil //nolint:staticcheck // SA1019: cfg.ApplicationConfiguration.RPC.TLSConfig.Address is deprecated + cfg.ApplicationConfiguration.RPC.TLSConfig.Port = nil //nolint:staticcheck // SA1019: cfg.ApplicationConfiguration.RPC.TLSConfig.Port is deprecated + cfg.ApplicationConfiguration.RPC.TLSConfig.Addresses = nil + }) + // RPC server listens nothing (not exposed in any way), but it works for internal clients. + c, err := rpcclient.NewInternal(context.TODO(), rpcSrv.RegisterLocal) + require.NoError(t, err) + require.NoError(t, c.Init()) + + // Invokers can use this client. + gasReader := gas.NewReader(invoker.New(c, nil)) + d, err := gasReader.Decimals() + require.NoError(t, err) + require.EqualValues(t, 8, d) + + // Actors can use it as well + priv := testchain.PrivateKeyByID(0) + acc := wallet.NewAccountFromPrivateKey(priv) + addr := priv.PublicKey().GetScriptHash() + + act, err := actor.NewSimple(c, acc) + require.NoError(t, err) + gasprom := gas.New(act) + txHash, _, err := gasprom.Transfer(addr, util.Uint160{}, big.NewInt(1000), nil) + require.NoError(t, err) + // No new blocks are produced here, but the tx is OK and is in the mempool. + txes, err := c.GetRawMemPool() + require.NoError(t, err) + require.Equal(t, []util.Uint256{txHash}, txes) + // Subscriptions are checked by other tests. +}