forked from TrueCloudLab/neoneo-go
Merge pull request #1625 from nspcc-dev/fix_oracle
core: restrict allowed Oracle callbacks
This commit is contained in:
commit
c13d6ecc55
2 changed files with 16 additions and 9 deletions
|
@ -6,6 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/dao"
|
"github.com/nspcc-dev/neo-go/pkg/core/dao"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
||||||
|
@ -275,6 +276,9 @@ func (o *Oracle) RequestInternal(ic *interop.Context, url string, filter *string
|
||||||
if len(url) > maxURLLength || (filter != nil && len(*filter) > maxFilterLength) || len(cb) > maxCallbackLength || gas.Uint64() < 1000_0000 {
|
if len(url) > maxURLLength || (filter != nil && len(*filter) > maxFilterLength) || len(cb) > maxCallbackLength || gas.Uint64() < 1000_0000 {
|
||||||
return ErrBigArgument
|
return ErrBigArgument
|
||||||
}
|
}
|
||||||
|
if strings.HasPrefix(cb, "_") {
|
||||||
|
return errors.New("disallowed callback method (starts with '_')")
|
||||||
|
}
|
||||||
|
|
||||||
if !ic.VM.AddGas(gas.Int64()) {
|
if !ic.VM.AddGas(gas.Int64()) {
|
||||||
return ErrNotEnoughGas
|
return ErrNotEnoughGas
|
||||||
|
|
|
@ -92,13 +92,13 @@ func getOracleContractState(h util.Uint160) *state.Contract {
|
||||||
}
|
}
|
||||||
|
|
||||||
func putOracleRequest(t *testing.T, h util.Uint160, bc *Blockchain,
|
func putOracleRequest(t *testing.T, h util.Uint160, bc *Blockchain,
|
||||||
url string, filter *string, userData []byte, gas int64) util.Uint256 {
|
url string, filter *string, cb string, userData []byte, gas int64) util.Uint256 {
|
||||||
var filtItem interface{}
|
var filtItem interface{}
|
||||||
if filter != nil {
|
if filter != nil {
|
||||||
filtItem = *filter
|
filtItem = *filter
|
||||||
}
|
}
|
||||||
res, err := invokeContractMethod(bc, gas+50_000_000+5_000_000, h, "requestURL",
|
res, err := invokeContractMethod(bc, gas+50_000_000+5_000_000, h, "requestURL",
|
||||||
url, filtItem, "handle", userData, gas)
|
url, filtItem, cb, userData, gas)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return res.Container
|
return res.Container
|
||||||
}
|
}
|
||||||
|
@ -114,7 +114,7 @@ func TestOracle_Request(t *testing.T) {
|
||||||
gasForResponse := int64(2000_1234)
|
gasForResponse := int64(2000_1234)
|
||||||
var filter = "flt"
|
var filter = "flt"
|
||||||
userData := []byte("custom info")
|
userData := []byte("custom info")
|
||||||
txHash := putOracleRequest(t, cs.Hash, bc, "url", &filter, userData, gasForResponse)
|
txHash := putOracleRequest(t, cs.Hash, bc, "url", &filter, "handle", userData, gasForResponse)
|
||||||
|
|
||||||
req, err := orc.GetRequestInternal(bc.dao, 1)
|
req, err := orc.GetRequestInternal(bc.dao, 1)
|
||||||
require.NotNil(t, req)
|
require.NotNil(t, req)
|
||||||
|
@ -191,7 +191,7 @@ func TestOracle_Request(t *testing.T) {
|
||||||
t.Run("ErrorOnFinish", func(t *testing.T) {
|
t.Run("ErrorOnFinish", func(t *testing.T) {
|
||||||
const reqID = 2
|
const reqID = 2
|
||||||
|
|
||||||
putOracleRequest(t, cs.Hash, bc, "url", nil, []byte{1, 2}, gasForResponse)
|
putOracleRequest(t, cs.Hash, bc, "url", nil, "handle", []byte{1, 2}, gasForResponse)
|
||||||
_, err := orc.GetRequestInternal(bc.dao, reqID) // ensure ID is 2
|
_, err := orc.GetRequestInternal(bc.dao, reqID) // ensure ID is 2
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -215,22 +215,25 @@ func TestOracle_Request(t *testing.T) {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
})
|
})
|
||||||
t.Run("BadRequest", func(t *testing.T) {
|
t.Run("BadRequest", func(t *testing.T) {
|
||||||
var doBadRequest = func(t *testing.T, h util.Uint160, url string, filter *string, userData []byte, gas int64) {
|
var doBadRequest = func(t *testing.T, h util.Uint160, url string, filter *string, cb string, userData []byte, gas int64) {
|
||||||
txHash := putOracleRequest(t, h, bc, url, filter, userData, gas)
|
txHash := putOracleRequest(t, h, bc, url, filter, cb, userData, gas)
|
||||||
aer, err := bc.GetAppExecResults(txHash, trigger.Application)
|
aer, err := bc.GetAppExecResults(txHash, trigger.Application)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 1, len(aer))
|
require.Equal(t, 1, len(aer))
|
||||||
require.Equal(t, vm.FaultState, aer[0].VMState)
|
require.Equal(t, vm.FaultState, aer[0].VMState)
|
||||||
}
|
}
|
||||||
t.Run("non-UTF8 url", func(t *testing.T) {
|
t.Run("non-UTF8 url", func(t *testing.T) {
|
||||||
doBadRequest(t, cs.Hash, "\xff", nil, []byte{1, 2}, gasForResponse)
|
doBadRequest(t, cs.Hash, "\xff", nil, "", []byte{1, 2}, gasForResponse)
|
||||||
})
|
})
|
||||||
t.Run("non-UTF8 filter", func(t *testing.T) {
|
t.Run("non-UTF8 filter", func(t *testing.T) {
|
||||||
var f = "\xff"
|
var f = "\xff"
|
||||||
doBadRequest(t, cs.Hash, "url", &f, []byte{1, 2}, gasForResponse)
|
doBadRequest(t, cs.Hash, "url", &f, "", []byte{1, 2}, gasForResponse)
|
||||||
})
|
})
|
||||||
t.Run("not enough gas", func(t *testing.T) {
|
t.Run("not enough gas", func(t *testing.T) {
|
||||||
doBadRequest(t, cs.Hash, "url", nil, nil, 1000)
|
doBadRequest(t, cs.Hash, "url", nil, "", nil, 1000)
|
||||||
|
})
|
||||||
|
t.Run("disallowed callback", func(t *testing.T) {
|
||||||
|
doBadRequest(t, cs.Hash, "url", nil, "_deploy", nil, 1000_0000)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue