diff --git a/pkg/core/native/oracle.go b/pkg/core/native/oracle.go index 6ad40eeef..497373ebc 100644 --- a/pkg/core/native/oracle.go +++ b/pkg/core/native/oracle.go @@ -6,6 +6,7 @@ import ( "fmt" "math" "math/big" + "strings" "github.com/nspcc-dev/neo-go/pkg/core/dao" "github.com/nspcc-dev/neo-go/pkg/core/interop" @@ -275,6 +276,9 @@ func (o *Oracle) RequestInternal(ic *interop.Context, url string, filter *string if len(url) > maxURLLength || (filter != nil && len(*filter) > maxFilterLength) || len(cb) > maxCallbackLength || gas.Uint64() < 1000_0000 { return ErrBigArgument } + if strings.HasPrefix(cb, "_") { + return errors.New("disallowed callback method (starts with '_')") + } if !ic.VM.AddGas(gas.Int64()) { return ErrNotEnoughGas diff --git a/pkg/core/native_oracle_test.go b/pkg/core/native_oracle_test.go index 80555bb15..533f25d94 100644 --- a/pkg/core/native_oracle_test.go +++ b/pkg/core/native_oracle_test.go @@ -92,13 +92,13 @@ func getOracleContractState(h util.Uint160) *state.Contract { } func putOracleRequest(t *testing.T, h util.Uint160, bc *Blockchain, - url string, filter *string, userData []byte, gas int64) util.Uint256 { + url string, filter *string, cb string, userData []byte, gas int64) util.Uint256 { var filtItem interface{} if filter != nil { filtItem = *filter } res, err := invokeContractMethod(bc, gas+50_000_000+5_000_000, h, "requestURL", - url, filtItem, "handle", userData, gas) + url, filtItem, cb, userData, gas) require.NoError(t, err) return res.Container } @@ -114,7 +114,7 @@ func TestOracle_Request(t *testing.T) { gasForResponse := int64(2000_1234) var filter = "flt" userData := []byte("custom info") - txHash := putOracleRequest(t, cs.Hash, bc, "url", &filter, userData, gasForResponse) + txHash := putOracleRequest(t, cs.Hash, bc, "url", &filter, "handle", userData, gasForResponse) req, err := orc.GetRequestInternal(bc.dao, 1) require.NotNil(t, req) @@ -191,7 +191,7 @@ func TestOracle_Request(t *testing.T) { t.Run("ErrorOnFinish", func(t *testing.T) { const reqID = 2 - putOracleRequest(t, cs.Hash, bc, "url", nil, []byte{1, 2}, gasForResponse) + putOracleRequest(t, cs.Hash, bc, "url", nil, "handle", []byte{1, 2}, gasForResponse) _, err := orc.GetRequestInternal(bc.dao, reqID) // ensure ID is 2 require.NoError(t, err) @@ -215,22 +215,25 @@ func TestOracle_Request(t *testing.T) { require.Error(t, err) }) t.Run("BadRequest", func(t *testing.T) { - var doBadRequest = func(t *testing.T, h util.Uint160, url string, filter *string, userData []byte, gas int64) { - txHash := putOracleRequest(t, h, bc, url, filter, userData, gas) + var doBadRequest = func(t *testing.T, h util.Uint160, url string, filter *string, cb string, userData []byte, gas int64) { + txHash := putOracleRequest(t, h, bc, url, filter, cb, userData, gas) aer, err := bc.GetAppExecResults(txHash, trigger.Application) require.NoError(t, err) require.Equal(t, 1, len(aer)) require.Equal(t, vm.FaultState, aer[0].VMState) } t.Run("non-UTF8 url", func(t *testing.T) { - doBadRequest(t, cs.Hash, "\xff", nil, []byte{1, 2}, gasForResponse) + doBadRequest(t, cs.Hash, "\xff", nil, "", []byte{1, 2}, gasForResponse) }) t.Run("non-UTF8 filter", func(t *testing.T) { var f = "\xff" - doBadRequest(t, cs.Hash, "url", &f, []byte{1, 2}, gasForResponse) + doBadRequest(t, cs.Hash, "url", &f, "", []byte{1, 2}, gasForResponse) }) t.Run("not enough gas", func(t *testing.T) { - doBadRequest(t, cs.Hash, "url", nil, nil, 1000) + doBadRequest(t, cs.Hash, "url", nil, "", nil, 1000) + }) + t.Run("disallowed callback", func(t *testing.T) { + doBadRequest(t, cs.Hash, "url", nil, "_deploy", nil, 1000_0000) }) }) }