rpc/server: add notification filters

And check state string correctness on unmarshaling.
This commit is contained in:
Roman Khimov 2020-05-13 17:13:33 +03:00
parent 78716c5335
commit 8f55f0ac76
4 changed files with 192 additions and 33 deletions

View file

@ -182,7 +182,11 @@ func (p *Param) UnmarshalJSON(data []byte) error {
case *NotificationFilter: case *NotificationFilter:
p.Value = *val p.Value = *val
case *ExecutionFilter: case *ExecutionFilter:
if (*val).State == "HALT" || (*val).State == "FAULT" {
p.Value = *val p.Value = *val
} else {
continue
}
case *[]Param: case *[]Param:
p.Value = *val p.Value = *val
} }

View file

@ -365,8 +365,8 @@ requestloop:
s.subsLock.Lock() s.subsLock.Lock()
delete(s.subscribers, subscr) delete(s.subscribers, subscr)
for _, e := range subscr.feeds { for _, e := range subscr.feeds {
if e != response.InvalidEventID { if e.event != response.InvalidEventID {
s.unsubscribeFromChannel(e) s.unsubscribeFromChannel(e.event)
} }
} }
s.subsLock.Unlock() s.subsLock.Unlock()
@ -1146,6 +1146,32 @@ func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface
if err != nil || event == response.MissedEventID { if err != nil || event == response.MissedEventID {
return nil, response.ErrInvalidParams return nil, response.ErrInvalidParams
} }
// Optional filter.
var filter interface{}
p, ok = reqParams.Value(1)
if ok {
// It doesn't accept filters.
if event == response.BlockEventID {
return nil, response.ErrInvalidParams
}
switch event {
case response.TransactionEventID:
if p.Type != request.TxFilterT {
return nil, response.ErrInvalidParams
}
case response.NotificationEventID:
if p.Type != request.NotificationFilterT {
return nil, response.ErrInvalidParams
}
case response.ExecutionEventID:
if p.Type != request.ExecutionFilterT {
return nil, response.ErrInvalidParams
}
}
filter = p.Value
}
s.subsLock.Lock() s.subsLock.Lock()
defer s.subsLock.Unlock() defer s.subsLock.Unlock()
select { select {
@ -1155,14 +1181,15 @@ func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface
} }
var id int var id int
for ; id < len(sub.feeds); id++ { for ; id < len(sub.feeds); id++ {
if sub.feeds[id] == response.InvalidEventID { if sub.feeds[id].event == response.InvalidEventID {
break break
} }
} }
if id == len(sub.feeds) { if id == len(sub.feeds) {
return nil, response.NewInternalServerError("maximum number of subscriptions is reached", nil) return nil, response.NewInternalServerError("maximum number of subscriptions is reached", nil)
} }
sub.feeds[id] = event sub.feeds[id].event = event
sub.feeds[id].filter = filter
s.subscribeToChannel(event) s.subscribeToChannel(event)
return strconv.FormatInt(int64(id), 10), nil return strconv.FormatInt(int64(id), 10), nil
} }
@ -1207,11 +1234,12 @@ func (s *Server) unsubscribe(reqParams request.Params, sub *subscriber) (interfa
} }
s.subsLock.Lock() s.subsLock.Lock()
defer s.subsLock.Unlock() defer s.subsLock.Unlock()
if len(sub.feeds) <= id || sub.feeds[id] == response.InvalidEventID { if len(sub.feeds) <= id || sub.feeds[id].event == response.InvalidEventID {
return nil, response.ErrInvalidParams return nil, response.ErrInvalidParams
} }
event := sub.feeds[id] event := sub.feeds[id].event
sub.feeds[id] = response.InvalidEventID sub.feeds[id].event = response.InvalidEventID
sub.feeds[id].filter = nil
s.unsubscribeFromChannel(event) s.unsubscribeFromChannel(event)
return true, nil return true, nil
} }
@ -1288,8 +1316,8 @@ chloop:
if sub.overflown.Load() { if sub.overflown.Load() {
continue continue
} }
for _, subID := range sub.feeds { for i := range sub.feeds {
if subID == resp.Event { if sub.feeds[i].Matches(&resp) {
if msg == nil { if msg == nil {
b, err = json.Marshal(resp) b, err = json.Marshal(resp)
if err != nil { if err != nil {

View file

@ -2,7 +2,10 @@ package server
import ( import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
"github.com/nspcc-dev/neo-go/pkg/rpc/response" "github.com/nspcc-dev/neo-go/pkg/rpc/response"
"github.com/nspcc-dev/neo-go/pkg/rpc/response/result"
"go.uber.org/atomic" "go.uber.org/atomic"
) )
@ -16,7 +19,11 @@ type (
// cheaper doing it this way rather than creating a map), // cheaper doing it this way rather than creating a map),
// pointing to EventID is an obvious overkill at the moment, but // pointing to EventID is an obvious overkill at the moment, but
// that's not for long. // that's not for long.
feeds [maxFeeds]response.EventID feeds [maxFeeds]feed
}
feed struct {
event response.EventID
filter interface{}
} }
) )
@ -34,3 +41,27 @@ const (
// a lot in terms of memory used. // a lot in terms of memory used.
notificationBufSize = 1024 notificationBufSize = 1024
) )
func (f *feed) Matches(r *response.Notification) bool {
if r.Event != f.event {
return false
}
if f.filter == nil {
return true
}
switch f.event {
case response.TransactionEventID:
filt := f.filter.(request.TxFilter)
tx := r.Payload[0].(*transaction.Transaction)
return tx.Type == filt.Type
case response.NotificationEventID:
filt := f.filter.(request.NotificationFilter)
notification := r.Payload[0].(result.NotificationEvent)
return notification.Contract.Equals(filt.Contract)
case response.ExecutionEventID:
filt := f.filter.(request.ExecutionFilter)
applog := r.Payload[0].(result.ApplicationLog)
return len(applog.Executions) != 0 && applog.Executions[0].VMState == filt.State
}
return false
}

View file

@ -62,6 +62,24 @@ func initCleanServerAndWSClient(t *testing.T) (*core.Blockchain, *Server, *webso
return chain, rpcSrv, ws, respMsgs, finishedFlag return chain, rpcSrv, ws, respMsgs, finishedFlag
} }
func callSubscribe(t *testing.T, ws *websocket.Conn, msgs <-chan []byte, params string) string {
var s string
resp := callWSGetRaw(t, ws, fmt.Sprintf(`{"jsonrpc": "2.0","method": "subscribe","params": %s,"id": 1}`, params), msgs)
require.Nil(t, resp.Error)
require.NotNil(t, resp.Result)
require.NoError(t, json.Unmarshal(resp.Result, &s))
return s
}
func callUnsubscribe(t *testing.T, ws *websocket.Conn, msgs <-chan []byte, id string) {
var b bool
resp := callWSGetRaw(t, ws, fmt.Sprintf(`{"jsonrpc": "2.0","method": "unsubscribe","params": ["%s"],"id": 1}`, id), msgs)
require.Nil(t, resp.Error)
require.NotNil(t, resp.Result)
require.NoError(t, json.Unmarshal(resp.Result, &b))
require.Equal(t, true, b)
}
func TestSubscriptions(t *testing.T) { func TestSubscriptions(t *testing.T) {
var subIDs = make([]string, 0) var subIDs = make([]string, 0)
var subFeeds = []string{"block_added", "transaction_added", "notification_from_execution", "transaction_executed"} var subFeeds = []string{"block_added", "transaction_added", "notification_from_execution", "transaction_executed"}
@ -72,16 +90,7 @@ func TestSubscriptions(t *testing.T) {
defer rpcSrv.Shutdown() defer rpcSrv.Shutdown()
for _, feed := range subFeeds { for _, feed := range subFeeds {
var s string s := callSubscribe(t, c, respMsgs, fmt.Sprintf(`["%s"]`, feed))
resp := callWSGetRaw(t, c, fmt.Sprintf(`{
"jsonrpc": "2.0",
"method": "subscribe",
"params": ["%s"],
"id": 1
}`, feed), respMsgs)
require.Nil(t, resp.Error)
require.NotNil(t, resp.Result)
require.NoError(t, json.Unmarshal(resp.Result, &s))
subIDs = append(subIDs, s) subIDs = append(subIDs, s)
} }
@ -109,23 +118,104 @@ func TestSubscriptions(t *testing.T) {
} }
for _, id := range subIDs { for _, id := range subIDs {
var b bool callUnsubscribe(t, c, respMsgs, id)
resp := callWSGetRaw(t, c, fmt.Sprintf(`{
"jsonrpc": "2.0",
"method": "unsubscribe",
"params": ["%s"],
"id": 1
}`, id), respMsgs)
require.Nil(t, resp.Error)
require.NotNil(t, resp.Result)
require.NoError(t, json.Unmarshal(resp.Result, &b))
require.Equal(t, true, b)
} }
finishedFlag.CAS(false, true) finishedFlag.CAS(false, true)
c.Close() c.Close()
} }
func TestFilteredSubscriptions(t *testing.T) {
var cases = map[string]struct {
params string
check func(*testing.T, *response.Notification)
}{
"tx matching": {
params: `["transaction_added", {"type":"InvocationTransaction"}]`,
check: func(t *testing.T, resp *response.Notification) {
rmap := resp.Payload[0].(map[string]interface{})
require.Equal(t, response.TransactionEventID, resp.Event)
typ := rmap["type"].(string)
require.Equal(t, "InvocationTransaction", typ)
},
},
"notification matching": {
params: `["notification_from_execution", {"contract":"` + testContractHash + `"}]`,
check: func(t *testing.T, resp *response.Notification) {
rmap := resp.Payload[0].(map[string]interface{})
require.Equal(t, response.NotificationEventID, resp.Event)
c := rmap["contract"].(string)
require.Equal(t, "0x"+testContractHash, c)
},
},
"execution matching": {
params: `["transaction_executed", {"state":"HALT"}]`,
check: func(t *testing.T, resp *response.Notification) {
rmap := resp.Payload[0].(map[string]interface{})
require.Equal(t, response.ExecutionEventID, resp.Event)
execs := rmap["executions"].([]interface{})
exec0 := execs[0].(map[string]interface{})
st := exec0["vmstate"].(string)
require.Equal(t, "HALT", st)
},
},
"tx non-matching": {
params: `["transaction_added", {"type":"EnrollmentTransaction"}]`,
check: func(t *testing.T, _ *response.Notification) {
t.Fatal("unexpected match for EnrollmentTransaction")
},
},
"notification non-matching": {
params: `["notification_from_execution", {"contract":"00112233445566778899aabbccddeeff00112233"}]`,
check: func(t *testing.T, _ *response.Notification) {
t.Fatal("unexpected match for contract 00112233445566778899aabbccddeeff00112233")
},
},
"execution non-matching": {
params: `["transaction_executed", {"state":"FAULT"}]`,
check: func(t *testing.T, _ *response.Notification) {
t.Fatal("unexpected match for faulted execution")
},
},
}
for name, this := range cases {
t.Run(name, func(t *testing.T) {
chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t)
defer chain.Close()
defer rpcSrv.Shutdown()
// It's used as an end-of-event-stream, so it's always present.
blockSubID := callSubscribe(t, c, respMsgs, `["block_added"]`)
subID := callSubscribe(t, c, respMsgs, this.params)
var lastBlock uint32
for _, b := range getTestBlocks(t) {
require.NoError(t, chain.AddBlock(b))
lastBlock = b.Index
}
for {
resp := getNotification(t, respMsgs)
rmap := resp.Payload[0].(map[string]interface{})
if resp.Event == response.BlockEventID {
index := rmap["height"].(float64)
if uint32(index) == lastBlock {
break
}
continue
}
this.check(t, resp)
}
callUnsubscribe(t, c, respMsgs, subID)
callUnsubscribe(t, c, respMsgs, blockSubID)
finishedFlag.CAS(false, true)
c.Close()
})
}
}
func TestMaxSubscriptions(t *testing.T) { func TestMaxSubscriptions(t *testing.T) {
var subIDs = make([]string, 0) var subIDs = make([]string, 0)
chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t) chain, rpcSrv, c, respMsgs, finishedFlag := initCleanServerAndWSClient(t)
@ -161,6 +251,12 @@ func TestBadSubUnsub(t *testing.T) {
"bad (non-string) event": `{"jsonrpc": "2.0", "method": "subscribe", "params": [1], "id": 1}`, "bad (non-string) event": `{"jsonrpc": "2.0", "method": "subscribe", "params": [1], "id": 1}`,
"bad (wrong) event": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["block_removed"], "id": 1}`, "bad (wrong) event": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["block_removed"], "id": 1}`,
"missed event": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["event_missed"], "id": 1}`, "missed event": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["event_missed"], "id": 1}`,
"block invalid filter": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["block_added", 1], "id": 1}`,
"tx filter 1": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["transaction_added", 1], "id": 1}`,
"tx filter 2": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["transaction_added", {"state": "HALT"}], "id": 1}`,
"notification filter": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["notification_from_execution", "contract"], "id": 1}`,
"execution filter 1": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["transaction_executed", "FAULT"], "id": 1}`,
"execution filter 2": `{"jsonrpc": "2.0", "method": "subscribe", "params": ["transaction_executed", {"state": "STOP"}], "id": 1}`,
} }
var unsubCases = map[string]string{ var unsubCases = map[string]string{
"no params": `{"jsonrpc": "2.0", "method": "unsubscribe", "params": [], "id": 1}`, "no params": `{"jsonrpc": "2.0", "method": "unsubscribe", "params": [], "id": 1}`,