Merge pull request #1625 from nspcc-dev/fix_oracle

core: restrict allowed Oracle callbacks
This commit is contained in:
Roman Khimov 2020-12-17 13:21:58 +03:00 committed by GitHub
commit c13d6ecc55
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 9 deletions

View file

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"math" "math"
"math/big" "math/big"
"strings"
"github.com/nspcc-dev/neo-go/pkg/core/dao" "github.com/nspcc-dev/neo-go/pkg/core/dao"
"github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/interop"
@ -275,6 +276,9 @@ func (o *Oracle) RequestInternal(ic *interop.Context, url string, filter *string
if len(url) > maxURLLength || (filter != nil && len(*filter) > maxFilterLength) || len(cb) > maxCallbackLength || gas.Uint64() < 1000_0000 { if len(url) > maxURLLength || (filter != nil && len(*filter) > maxFilterLength) || len(cb) > maxCallbackLength || gas.Uint64() < 1000_0000 {
return ErrBigArgument return ErrBigArgument
} }
if strings.HasPrefix(cb, "_") {
return errors.New("disallowed callback method (starts with '_')")
}
if !ic.VM.AddGas(gas.Int64()) { if !ic.VM.AddGas(gas.Int64()) {
return ErrNotEnoughGas return ErrNotEnoughGas

View file

@ -92,13 +92,13 @@ func getOracleContractState(h util.Uint160) *state.Contract {
} }
func putOracleRequest(t *testing.T, h util.Uint160, bc *Blockchain, func putOracleRequest(t *testing.T, h util.Uint160, bc *Blockchain,
url string, filter *string, userData []byte, gas int64) util.Uint256 { url string, filter *string, cb string, userData []byte, gas int64) util.Uint256 {
var filtItem interface{} var filtItem interface{}
if filter != nil { if filter != nil {
filtItem = *filter filtItem = *filter
} }
res, err := invokeContractMethod(bc, gas+50_000_000+5_000_000, h, "requestURL", res, err := invokeContractMethod(bc, gas+50_000_000+5_000_000, h, "requestURL",
url, filtItem, "handle", userData, gas) url, filtItem, cb, userData, gas)
require.NoError(t, err) require.NoError(t, err)
return res.Container return res.Container
} }
@ -114,7 +114,7 @@ func TestOracle_Request(t *testing.T) {
gasForResponse := int64(2000_1234) gasForResponse := int64(2000_1234)
var filter = "flt" var filter = "flt"
userData := []byte("custom info") userData := []byte("custom info")
txHash := putOracleRequest(t, cs.Hash, bc, "url", &filter, userData, gasForResponse) txHash := putOracleRequest(t, cs.Hash, bc, "url", &filter, "handle", userData, gasForResponse)
req, err := orc.GetRequestInternal(bc.dao, 1) req, err := orc.GetRequestInternal(bc.dao, 1)
require.NotNil(t, req) require.NotNil(t, req)
@ -191,7 +191,7 @@ func TestOracle_Request(t *testing.T) {
t.Run("ErrorOnFinish", func(t *testing.T) { t.Run("ErrorOnFinish", func(t *testing.T) {
const reqID = 2 const reqID = 2
putOracleRequest(t, cs.Hash, bc, "url", nil, []byte{1, 2}, gasForResponse) putOracleRequest(t, cs.Hash, bc, "url", nil, "handle", []byte{1, 2}, gasForResponse)
_, err := orc.GetRequestInternal(bc.dao, reqID) // ensure ID is 2 _, err := orc.GetRequestInternal(bc.dao, reqID) // ensure ID is 2
require.NoError(t, err) require.NoError(t, err)
@ -215,22 +215,25 @@ func TestOracle_Request(t *testing.T) {
require.Error(t, err) require.Error(t, err)
}) })
t.Run("BadRequest", func(t *testing.T) { t.Run("BadRequest", func(t *testing.T) {
var doBadRequest = func(t *testing.T, h util.Uint160, url string, filter *string, userData []byte, gas int64) { var doBadRequest = func(t *testing.T, h util.Uint160, url string, filter *string, cb string, userData []byte, gas int64) {
txHash := putOracleRequest(t, h, bc, url, filter, userData, gas) txHash := putOracleRequest(t, h, bc, url, filter, cb, userData, gas)
aer, err := bc.GetAppExecResults(txHash, trigger.Application) aer, err := bc.GetAppExecResults(txHash, trigger.Application)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(aer)) require.Equal(t, 1, len(aer))
require.Equal(t, vm.FaultState, aer[0].VMState) require.Equal(t, vm.FaultState, aer[0].VMState)
} }
t.Run("non-UTF8 url", func(t *testing.T) { t.Run("non-UTF8 url", func(t *testing.T) {
doBadRequest(t, cs.Hash, "\xff", nil, []byte{1, 2}, gasForResponse) doBadRequest(t, cs.Hash, "\xff", nil, "", []byte{1, 2}, gasForResponse)
}) })
t.Run("non-UTF8 filter", func(t *testing.T) { t.Run("non-UTF8 filter", func(t *testing.T) {
var f = "\xff" var f = "\xff"
doBadRequest(t, cs.Hash, "url", &f, []byte{1, 2}, gasForResponse) doBadRequest(t, cs.Hash, "url", &f, "", []byte{1, 2}, gasForResponse)
}) })
t.Run("not enough gas", func(t *testing.T) { t.Run("not enough gas", func(t *testing.T) {
doBadRequest(t, cs.Hash, "url", nil, nil, 1000) doBadRequest(t, cs.Hash, "url", nil, "", nil, 1000)
})
t.Run("disallowed callback", func(t *testing.T) {
doBadRequest(t, cs.Hash, "url", nil, "_deploy", nil, 1000_0000)
}) })
}) })
} }