diff --git a/pkg/config/oracle_config.go b/pkg/config/oracle_config.go index a12454f70..33c2e6136 100644 --- a/pkg/config/oracle_config.go +++ b/pkg/config/oracle_config.go @@ -6,6 +6,7 @@ import "time" type OracleConfiguration struct { Enabled bool `yaml:"Enabled"` AllowPrivateHost bool `yaml:"AllowPrivateHost"` + AllowedContentTypes []string `yaml:"AllowedContentTypes"` Nodes []string `yaml:"Nodes"` NeoFS NeoFSConfiguration `yaml:"NeoFS"` MaxTaskTimeout time.Duration `yaml:"MaxTaskTimeout"` diff --git a/pkg/core/oracle_test.go b/pkg/core/oracle_test.go index a9a263c6c..f1a2148d0 100644 --- a/pkg/core/oracle_test.go +++ b/pkg/core/oracle_test.go @@ -34,7 +34,8 @@ func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string) oracle.Config Log: zaptest.NewLogger(t), Network: netmode.UnitTestNet, MainCfg: config.OracleConfiguration{ - RefreshInterval: time.Second, + RefreshInterval: time.Second, + AllowedContentTypes: []string{"application/json"}, UnlockWallet: config.Wallet{ Path: path.Join(oracleModulePath, w), Password: pass, @@ -147,6 +148,8 @@ func TestOracle(t *testing.T) { putOracleRequest(t, cs.Hash, bc, "https://get.filter", &flt, "handle", []byte{}, 10_000_000) putOracleRequest(t, cs.Hash, bc, "https://get.filterinv", &flt, "handle", []byte{}, 10_000_000) + putOracleRequest(t, cs.Hash, bc, "https://get.invalidcontent", nil, "handle", []byte{}, 10_000_000) + checkResp := func(t *testing.T, id uint64, resp *transaction.OracleResponse) *state.OracleRequest { req, err := oracleCtr.GetRequestInternal(bc.dao, id) require.NoError(t, err) @@ -262,6 +265,12 @@ func TestOracle(t *testing.T) { }) }) }) + t.Run("InvalidContentType", func(t *testing.T) { + checkResp(t, 11, &transaction.OracleResponse{ + ID: 11, + Code: transaction.ContentTypeNotSupported, + }) + }) } func TestOracleFull(t *testing.T) { @@ -322,6 +331,7 @@ type ( testResponse struct { code int + ct string body []byte } ) @@ -332,7 +342,10 @@ func (c *httpClient) Do(req *http.Request) (*http.Response, error) { if ok { return &http.Response{ StatusCode: resp.code, - Body: newResponseBody(resp.body), + Header: http.Header{ + "Content-Type": {resp.ct}, + }, + Body: newResponseBody(resp.body), }, nil } return nil, errors.New("error during request") @@ -343,44 +356,59 @@ func newDefaultHTTPClient() oracle.HTTPClient { responses: map[string]testResponse{ "https://get.1234": { code: http.StatusOK, + ct: "application/json", body: []byte{1, 2, 3, 4}, }, "https://get.4321": { code: http.StatusOK, + ct: "application/json", body: []byte{4, 3, 2, 1}, }, "https://get.timeout": { code: http.StatusRequestTimeout, + ct: "application/json", body: []byte{}, }, "https://get.notfound": { code: http.StatusNotFound, + ct: "application/json", body: []byte{}, }, "https://get.forbidden": { code: http.StatusForbidden, + ct: "application/json", body: []byte{}, }, "https://private.url": { code: http.StatusOK, + ct: "application/json", body: []byte("passwords"), }, "https://get.big": { code: http.StatusOK, + ct: "application/json", body: make([]byte, transaction.MaxOracleResultSize+1), }, "https://get.maxallowed": { code: http.StatusOK, + ct: "application/json", body: make([]byte, transaction.MaxOracleResultSize), }, "https://get.filter": { code: http.StatusOK, + ct: "application/json", body: []byte(`{"Values":["one", 2, 3],"Another":null}`), }, "https://get.filterinv": { code: http.StatusOK, + ct: "application/json", body: []byte{0xFF}, }, + "https://get.invalidcontent": { + code: http.StatusOK, + ct: "image/gif", + body: []byte{1, 2, 3}, + }, }, } } diff --git a/pkg/core/transaction/oracle.go b/pkg/core/transaction/oracle.go index 3c89d644c..e9ee0598a 100644 --- a/pkg/core/transaction/oracle.go +++ b/pkg/core/transaction/oracle.go @@ -26,15 +26,16 @@ const MaxOracleResultSize = math.MaxUint16 // Enumeration of possible oracle response types. const ( - Success OracleResponseCode = 0x00 - ProtocolNotSupported OracleResponseCode = 0x10 - ConsensusUnreachable OracleResponseCode = 0x12 - NotFound OracleResponseCode = 0x14 - Timeout OracleResponseCode = 0x16 - Forbidden OracleResponseCode = 0x18 - ResponseTooLarge OracleResponseCode = 0x1a - InsufficientFunds OracleResponseCode = 0x1c - Error OracleResponseCode = 0xff + Success OracleResponseCode = 0x00 + ProtocolNotSupported OracleResponseCode = 0x10 + ConsensusUnreachable OracleResponseCode = 0x12 + NotFound OracleResponseCode = 0x14 + Timeout OracleResponseCode = 0x16 + Forbidden OracleResponseCode = 0x18 + ResponseTooLarge OracleResponseCode = 0x1a + InsufficientFunds OracleResponseCode = 0x1c + ContentTypeNotSupported OracleResponseCode = 0x1f + Error OracleResponseCode = 0xff ) // Various validation errors. diff --git a/pkg/core/transaction/oracleresponsecode_string.go b/pkg/core/transaction/oracleresponsecode_string.go index 53a57e45c..5ebd57e97 100644 --- a/pkg/core/transaction/oracleresponsecode_string.go +++ b/pkg/core/transaction/oracleresponsecode_string.go @@ -16,6 +16,7 @@ func _() { _ = x[Forbidden-24] _ = x[ResponseTooLarge-26] _ = x[InsufficientFunds-28] + _ = x[ContentTypeNotSupported-31] _ = x[Error-255] } @@ -28,7 +29,8 @@ const ( _OracleResponseCode_name_5 = "Forbidden" _OracleResponseCode_name_6 = "ResponseTooLarge" _OracleResponseCode_name_7 = "InsufficientFunds" - _OracleResponseCode_name_8 = "Error" + _OracleResponseCode_name_8 = "ContentTypeNotSupported" + _OracleResponseCode_name_9 = "Error" ) func (i OracleResponseCode) String() string { @@ -49,8 +51,10 @@ func (i OracleResponseCode) String() string { return _OracleResponseCode_name_6 case i == 28: return _OracleResponseCode_name_7 - case i == 255: + case i == 31: return _OracleResponseCode_name_8 + case i == 255: + return _OracleResponseCode_name_9 default: return "OracleResponseCode(" + strconv.FormatInt(int64(i), 10) + ")" } diff --git a/pkg/services/oracle/request.go b/pkg/services/oracle/request.go index 216ba60e8..6dee6334e 100644 --- a/pkg/services/oracle/request.go +++ b/pkg/services/oracle/request.go @@ -3,6 +3,7 @@ package oracle import ( "context" "errors" + "mime" "net/http" "net/url" "time" @@ -126,6 +127,11 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error { } switch r.StatusCode { case http.StatusOK: + if !checkMediaType(r.Header.Get("Content-Type"), o.MainCfg.AllowedContentTypes) { + resp.Code = transaction.ContentTypeNotSupported + break + } + result, err := readResponse(r.Body, transaction.MaxOracleResultSize) if err != nil { if errors.Is(err, ErrResponseTooLarge) { @@ -242,3 +248,21 @@ func (o *Oracle) processFailedRequest(priv *keys.PrivateKey, req request) { o.getOnTransaction()(readyTx) } } + +func checkMediaType(hdr string, allowed []string) bool { + if len(allowed) == 0 { + return true + } + + typ, _, err := mime.ParseMediaType(hdr) + if err != nil { + return false + } + + for _, ct := range allowed { + if ct == typ { + return true + } + } + return false +} diff --git a/pkg/services/oracle/request_test.go b/pkg/services/oracle/request_test.go new file mode 100644 index 000000000..60067dd1d --- /dev/null +++ b/pkg/services/oracle/request_test.go @@ -0,0 +1,19 @@ +package oracle + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCheckContentType(t *testing.T) { + allowedTypes := []string{"application/json", "text/plain"} + require.True(t, checkMediaType("application/json", allowedTypes)) + require.True(t, checkMediaType("application/json; param=value", allowedTypes)) + require.True(t, checkMediaType("text/plain; filename=file.txt", allowedTypes)) + + require.False(t, checkMediaType("image/gif", allowedTypes)) + require.True(t, checkMediaType("image/gif", nil)) + + require.False(t, checkMediaType("invalid format", allowedTypes)) +}