oracle: check response Content-Type
If not specified everything is allowed. Signed-off-by: Evgeniy Stratonikov <evgeniy@nspcc.ru>
This commit is contained in:
parent
1853d0c713
commit
8e9302f40b
6 changed files with 90 additions and 13 deletions
|
@ -6,6 +6,7 @@ import "time"
|
|||
type OracleConfiguration struct {
|
||||
Enabled bool `yaml:"Enabled"`
|
||||
AllowPrivateHost bool `yaml:"AllowPrivateHost"`
|
||||
AllowedContentTypes []string `yaml:"AllowedContentTypes"`
|
||||
Nodes []string `yaml:"Nodes"`
|
||||
NeoFS NeoFSConfiguration `yaml:"NeoFS"`
|
||||
MaxTaskTimeout time.Duration `yaml:"MaxTaskTimeout"`
|
||||
|
|
|
@ -34,7 +34,8 @@ func getOracleConfig(t *testing.T, bc *Blockchain, w, pass string) oracle.Config
|
|||
Log: zaptest.NewLogger(t),
|
||||
Network: netmode.UnitTestNet,
|
||||
MainCfg: config.OracleConfiguration{
|
||||
RefreshInterval: time.Second,
|
||||
RefreshInterval: time.Second,
|
||||
AllowedContentTypes: []string{"application/json"},
|
||||
UnlockWallet: config.Wallet{
|
||||
Path: path.Join(oracleModulePath, w),
|
||||
Password: pass,
|
||||
|
@ -147,6 +148,8 @@ func TestOracle(t *testing.T) {
|
|||
putOracleRequest(t, cs.Hash, bc, "https://get.filter", &flt, "handle", []byte{}, 10_000_000)
|
||||
putOracleRequest(t, cs.Hash, bc, "https://get.filterinv", &flt, "handle", []byte{}, 10_000_000)
|
||||
|
||||
putOracleRequest(t, cs.Hash, bc, "https://get.invalidcontent", nil, "handle", []byte{}, 10_000_000)
|
||||
|
||||
checkResp := func(t *testing.T, id uint64, resp *transaction.OracleResponse) *state.OracleRequest {
|
||||
req, err := oracleCtr.GetRequestInternal(bc.dao, id)
|
||||
require.NoError(t, err)
|
||||
|
@ -262,6 +265,12 @@ func TestOracle(t *testing.T) {
|
|||
})
|
||||
})
|
||||
})
|
||||
t.Run("InvalidContentType", func(t *testing.T) {
|
||||
checkResp(t, 11, &transaction.OracleResponse{
|
||||
ID: 11,
|
||||
Code: transaction.ContentTypeNotSupported,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOracleFull(t *testing.T) {
|
||||
|
@ -322,6 +331,7 @@ type (
|
|||
|
||||
testResponse struct {
|
||||
code int
|
||||
ct string
|
||||
body []byte
|
||||
}
|
||||
)
|
||||
|
@ -332,7 +342,10 @@ func (c *httpClient) Do(req *http.Request) (*http.Response, error) {
|
|||
if ok {
|
||||
return &http.Response{
|
||||
StatusCode: resp.code,
|
||||
Body: newResponseBody(resp.body),
|
||||
Header: http.Header{
|
||||
"Content-Type": {resp.ct},
|
||||
},
|
||||
Body: newResponseBody(resp.body),
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("error during request")
|
||||
|
@ -343,44 +356,59 @@ func newDefaultHTTPClient() oracle.HTTPClient {
|
|||
responses: map[string]testResponse{
|
||||
"https://get.1234": {
|
||||
code: http.StatusOK,
|
||||
ct: "application/json",
|
||||
body: []byte{1, 2, 3, 4},
|
||||
},
|
||||
"https://get.4321": {
|
||||
code: http.StatusOK,
|
||||
ct: "application/json",
|
||||
body: []byte{4, 3, 2, 1},
|
||||
},
|
||||
"https://get.timeout": {
|
||||
code: http.StatusRequestTimeout,
|
||||
ct: "application/json",
|
||||
body: []byte{},
|
||||
},
|
||||
"https://get.notfound": {
|
||||
code: http.StatusNotFound,
|
||||
ct: "application/json",
|
||||
body: []byte{},
|
||||
},
|
||||
"https://get.forbidden": {
|
||||
code: http.StatusForbidden,
|
||||
ct: "application/json",
|
||||
body: []byte{},
|
||||
},
|
||||
"https://private.url": {
|
||||
code: http.StatusOK,
|
||||
ct: "application/json",
|
||||
body: []byte("passwords"),
|
||||
},
|
||||
"https://get.big": {
|
||||
code: http.StatusOK,
|
||||
ct: "application/json",
|
||||
body: make([]byte, transaction.MaxOracleResultSize+1),
|
||||
},
|
||||
"https://get.maxallowed": {
|
||||
code: http.StatusOK,
|
||||
ct: "application/json",
|
||||
body: make([]byte, transaction.MaxOracleResultSize),
|
||||
},
|
||||
"https://get.filter": {
|
||||
code: http.StatusOK,
|
||||
ct: "application/json",
|
||||
body: []byte(`{"Values":["one", 2, 3],"Another":null}`),
|
||||
},
|
||||
"https://get.filterinv": {
|
||||
code: http.StatusOK,
|
||||
ct: "application/json",
|
||||
body: []byte{0xFF},
|
||||
},
|
||||
"https://get.invalidcontent": {
|
||||
code: http.StatusOK,
|
||||
ct: "image/gif",
|
||||
body: []byte{1, 2, 3},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,15 +26,16 @@ const MaxOracleResultSize = math.MaxUint16
|
|||
|
||||
// Enumeration of possible oracle response types.
|
||||
const (
|
||||
Success OracleResponseCode = 0x00
|
||||
ProtocolNotSupported OracleResponseCode = 0x10
|
||||
ConsensusUnreachable OracleResponseCode = 0x12
|
||||
NotFound OracleResponseCode = 0x14
|
||||
Timeout OracleResponseCode = 0x16
|
||||
Forbidden OracleResponseCode = 0x18
|
||||
ResponseTooLarge OracleResponseCode = 0x1a
|
||||
InsufficientFunds OracleResponseCode = 0x1c
|
||||
Error OracleResponseCode = 0xff
|
||||
Success OracleResponseCode = 0x00
|
||||
ProtocolNotSupported OracleResponseCode = 0x10
|
||||
ConsensusUnreachable OracleResponseCode = 0x12
|
||||
NotFound OracleResponseCode = 0x14
|
||||
Timeout OracleResponseCode = 0x16
|
||||
Forbidden OracleResponseCode = 0x18
|
||||
ResponseTooLarge OracleResponseCode = 0x1a
|
||||
InsufficientFunds OracleResponseCode = 0x1c
|
||||
ContentTypeNotSupported OracleResponseCode = 0x1f
|
||||
Error OracleResponseCode = 0xff
|
||||
)
|
||||
|
||||
// Various validation errors.
|
||||
|
|
|
@ -16,6 +16,7 @@ func _() {
|
|||
_ = x[Forbidden-24]
|
||||
_ = x[ResponseTooLarge-26]
|
||||
_ = x[InsufficientFunds-28]
|
||||
_ = x[ContentTypeNotSupported-31]
|
||||
_ = x[Error-255]
|
||||
}
|
||||
|
||||
|
@ -28,7 +29,8 @@ const (
|
|||
_OracleResponseCode_name_5 = "Forbidden"
|
||||
_OracleResponseCode_name_6 = "ResponseTooLarge"
|
||||
_OracleResponseCode_name_7 = "InsufficientFunds"
|
||||
_OracleResponseCode_name_8 = "Error"
|
||||
_OracleResponseCode_name_8 = "ContentTypeNotSupported"
|
||||
_OracleResponseCode_name_9 = "Error"
|
||||
)
|
||||
|
||||
func (i OracleResponseCode) String() string {
|
||||
|
@ -49,8 +51,10 @@ func (i OracleResponseCode) String() string {
|
|||
return _OracleResponseCode_name_6
|
||||
case i == 28:
|
||||
return _OracleResponseCode_name_7
|
||||
case i == 255:
|
||||
case i == 31:
|
||||
return _OracleResponseCode_name_8
|
||||
case i == 255:
|
||||
return _OracleResponseCode_name_9
|
||||
default:
|
||||
return "OracleResponseCode(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package oracle
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
@ -126,6 +127,11 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error {
|
|||
}
|
||||
switch r.StatusCode {
|
||||
case http.StatusOK:
|
||||
if !checkMediaType(r.Header.Get("Content-Type"), o.MainCfg.AllowedContentTypes) {
|
||||
resp.Code = transaction.ContentTypeNotSupported
|
||||
break
|
||||
}
|
||||
|
||||
result, err := readResponse(r.Body, transaction.MaxOracleResultSize)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrResponseTooLarge) {
|
||||
|
@ -242,3 +248,21 @@ func (o *Oracle) processFailedRequest(priv *keys.PrivateKey, req request) {
|
|||
o.getOnTransaction()(readyTx)
|
||||
}
|
||||
}
|
||||
|
||||
func checkMediaType(hdr string, allowed []string) bool {
|
||||
if len(allowed) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
typ, _, err := mime.ParseMediaType(hdr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, ct := range allowed {
|
||||
if ct == typ {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
19
pkg/services/oracle/request_test.go
Normal file
19
pkg/services/oracle/request_test.go
Normal file
|
@ -0,0 +1,19 @@
|
|||
package oracle
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCheckContentType(t *testing.T) {
|
||||
allowedTypes := []string{"application/json", "text/plain"}
|
||||
require.True(t, checkMediaType("application/json", allowedTypes))
|
||||
require.True(t, checkMediaType("application/json; param=value", allowedTypes))
|
||||
require.True(t, checkMediaType("text/plain; filename=file.txt", allowedTypes))
|
||||
|
||||
require.False(t, checkMediaType("image/gif", allowedTypes))
|
||||
require.True(t, checkMediaType("image/gif", nil))
|
||||
|
||||
require.False(t, checkMediaType("invalid format", allowedTypes))
|
||||
}
|
Loading…
Reference in a new issue