diff --git a/pkg/rpc/request/param.go b/pkg/rpc/request/param.go index f917a33ee..3f5f9acfa 100644 --- a/pkg/rpc/request/param.go +++ b/pkg/rpc/request/param.go @@ -191,7 +191,11 @@ func (p *Param) UnmarshalJSON(data []byte) error { case *NotificationFilter: p.Value = *val case *ExecutionFilter: - p.Value = *val + if (*val).State == "HALT" || (*val).State == "FAULT" { + p.Value = *val + } else { + continue + } case *[]Param: p.Value = *val } diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index 18a011f87..fad04857e 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -366,8 +366,8 @@ requestloop: s.subsLock.Lock() delete(s.subscribers, subscr) for _, e := range subscr.feeds { - if e != response.InvalidEventID { - s.unsubscribeFromChannel(e) + if e.event != response.InvalidEventID { + s.unsubscribeFromChannel(e.event) } } s.subsLock.Unlock() @@ -1145,6 +1145,31 @@ func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface if err != nil || event == response.MissedEventID { return nil, response.ErrInvalidParams } + // Optional filter. + var filter interface{} + p, ok = reqParams.Value(1) + if ok { + switch event { + case response.BlockEventID: + if p.Type != request.BlockFilterT { + return nil, response.ErrInvalidParams + } + case response.TransactionEventID: + if p.Type != request.TxFilterT { + return nil, response.ErrInvalidParams + } + case response.NotificationEventID: + if p.Type != request.NotificationFilterT { + return nil, response.ErrInvalidParams + } + case response.ExecutionEventID: + if p.Type != request.ExecutionFilterT { + return nil, response.ErrInvalidParams + } + } + filter = p.Value + } + s.subsLock.Lock() defer s.subsLock.Unlock() select { @@ -1154,14 +1179,15 @@ func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface } var id int for ; id < len(sub.feeds); id++ { - if sub.feeds[id] == response.InvalidEventID { + if sub.feeds[id].event == response.InvalidEventID { break } } if id == len(sub.feeds) { return nil, response.NewInternalServerError("maximum number of subscriptions is reached", nil) } - sub.feeds[id] = event + sub.feeds[id].event = event + sub.feeds[id].filter = filter s.subscribeToChannel(event) return strconv.FormatInt(int64(id), 10), nil } @@ -1206,11 +1232,12 @@ func (s *Server) unsubscribe(reqParams request.Params, sub *subscriber) (interfa } s.subsLock.Lock() defer s.subsLock.Unlock() - if len(sub.feeds) <= id || sub.feeds[id] == response.InvalidEventID { + if len(sub.feeds) <= id || sub.feeds[id].event == response.InvalidEventID { return nil, response.ErrInvalidParams } - event := sub.feeds[id] - sub.feeds[id] = response.InvalidEventID + event := sub.feeds[id].event + sub.feeds[id].event = response.InvalidEventID + sub.feeds[id].filter = nil s.unsubscribeFromChannel(event) return true, nil } @@ -1287,8 +1314,8 @@ chloop: if sub.overflown.Load() { continue } - for _, subID := range sub.feeds { - if subID == resp.Event { + for i := range sub.feeds { + if sub.feeds[i].Matches(&resp) { if msg == nil { b, err = json.Marshal(resp) if err != nil { diff --git a/pkg/rpc/server/server_test.go b/pkg/rpc/server/server_test.go index b3f64686f..f36d6c651 100644 --- a/pkg/rpc/server/server_test.go +++ b/pkg/rpc/server/server_test.go @@ -910,7 +910,7 @@ func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) [] t.Run("submit", func(t *testing.T) { rpc := `{"jsonrpc": "2.0", "id": 1, "method": "submitblock", "params": ["%s"]}` t.Run("invalid signature", func(t *testing.T) { - s := newBlock(t, chain, 1) + s := newBlock(t, chain, 1, 0) s.Script.VerificationScript[8] ^= 0xff body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, s)), httpSrv.URL, t) checkErrGetResult(t, body, true) @@ -940,13 +940,13 @@ func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) [] } t.Run("invalid height", func(t *testing.T) { - b := newBlock(t, chain, 2, newTx()) + b := newBlock(t, chain, 2, 0, newTx()) body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, b)), httpSrv.URL, t) checkErrGetResult(t, body, true) }) t.Run("positive", func(t *testing.T) { - b := newBlock(t, chain, 1, newTx()) + b := newBlock(t, chain, 1, 0, newTx()) body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, b)), httpSrv.URL, t) data := checkErrGetResult(t, body, false) var res bool @@ -1114,7 +1114,7 @@ func encodeBlock(t *testing.T, b *block.Block) string { return hex.EncodeToString(w.Bytes()) } -func newBlock(t *testing.T, bc blockchainer.Blockchainer, index uint32, txs ...*transaction.Transaction) *block.Block { +func newBlock(t *testing.T, bc blockchainer.Blockchainer, index uint32, primary uint32, txs ...*transaction.Transaction) *block.Block { witness := transaction.Witness{VerificationScript: testchain.MultisigVerificationScript()} height := bc.BlockHeight() h := bc.GetHeaderHash(int(height)) @@ -1129,7 +1129,7 @@ func newBlock(t *testing.T, bc blockchainer.Blockchainer, index uint32, txs ...* Script: witness, }, ConsensusData: block.ConsensusData{ - PrimaryIndex: 0, + PrimaryIndex: primary, Nonce: 1111, }, Transactions: txs, diff --git a/pkg/rpc/server/subscription.go b/pkg/rpc/server/subscription.go index f4c736b08..16433ce51 100644 --- a/pkg/rpc/server/subscription.go +++ b/pkg/rpc/server/subscription.go @@ -2,7 +2,11 @@ package server import ( "github.com/gorilla/websocket" + "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/rpc/response" + "github.com/nspcc-dev/neo-go/pkg/rpc/response/result" "go.uber.org/atomic" ) @@ -16,7 +20,11 @@ type ( // cheaper doing it this way rather than creating a map), // pointing to EventID is an obvious overkill at the moment, but // that's not for long. - feeds [maxFeeds]response.EventID + feeds [maxFeeds]feed + } + feed struct { + event response.EventID + filter interface{} } ) @@ -34,3 +42,42 @@ const ( // a lot in terms of memory used. notificationBufSize = 1024 ) + +func (f *feed) Matches(r *response.Notification) bool { + if r.Event != f.event { + return false + } + if f.filter == nil { + return true + } + switch f.event { + case response.BlockEventID: + filt := f.filter.(request.BlockFilter) + b := r.Payload[0].(*block.Block) + return int(b.ConsensusData.PrimaryIndex) == filt.Primary + case response.TransactionEventID: + filt := f.filter.(request.TxFilter) + tx := r.Payload[0].(*transaction.Transaction) + senderOK := filt.Sender == nil || tx.Sender.Equals(*filt.Sender) + cosignerOK := true + if filt.Cosigner != nil { + cosignerOK = false + for i := range tx.Cosigners { + if tx.Cosigners[i].Account.Equals(*filt.Cosigner) { + cosignerOK = true + break + } + } + } + return senderOK && cosignerOK + case response.NotificationEventID: + filt := f.filter.(request.NotificationFilter) + notification := r.Payload[0].(result.NotificationEvent) + return notification.Contract.Equals(filt.Contract) + case response.ExecutionEventID: + filt := f.filter.(request.ExecutionFilter) + applog := r.Payload[0].(result.ApplicationLog) + return len(applog.Executions) != 0 && applog.Executions[0].VMState == filt.State + } + return false +} diff --git a/pkg/rpc/server/subscription_test.go b/pkg/rpc/server/subscription_test.go index 27a5397ef..162674549 100644 --- a/pkg/rpc/server/subscription_test.go +++ b/pkg/rpc/server/subscription_test.go @@ -10,6 +10,8 @@ import ( "github.com/gorilla/websocket" "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/encoding/address" + "github.com/nspcc-dev/neo-go/pkg/internal/testchain" "github.com/nspcc-dev/neo-go/pkg/rpc/response" "github.com/stretchr/testify/require" "go.uber.org/atomic" @@ -62,6 +64,24 @@ func initCleanServerAndWSClient(t *testing.T) (*core.Blockchain, *Server, *webso return chain, rpcSrv, ws, respMsgs, finishedFlag } +func callSubscribe(t *testing.T, ws *websocket.Conn, msgs <-chan []byte, params string) string { + var s string + resp := callWSGetRaw(t, ws, fmt.Sprintf(`{"jsonrpc": "2.0","method": "subscribe","params": %s,"id": 1}`, params), msgs) + require.Nil(t, resp.Error) + require.NotNil(t, resp.Result) + require.NoError(t, json.Unmarshal(resp.Result, &s)) + return s +} + +func callUnsubscribe(t *testing.T, ws *websocket.Conn, msgs <-chan []byte, id string) { + var b bool + resp := callWSGetRaw(t, ws, fmt.Sprintf(`{"jsonrpc": "2.0","method": "unsubscribe","params": ["%s"],"id": 1}`, id), msgs) + require.Nil(t, resp.Error) + require.NotNil(t, resp.Result) + require.NoError(t, json.Unmarshal(resp.Result, &b)) + require.Equal(t, true, b) +} + func TestSubscriptions(t *testing.T) { var subIDs = make([]string, 0) var subFeeds = []string{"block_added", "transaction_added", "notification_from_execution", "transaction_executed"} @@ -72,16 +92,7 @@ func TestSubscriptions(t *testing.T) { defer rpcSrv.Shutdown() for _, feed := range subFeeds { - var s string - resp := callWSGetRaw(t, c, fmt.Sprintf(`{ - "jsonrpc": "2.0", - "method": "subscribe", - "params": ["%s"], - "id": 1 -}`, feed), respMsgs) - require.Nil(t, resp.Error) - require.NotNil(t, resp.Result) - require.NoError(t, json.Unmarshal(resp.Result, &s)) + s := callSubscribe(t, c, respMsgs, fmt.Sprintf(`["%s"]`, feed)) subIDs = append(subIDs, s) } @@ -109,23 +120,173 @@ func TestSubscriptions(t *testing.T) { } for _, id := range subIDs { - var b bool - - resp := callWSGetRaw(t, c, fmt.Sprintf(`{ - "jsonrpc": "2.0", - "method": "unsubscribe", - "params": ["%s"], - "id": 1 -}`, id), respMsgs) - require.Nil(t, resp.Error) - require.NotNil(t, resp.Result) - require.NoError(t, json.Unmarshal(resp.Result, &b)) - require.Equal(t, true, b) + callUnsubscribe(t, c, respMsgs, id) } finishedFlag.CAS(false, true) c.Close() } +func TestFilteredSubscriptions(t *testing.T) { + priv0 := testchain.PrivateKeyByID(0) + var goodSender = priv0.GetScriptHash() + + var cases = map[string]struct { + params string + check func(*testing.T, *response.Notification) + }{ + "tx matching sender": { + params: `["transaction_added", {"sender":"` + goodSender.StringLE() + `"}]`, + check: func(t *testing.T, resp *response.Notification) { + rmap := resp.Payload[0].(map[string]interface{}) + require.Equal(t, response.TransactionEventID, resp.Event) + sender := rmap["sender"].(string) + require.Equal(t, address.Uint160ToString(goodSender), sender) + }, + }, + "tx matching cosigner": { + params: `["transaction_added", {"cosigner":"` + goodSender.StringLE() + `"}]`, + check: func(t *testing.T, resp *response.Notification) { + rmap := resp.Payload[0].(map[string]interface{}) + require.Equal(t, response.TransactionEventID, resp.Event) + cosigners := rmap["cosigners"].([]interface{}) + cosigner0 := cosigners[0].(map[string]interface{}) + cosigner0acc := cosigner0["account"].(string) + require.Equal(t, "0x"+goodSender.StringLE(), cosigner0acc) + }, + }, + "tx matching sender and cosigner": { + params: `["transaction_added", {"sender":"` + goodSender.StringLE() + `", "cosigner":"` + goodSender.StringLE() + `"}]`, + check: func(t *testing.T, resp *response.Notification) { + rmap := resp.Payload[0].(map[string]interface{}) + require.Equal(t, response.TransactionEventID, resp.Event) + sender := rmap["sender"].(string) + require.Equal(t, address.Uint160ToString(goodSender), sender) + cosigners := rmap["cosigners"].([]interface{}) + cosigner0 := cosigners[0].(map[string]interface{}) + cosigner0acc := cosigner0["account"].(string) + require.Equal(t, "0x"+goodSender.StringLE(), cosigner0acc) + }, + }, + "notification matching": { + params: `["notification_from_execution", {"contract":"` + testContractHash + `"}]`, + check: func(t *testing.T, resp *response.Notification) { + rmap := resp.Payload[0].(map[string]interface{}) + require.Equal(t, response.NotificationEventID, resp.Event) + c := rmap["contract"].(string) + require.Equal(t, "0x"+testContractHash, c) + }, + }, + "execution matching": { + params: `["transaction_executed", {"state":"HALT"}]`, + check: func(t *testing.T, resp *response.Notification) { + rmap := resp.Payload[0].(map[string]interface{}) + require.Equal(t, response.ExecutionEventID, resp.Event) + execs := rmap["executions"].([]interface{}) + exec0 := execs[0].(map[string]interface{}) + st := exec0["vmstate"].(string) + require.Equal(t, "HALT", st) + }, + }, + "tx non-matching": { + params: `["transaction_added", {"sender":"00112233445566778899aabbccddeeff00112233"}]`, + check: func(t *testing.T, _ *response.Notification) { + t.Fatal("unexpected match for EnrollmentTransaction") + }, + }, + "notification non-matching": { + params: `["notification_from_execution", {"contract":"00112233445566778899aabbccddeeff00112233"}]`, + check: func(t *testing.T, _ *response.Notification) { + t.Fatal("unexpected match for contract 00112233445566778899aabbccddeeff00112233") + }, + }, + "execution non-matching": { + params: `["transaction_executed", {"state":"FAULT"}]`, + check: func(t *testing.T, _ *response.Notification) { + t.Fatal("unexpected match for faulted execution") + }, + }, + } + + for name, this := range cases { + t.Run(name, func(t *testing.T) { + chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t) + + defer chain.Close() + defer rpcSrv.Shutdown() + + // It's used as an end-of-event-stream, so it's always present. + blockSubID := callSubscribe(t, c, respMsgs, `["block_added"]`) + subID := callSubscribe(t, c, respMsgs, this.params) + + var lastBlock uint32 + for _, b := range getTestBlocks(t) { + require.NoError(t, chain.AddBlock(b)) + lastBlock = b.Index + } + + for { + resp := getNotification(t, respMsgs) + rmap := resp.Payload[0].(map[string]interface{}) + if resp.Event == response.BlockEventID { + index := rmap["height"].(float64) + if uint32(index) == lastBlock { + break + } + continue + } + this.check(t, resp) + } + + callUnsubscribe(t, c, respMsgs, subID) + callUnsubscribe(t, c, respMsgs, blockSubID) + finishedFlag.CAS(false, true) + c.Close() + }) + } +} + +func TestFilteredBlockSubscriptions(t *testing.T) { + // We can't fit this into TestFilteredSubscriptions, because it uses + // blocks as EOF events to wait for. + const numBlocks = 10 + chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t) + + defer chain.Close() + defer rpcSrv.Shutdown() + + blockSubID := callSubscribe(t, c, respMsgs, `["block_added", {"primary":3}]`) + + var expectedCnt int + for i := 0; i < numBlocks; i++ { + primary := uint32(i % 4) + if primary == 3 { + expectedCnt++ + } + b := newBlock(t, chain, 1, primary) + require.NoError(t, chain.AddBlock(b)) + } + + for i := 0; i < expectedCnt; i++ { + var resp = new(response.Notification) + select { + case body := <-respMsgs: + require.NoError(t, json.Unmarshal(body, resp)) + case <-time.After(time.Second): + t.Fatal("timeout waiting for event") + } + + require.Equal(t, response.BlockEventID, resp.Event) + rmap := resp.Payload[0].(map[string]interface{}) + cd := rmap["consensus_data"].(map[string]interface{}) + primary := cd["primary"].(float64) + require.Equal(t, 3, int(primary)) + + } + callUnsubscribe(t, c, respMsgs, blockSubID) + finishedFlag.CAS(false, true) + c.Close() +} + func TestMaxSubscriptions(t *testing.T) { var subIDs = make([]string, 0) chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t) @@ -161,6 +322,12 @@ func TestBadSubUnsub(t *testing.T) { "bad (non-string) event": `{"jsonrpc": "2.0", "method": "subscribe", "params": [1], "id": 1}`, "bad (wrong) event": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["block_removed"], "id": 1}`, "missed event": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["event_missed"], "id": 1}`, + "block invalid filter": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["block_added", 1], "id": 1}`, + "tx filter 1": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["transaction_added", 1], "id": 1}`, + "tx filter 2": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["transaction_added", {"state": "HALT"}], "id": 1}`, + "notification filter": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["notification_from_execution", "contract"], "id": 1}`, + "execution filter 1": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["transaction_executed", "FAULT"], "id": 1}`, + "execution filter 2": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["transaction_executed", {"state": "STOP"}], "id": 1}`, } var unsubCases = map[string]string{ "no params": `{"jsonrpc": "2.0", "method": "unsubscribe", "params": [], "id": 1}`, @@ -247,7 +414,7 @@ func testSubscriptionOverflow(t *testing.T) { // Push a lot of new blocks, but don't read events for them. for i := 0; i < blockCnt; i++ { - b := newBlock(t, chain, 1) + b := newBlock(t, chain, 1, 0) require.NoError(t, chain.AddBlock(b)) } for i := 0; i < blockCnt; i++ {