diff --git a/pkg/rpc/client/wsclient_test.go b/pkg/rpc/client/wsclient_test.go index 24b2dc4a8..3d8de1897 100644 --- a/pkg/rpc/client/wsclient_test.go +++ b/pkg/rpc/client/wsclient_test.go @@ -2,6 +2,7 @@ package client import ( "context" + "encoding/json" "fmt" "net/http" "net/http/httptest" @@ -186,10 +187,10 @@ func TestWSFilteredSubscriptions(t *testing.T) { }, func(t *testing.T, p *request.Params) { param := p.Value(1) - require.NotNil(t, param) - require.Equal(t, request.BlockFilterT, param.Type) - filt, ok := param.Value.(request.BlockFilter) - require.Equal(t, true, ok) + raw, ok := param.Value.(json.RawMessage) + require.True(t, ok) + filt := new(request.BlockFilter) + require.NoError(t, json.Unmarshal(raw, filt)) require.Equal(t, 3, filt.Primary) }, }, @@ -201,10 +202,10 @@ func TestWSFilteredSubscriptions(t *testing.T) { }, func(t *testing.T, p *request.Params) { param := p.Value(1) - require.NotNil(t, param) - require.Equal(t, request.TxFilterT, param.Type) - filt, ok := param.Value.(request.TxFilter) - require.Equal(t, true, ok) + raw, ok := param.Value.(json.RawMessage) + require.True(t, ok) + filt := new(request.TxFilter) + require.NoError(t, json.Unmarshal(raw, filt)) require.Equal(t, util.Uint160{1, 2, 3, 4, 5}, *filt.Sender) require.Nil(t, filt.Signer) }, @@ -217,10 +218,10 @@ func TestWSFilteredSubscriptions(t *testing.T) { }, func(t *testing.T, p *request.Params) { param := p.Value(1) - require.NotNil(t, param) - require.Equal(t, request.TxFilterT, param.Type) - filt, ok := param.Value.(request.TxFilter) - require.Equal(t, true, ok) + raw, ok := param.Value.(json.RawMessage) + require.True(t, ok) + filt := new(request.TxFilter) + require.NoError(t, json.Unmarshal(raw, filt)) require.Nil(t, filt.Sender) require.Equal(t, util.Uint160{0, 42}, *filt.Signer) }, @@ -234,10 +235,10 @@ func TestWSFilteredSubscriptions(t *testing.T) { }, func(t *testing.T, p *request.Params) { param := p.Value(1) - require.NotNil(t, param) - require.Equal(t, request.TxFilterT, param.Type) - filt, ok := param.Value.(request.TxFilter) - require.Equal(t, true, ok) + raw, ok := param.Value.(json.RawMessage) + require.True(t, ok) + filt := new(request.TxFilter) + require.NoError(t, json.Unmarshal(raw, filt)) require.Equal(t, util.Uint160{1, 2, 3, 4, 5}, *filt.Sender) require.Equal(t, util.Uint160{0, 42}, *filt.Signer) }, @@ -250,10 +251,10 @@ func TestWSFilteredSubscriptions(t *testing.T) { }, func(t *testing.T, p *request.Params) { param := p.Value(1) - require.NotNil(t, param) - require.Equal(t, request.NotificationFilterT, param.Type) - filt, ok := param.Value.(request.NotificationFilter) - require.Equal(t, true, ok) + raw, ok := param.Value.(json.RawMessage) + require.True(t, ok) + filt := new(request.NotificationFilter) + require.NoError(t, json.Unmarshal(raw, filt)) require.Equal(t, util.Uint160{1, 2, 3, 4, 5}, *filt.Contract) require.Nil(t, filt.Name) }, @@ -266,10 +267,10 @@ func TestWSFilteredSubscriptions(t *testing.T) { }, func(t *testing.T, p *request.Params) { param := p.Value(1) - require.NotNil(t, param) - require.Equal(t, request.NotificationFilterT, param.Type) - filt, ok := param.Value.(request.NotificationFilter) - require.Equal(t, true, ok) + raw, ok := param.Value.(json.RawMessage) + require.True(t, ok) + filt := new(request.NotificationFilter) + require.NoError(t, json.Unmarshal(raw, filt)) require.Equal(t, "my_pretty_notification", *filt.Name) require.Nil(t, filt.Contract) }, @@ -283,10 +284,10 @@ func TestWSFilteredSubscriptions(t *testing.T) { }, func(t *testing.T, p *request.Params) { param := p.Value(1) - require.NotNil(t, param) - require.Equal(t, request.NotificationFilterT, param.Type) - filt, ok := param.Value.(request.NotificationFilter) - require.Equal(t, true, ok) + raw, ok := param.Value.(json.RawMessage) + require.True(t, ok) + filt := new(request.NotificationFilter) + require.NoError(t, json.Unmarshal(raw, filt)) require.Equal(t, util.Uint160{1, 2, 3, 4, 5}, *filt.Contract) require.Equal(t, "my_pretty_notification", *filt.Name) }, @@ -299,10 +300,10 @@ func TestWSFilteredSubscriptions(t *testing.T) { }, func(t *testing.T, p *request.Params) { param := p.Value(1) - require.NotNil(t, param) - require.Equal(t, request.ExecutionFilterT, param.Type) - filt, ok := param.Value.(request.ExecutionFilter) - require.Equal(t, true, ok) + raw, ok := param.Value.(json.RawMessage) + require.True(t, ok) + filt := new(request.ExecutionFilter) + require.NoError(t, json.Unmarshal(raw, filt)) require.Equal(t, "FAULT", filt.State) }, }, diff --git a/pkg/rpc/request/param.go b/pkg/rpc/request/param.go index d779f5d00..03834d050 100644 --- a/pkg/rpc/request/param.go +++ b/pkg/rpc/request/param.go @@ -134,9 +134,20 @@ func (p *Param) GetArray() ([]Param, error) { return nil, errMissingParameter } a, ok := p.Value.([]Param) + if ok { + return a, nil + } + raw, ok := p.Value.(json.RawMessage) if !ok { return nil, errors.New("not an array") } + + a = []Param{} + err := json.Unmarshal(raw, &a) + if err != nil { + return nil, errors.New("not an array") + } + p.Value = a return a, nil } @@ -190,9 +201,18 @@ func (p *Param) GetFuncParam() (FuncParam, error) { return FuncParam{}, errMissingParameter } fp, ok := p.Value.(FuncParam) + if ok { + return fp, nil + } + raw, ok := p.Value.(json.RawMessage) if !ok { return FuncParam{}, errors.New("not a function parameter") } + err := json.Unmarshal(raw, &fp) + if err != nil { + return fp, err + } + p.Value = fp return fp, nil } @@ -219,11 +239,38 @@ func (p *Param) GetBytesBase64() ([]byte, error) { } // GetSignerWithWitness returns SignerWithWitness value of the parameter. -func (p Param) GetSignerWithWitness() (SignerWithWitness, error) { +func (p *Param) GetSignerWithWitness() (SignerWithWitness, error) { c, ok := p.Value.(SignerWithWitness) + if ok { + return c, nil + } + raw, ok := p.Value.(json.RawMessage) if !ok { return SignerWithWitness{}, errors.New("not a signer") } + aux := new(signerWithWitnessAux) + err := json.Unmarshal(raw, aux) + if err != nil { + return SignerWithWitness{}, errors.New("not a signer") + } + accParam := Param{StringT, aux.Account} + acc, err := accParam.GetUint160FromAddressOrHex() + if err != nil { + return SignerWithWitness{}, errors.New("not a signer") + } + c = SignerWithWitness{ + Signer: transaction.Signer{ + Account: acc, + Scopes: aux.Scopes, + AllowedContracts: aux.AllowedContracts, + AllowedGroups: aux.AllowedGroups, + }, + Witness: transaction.Witness{ + InvocationScript: aux.InvocationScript, + VerificationScript: aux.VerificationScript, + }, + } + p.Value = c return c, nil } @@ -264,83 +311,46 @@ func (p Param) GetSignersWithWitnesses() ([]transaction.Signer, []transaction.Wi // UnmarshalJSON implements json.Unmarshaler interface. func (p *Param) UnmarshalJSON(data []byte) error { - var s string - var num float64 - var b bool - // To unmarshal correctly we need to pass pointers into the decoder. - var attempts = [...]Param{ - {NumberT, &num}, - {BooleanT, &b}, - {StringT, &s}, - {FuncParamT, &FuncParam{}}, - {BlockFilterT, &BlockFilter{}}, - {TxFilterT, &TxFilter{}}, - {NotificationFilterT, &NotificationFilter{}}, - {ExecutionFilterT, &ExecutionFilter{}}, - {SignerWithWitnessT, &signerWithWitnessAux{}}, - {ArrayT, &[]Param{}}, + r := bytes.NewReader(data) + jd := json.NewDecoder(r) + jd.UseNumber() + tok, err := jd.Token() + if err != nil { + return err } - - if bytes.Equal(data, []byte("null")) { - p.Type = defaultT - return nil - } - - for _, cur := range attempts { - r := bytes.NewReader(data) - jd := json.NewDecoder(r) - jd.DisallowUnknownFields() - if err := jd.Decode(cur.Value); err == nil { - p.Type = cur.Type - // But we need to store actual values, not pointers. - switch val := cur.Value.(type) { - case *float64: - p.Value = int(*val) - case *string: - p.Value = *val - case *bool: - p.Value = *val - case *FuncParam: - p.Value = *val - case *BlockFilter: - p.Value = *val - case *TxFilter: - p.Value = *val - case *NotificationFilter: - p.Value = *val - case *ExecutionFilter: - if (*val).State == "HALT" || (*val).State == "FAULT" { - p.Value = *val - } else { - continue - } - case *signerWithWitnessAux: - aux := *val - accParam := Param{StringT, aux.Account} - acc, err := accParam.GetUint160FromAddressOrHex() - if err != nil { - return err - } - p.Value = SignerWithWitness{ - Signer: transaction.Signer{ - Account: acc, - Scopes: aux.Scopes, - AllowedContracts: aux.AllowedContracts, - AllowedGroups: aux.AllowedGroups, - }, - Witness: transaction.Witness{ - InvocationScript: aux.InvocationScript, - VerificationScript: aux.VerificationScript, - }, - } - case *[]Param: - p.Value = *val + switch t := tok.(type) { + case json.Delim: + if t == json.Delim('[') { + var arr []Param + err := json.Unmarshal(data, &arr) + if err != nil { + return err } - return nil + p.Type = ArrayT + p.Value = arr + } else { + p.Type = defaultT + p.Value = json.RawMessage(data) } + case bool: + p.Type = BooleanT + p.Value = t + case float64: // unexpected because of `UseNumber`. + panic("unexpected") + case json.Number: + value, err := strconv.Atoi(string(t)) + if err != nil { + return err + } + p.Type = NumberT + p.Value = value + case string: + p.Type = StringT + p.Value = t + default: // null + p.Type = defaultT } - - return errors.New("unknown type") + return nil } // signerWithWitnessAux is an auxiluary struct for JSON marshalling. We need it because of diff --git a/pkg/rpc/request/param_test.go b/pkg/rpc/request/param_test.go index 61c3f8aa0..7303006c9 100644 --- a/pkg/rpc/request/param_test.go +++ b/pkg/rpc/request/param_test.go @@ -16,24 +16,9 @@ import ( func TestParam_UnmarshalJSON(t *testing.T) { msg := `["str1", 123, null, ["str2", 3], [{"type": "String", "value": "jajaja"}], - {"primary": 1}, - {"sender": "f84d6a337fbc3d3a201d41da99e86b479e7a2554"}, - {"signer": "f84d6a337fbc3d3a201d41da99e86b479e7a2554"}, - {"sender": "f84d6a337fbc3d3a201d41da99e86b479e7a2554", "signer": "f84d6a337fbc3d3a201d41da99e86b479e7a2554"}, - {"contract": "f84d6a337fbc3d3a201d41da99e86b479e7a2554"}, - {"name": "my_pretty_notification"}, - {"contract": "f84d6a337fbc3d3a201d41da99e86b479e7a2554", "name":"my_pretty_notification"}, - {"state": "HALT"}, {"account": "0xcadb3dc2faa3ef14a13b619c9a43124755aa2569"}, {"account": "NYxb4fSZVKAz8YsgaPK2WkT3KcAE9b3Vag", "scopes": "Global"}, [{"account": "0xcadb3dc2faa3ef14a13b619c9a43124755aa2569", "scopes": "Global"}]]` - contr, err := util.Uint160DecodeStringLE("f84d6a337fbc3d3a201d41da99e86b479e7a2554") - require.NoError(t, err) - name := "my_pretty_notification" - accountHash, err := util.Uint160DecodeStringLE("cadb3dc2faa3ef14a13b619c9a43124755aa2569") - require.NoError(t, err) - addrHash, err := address.StringToUint160("NYxb4fSZVKAz8YsgaPK2WkT3KcAE9b3Vag") - require.NoError(t, err) expected := Params{ { Type: StringT, @@ -63,78 +48,25 @@ func TestParam_UnmarshalJSON(t *testing.T) { Type: ArrayT, Value: []Param{ { - Type: FuncParamT, - Value: FuncParam{ - Type: smartcontract.StringType, - Value: Param{ - Type: StringT, - Value: "jajaja", - }, - }, + Type: defaultT, + Value: json.RawMessage(`{"type": "String", "value": "jajaja"}`), }, }, }, { - Type: BlockFilterT, - Value: BlockFilter{Primary: 1}, + Type: defaultT, + Value: json.RawMessage(`{"account": "0xcadb3dc2faa3ef14a13b619c9a43124755aa2569"}`), }, { - Type: TxFilterT, - Value: TxFilter{Sender: &contr}, - }, - { - Type: TxFilterT, - Value: TxFilter{Signer: &contr}, - }, - { - Type: TxFilterT, - Value: TxFilter{Sender: &contr, Signer: &contr}, - }, - { - Type: NotificationFilterT, - Value: NotificationFilter{Contract: &contr}, - }, - { - Type: NotificationFilterT, - Value: NotificationFilter{Name: &name}, - }, - { - Type: NotificationFilterT, - Value: NotificationFilter{Contract: &contr, Name: &name}, - }, - { - Type: ExecutionFilterT, - Value: ExecutionFilter{State: "HALT"}, - }, - { - Type: SignerWithWitnessT, - Value: SignerWithWitness{ - Signer: transaction.Signer{ - Account: accountHash, - Scopes: transaction.None, - }, - }, - }, - { - Type: SignerWithWitnessT, - Value: SignerWithWitness{ - Signer: transaction.Signer{ - Account: addrHash, - Scopes: transaction.Global, - }, - }, + Type: defaultT, + Value: json.RawMessage(`{"account": "NYxb4fSZVKAz8YsgaPK2WkT3KcAE9b3Vag", "scopes": "Global"}`), }, { Type: ArrayT, Value: []Param{ { - Type: SignerWithWitnessT, - Value: SignerWithWitness{ - Signer: transaction.Signer{ - Account: accountHash, - Scopes: transaction.Global, - }, - }, + Type: defaultT, + Value: json.RawMessage(`{"account": "0xcadb3dc2faa3ef14a13b619c9a43124755aa2569", "scopes": "Global"}`), }, }, }, @@ -143,11 +75,49 @@ func TestParam_UnmarshalJSON(t *testing.T) { var ps Params require.NoError(t, json.Unmarshal([]byte(msg), &ps)) require.Equal(t, expected, ps) +} - msg = `[{"2": 3}]` - require.Error(t, json.Unmarshal([]byte(msg), &ps)) - msg = `[{"account": "notanaccount", "scopes": "Global"}]` - require.Error(t, json.Unmarshal([]byte(msg), &ps)) +func TestGetWitness(t *testing.T) { + accountHash, err := util.Uint160DecodeStringLE("cadb3dc2faa3ef14a13b619c9a43124755aa2569") + require.NoError(t, err) + addrHash, err := address.StringToUint160("NYxb4fSZVKAz8YsgaPK2WkT3KcAE9b3Vag") + require.NoError(t, err) + + testCases := []struct { + raw string + expected SignerWithWitness + }{ + {`{"account": "0xcadb3dc2faa3ef14a13b619c9a43124755aa2569"}`, SignerWithWitness{ + Signer: transaction.Signer{ + Account: accountHash, + Scopes: transaction.None, + }}, + }, + {`{"account": "NYxb4fSZVKAz8YsgaPK2WkT3KcAE9b3Vag", "scopes": "Global"}`, SignerWithWitness{ + Signer: transaction.Signer{ + Account: addrHash, + Scopes: transaction.Global, + }}, + }, + {`{"account": "0xcadb3dc2faa3ef14a13b619c9a43124755aa2569", "scopes": "Global"}`, SignerWithWitness{ + Signer: transaction.Signer{ + Account: accountHash, + Scopes: transaction.Global, + }}, + }, + } + + for _, tc := range testCases { + p := Param{Value: json.RawMessage(tc.raw)} + actual, err := p.GetSignerWithWitness() + require.NoError(t, err) + require.Equal(t, tc.expected, actual) + require.Equal(t, tc.expected, p.Value) + + actual, err = p.GetSignerWithWitness() // valid second invocation. + require.NoError(t, err) + require.Equal(t, tc.expected, actual) + } } func TestParamGetString(t *testing.T) { diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index 76f967d8b..3fa82c659 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "context" "crypto/elliptic" "encoding/binary" @@ -1496,28 +1497,41 @@ func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface // Optional filter. var filter interface{} if p := reqParams.Value(1); p != nil { + param, ok := p.Value.(json.RawMessage) + if !ok { + return nil, response.ErrInvalidParams + } + jd := json.NewDecoder(bytes.NewReader(param)) + jd.DisallowUnknownFields() switch event { case response.BlockEventID: - if p.Type != request.BlockFilterT { - return nil, response.ErrInvalidParams - } - case response.TransactionEventID: - if p.Type != request.TxFilterT { - return nil, response.ErrInvalidParams - } + flt := new(request.BlockFilter) + err = jd.Decode(flt) + p.Type = request.BlockFilterT + p.Value = *flt + case response.TransactionEventID, response.NotaryRequestEventID: + flt := new(request.TxFilter) + err = jd.Decode(flt) + p.Type = request.TxFilterT + p.Value = *flt case response.NotificationEventID: - if p.Type != request.NotificationFilterT { - return nil, response.ErrInvalidParams - } + flt := new(request.NotificationFilter) + err = jd.Decode(flt) + p.Type = request.NotificationFilterT + p.Value = *flt case response.ExecutionEventID: - if p.Type != request.ExecutionFilterT { - return nil, response.ErrInvalidParams - } - case response.NotaryRequestEventID: - if p.Type != request.TxFilterT { - return nil, response.ErrInvalidParams + flt := new(request.ExecutionFilter) + err = jd.Decode(flt) + if err == nil && (flt.State == "HALT" || flt.State == "FAULT") { + p.Type = request.ExecutionFilterT + p.Value = *flt + } else if err == nil { + err = errors.New("invalid state") } } + if err != nil { + return nil, response.ErrInvalidParams + } filter = p.Value }