Merge pull request #1883 from nspcc-dev/fix-oracle-request-processing

Fix oracle request processing
This commit is contained in:
Roman Khimov 2021-04-06 17:25:22 +03:00 committed by GitHub
commit fc42b77916
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 90 additions and 69 deletions

View file

@ -18,6 +18,6 @@ type OracleConfiguration struct {
// NeoFSConfiguration is a config for the NeoFS service. // NeoFSConfiguration is a config for the NeoFS service.
type NeoFSConfiguration struct { type NeoFSConfiguration struct {
Nodes []string `yaml:"Nodes"` Nodes []string `yaml:"Nodes"`
Timeout int `yaml:"Timeout"` Timeout time.Duration `yaml:"Timeout"`
} }

View file

@ -133,19 +133,19 @@ func TestOracle(t *testing.T) {
cs := getOracleContractState(bc.contracts.Oracle.Hash, bc.contracts.Std.Hash) cs := getOracleContractState(bc.contracts.Oracle.Hash, bc.contracts.Std.Hash)
require.NoError(t, bc.contracts.Management.PutContractState(bc.dao, cs)) require.NoError(t, bc.contracts.Management.PutContractState(bc.dao, cs))
putOracleRequest(t, cs.Hash, bc, "http://get.1234", nil, "handle", []byte{}, 10_000_000) putOracleRequest(t, cs.Hash, bc, "https://get.1234", nil, "handle", []byte{}, 10_000_000)
putOracleRequest(t, cs.Hash, bc, "http://get.1234", nil, "handle", []byte{}, 10_000_000) putOracleRequest(t, cs.Hash, bc, "https://get.1234", nil, "handle", []byte{}, 10_000_000)
putOracleRequest(t, cs.Hash, bc, "http://get.timeout", nil, "handle", []byte{}, 10_000_000) putOracleRequest(t, cs.Hash, bc, "https://get.timeout", nil, "handle", []byte{}, 10_000_000)
putOracleRequest(t, cs.Hash, bc, "http://get.notfound", nil, "handle", []byte{}, 10_000_000) putOracleRequest(t, cs.Hash, bc, "https://get.notfound", nil, "handle", []byte{}, 10_000_000)
putOracleRequest(t, cs.Hash, bc, "http://get.forbidden", nil, "handle", []byte{}, 10_000_000) putOracleRequest(t, cs.Hash, bc, "https://get.forbidden", nil, "handle", []byte{}, 10_000_000)
putOracleRequest(t, cs.Hash, bc, "http://private.url", nil, "handle", []byte{}, 10_000_000) putOracleRequest(t, cs.Hash, bc, "https://private.url", nil, "handle", []byte{}, 10_000_000)
putOracleRequest(t, cs.Hash, bc, "http://get.big", nil, "handle", []byte{}, 10_000_000) putOracleRequest(t, cs.Hash, bc, "https://get.big", nil, "handle", []byte{}, 10_000_000)
putOracleRequest(t, cs.Hash, bc, "http://get.maxallowed", nil, "handle", []byte{}, 10_000_000) putOracleRequest(t, cs.Hash, bc, "https://get.maxallowed", nil, "handle", []byte{}, 10_000_000)
putOracleRequest(t, cs.Hash, bc, "http://get.maxallowed", nil, "handle", []byte{}, 100_000_000) putOracleRequest(t, cs.Hash, bc, "https://get.maxallowed", nil, "handle", []byte{}, 100_000_000)
flt := "Values[1]" flt := "Values[1]"
putOracleRequest(t, cs.Hash, bc, "http://get.filter", &flt, "handle", []byte{}, 10_000_000) putOracleRequest(t, cs.Hash, bc, "https://get.filter", &flt, "handle", []byte{}, 10_000_000)
putOracleRequest(t, cs.Hash, bc, "http://get.filterinv", &flt, "handle", []byte{}, 10_000_000) putOracleRequest(t, cs.Hash, bc, "https://get.filterinv", &flt, "handle", []byte{}, 10_000_000)
checkResp := func(t *testing.T, id uint64, resp *transaction.OracleResponse) *state.OracleRequest { checkResp := func(t *testing.T, id uint64, resp *transaction.OracleResponse) *state.OracleRequest {
req, err := oracleCtr.GetRequestInternal(bc.dao, id) req, err := oracleCtr.GetRequestInternal(bc.dao, id)
@ -279,7 +279,7 @@ func TestOracleFull(t *testing.T) {
t.Cleanup(orc.Shutdown) t.Cleanup(orc.Shutdown)
bc.setNodesByRole(t, true, noderoles.Oracle, keys.PublicKeys{acc.PrivateKey().PublicKey()}) bc.setNodesByRole(t, true, noderoles.Oracle, keys.PublicKeys{acc.PrivateKey().PublicKey()})
putOracleRequest(t, cs.Hash, bc, "http://get.1234", new(string), "handle", []byte{}, 10_000_000) putOracleRequest(t, cs.Hash, bc, "https://get.1234", new(string), "handle", []byte{}, 10_000_000)
require.Eventually(t, func() bool { return mp.Count() == 1 }, require.Eventually(t, func() bool { return mp.Count() == 1 },
time.Second*3, time.Millisecond*200) time.Second*3, time.Millisecond*200)
@ -341,43 +341,43 @@ func (c *httpClient) Get(url string) (*http.Response, error) {
func newDefaultHTTPClient() oracle.HTTPClient { func newDefaultHTTPClient() oracle.HTTPClient {
return &httpClient{ return &httpClient{
responses: map[string]testResponse{ responses: map[string]testResponse{
"http://get.1234": { "https://get.1234": {
code: http.StatusOK, code: http.StatusOK,
body: []byte{1, 2, 3, 4}, body: []byte{1, 2, 3, 4},
}, },
"http://get.4321": { "https://get.4321": {
code: http.StatusOK, code: http.StatusOK,
body: []byte{4, 3, 2, 1}, body: []byte{4, 3, 2, 1},
}, },
"http://get.timeout": { "https://get.timeout": {
code: http.StatusRequestTimeout, code: http.StatusRequestTimeout,
body: []byte{}, body: []byte{},
}, },
"http://get.notfound": { "https://get.notfound": {
code: http.StatusNotFound, code: http.StatusNotFound,
body: []byte{}, body: []byte{},
}, },
"http://get.forbidden": { "https://get.forbidden": {
code: http.StatusForbidden, code: http.StatusForbidden,
body: []byte{}, body: []byte{},
}, },
"http://private.url": { "https://private.url": {
code: http.StatusOK, code: http.StatusOK,
body: []byte("passwords"), body: []byte("passwords"),
}, },
"http://get.big": { "https://get.big": {
code: http.StatusOK, code: http.StatusOK,
body: make([]byte, transaction.MaxOracleResultSize+1), body: make([]byte, transaction.MaxOracleResultSize+1),
}, },
"http://get.maxallowed": { "https://get.maxallowed": {
code: http.StatusOK, code: http.StatusOK,
body: make([]byte, transaction.MaxOracleResultSize), body: make([]byte, transaction.MaxOracleResultSize),
}, },
"http://get.filter": { "https://get.filter": {
code: http.StatusOK, code: http.StatusOK,
body: []byte(`{"Values":["one", 2, 3],"Another":null}`), body: []byte(`{"Values":["one", 2, 3],"Another":null}`),
}, },
"http://get.filterinv": { "https://get.filterinv": {
code: http.StatusOK, code: http.StatusOK,
body: []byte{0xFF}, body: []byte{0xFF},
}, },

View file

@ -47,7 +47,7 @@ var (
) )
// Get returns neofs object from the provided url. // Get returns neofs object from the provided url.
// URI scheme is "neofs://<Container-ID>/<Object-ID/<Command>/<Params>". // URI scheme is "neofs:<Container-ID>/<Object-ID/<Command>/<Params>".
// If Command is not provided, full object is requested. // If Command is not provided, full object is requested.
func Get(ctx context.Context, priv *keys.PrivateKey, u *url.URL, addr string) ([]byte, error) { func Get(ctx context.Context, priv *keys.PrivateKey, u *url.URL, addr string) ([]byte, error) {
objectAddr, ps, err := parseNeoFSURL(u) objectAddr, ps, err := parseNeoFSURL(u)
@ -80,25 +80,25 @@ func parseNeoFSURL(u *url.URL) (*object.Address, []string, error) {
return nil, nil, ErrInvalidScheme return nil, nil, ErrInvalidScheme
} }
ps := strings.Split(strings.TrimPrefix(u.Path, "/"), "/") ps := strings.Split(u.Opaque, "/")
if len(ps) == 0 || ps[0] == "" { if len(ps) < 2 {
return nil, nil, ErrMissingObject return nil, nil, ErrMissingObject
} }
containerID := container.NewID() containerID := container.NewID()
if err := containerID.Parse(u.Hostname()); err != nil { if err := containerID.Parse(ps[0]); err != nil {
return nil, nil, fmt.Errorf("%w: %v", ErrInvalidContainer, err) return nil, nil, fmt.Errorf("%w: %v", ErrInvalidContainer, err)
} }
objectID := object.NewID() objectID := object.NewID()
if err := objectID.Parse(ps[0]); err != nil { if err := objectID.Parse(ps[1]); err != nil {
return nil, nil, fmt.Errorf("%w: %v", ErrInvalidObject, err) return nil, nil, fmt.Errorf("%w: %v", ErrInvalidObject, err)
} }
objectAddr := object.NewAddress() objectAddr := object.NewAddress()
objectAddr.SetContainerID(containerID) objectAddr.SetContainerID(containerID)
objectAddr.SetObjectID(objectID) objectAddr.SetObjectID(objectID)
return objectAddr, ps[1:], nil return objectAddr, ps[2:], nil
} }
func getPayload(ctx context.Context, c *client.Client, addr *object.Address) ([]byte, error) { func getPayload(ctx context.Context, c *client.Client, addr *object.Address) ([]byte, error) {

View file

@ -44,7 +44,7 @@ func TestParseNeoFSURL(t *testing.T) {
oid := object.NewID() oid := object.NewID()
require.NoError(t, oid.Parse(oStr)) require.NoError(t, oid.Parse(oStr))
validPrefix := "neofs://" + cStr + "/" + oStr validPrefix := "neofs:" + cStr + "/" + oStr
objectAddr := object.NewAddress() objectAddr := object.NewAddress()
objectAddr.SetContainerID(cid) objectAddr.SetContainerID(cid)
objectAddr.SetObjectID(oid) objectAddr.SetObjectID(oid)
@ -57,10 +57,10 @@ func TestParseNeoFSURL(t *testing.T) {
{validPrefix, nil, nil}, {validPrefix, nil, nil},
{validPrefix + "/", []string{""}, nil}, {validPrefix + "/", []string{""}, nil},
{validPrefix + "/range/1|2", []string{"range", "1|2"}, nil}, {validPrefix + "/range/1|2", []string{"range", "1|2"}, nil},
{"neoffs://" + cStr + "/" + oStr, nil, ErrInvalidScheme}, {"neoffs:" + cStr + "/" + oStr, nil, ErrInvalidScheme},
{"neofs://" + cStr, nil, ErrMissingObject}, {"neofs:" + cStr, nil, ErrMissingObject},
{"neofs://" + cStr + "ooo/" + oStr, nil, ErrInvalidContainer}, {"neofs:" + cStr + "ooo/" + oStr, nil, ErrInvalidContainer},
{"neofs://" + cStr + "/ooo" + oStr, nil, ErrInvalidObject}, {"neofs:" + cStr + "/ooo" + oStr, nil, ErrInvalidObject},
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.url, func(t *testing.T) { t.Run(tc.url, func(t *testing.T) {

View file

@ -108,6 +108,9 @@ func NewOracle(cfg Config) (*Oracle, error) {
if o.MainCfg.RequestTimeout == 0 { if o.MainCfg.RequestTimeout == 0 {
o.MainCfg.RequestTimeout = defaultRequestTimeout o.MainCfg.RequestTimeout = defaultRequestTimeout
} }
if o.MainCfg.NeoFS.Timeout == 0 {
o.MainCfg.NeoFS.Timeout = defaultRequestTimeout
}
if o.MainCfg.MaxConcurrentRequests == 0 { if o.MainCfg.MaxConcurrentRequests == 0 {
o.MainCfg.MaxConcurrentRequests = defaultMaxConcurrentRequests o.MainCfg.MaxConcurrentRequests = defaultMaxConcurrentRequests
} }

View file

@ -97,47 +97,65 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error {
} }
resp := &transaction.OracleResponse{ID: req.ID} resp := &transaction.OracleResponse{ID: req.ID}
u, err := url.ParseRequestURI(req.Req.URL) u, err := url.ParseRequestURI(req.Req.URL)
if err == nil && !o.MainCfg.AllowPrivateHost {
err = o.URIValidator(u)
}
if err != nil { if err != nil {
resp.Code = transaction.Forbidden o.Log.Warn("malformed oracle request", zap.String("url", req.Req.URL), zap.Error(err))
} else if u.Scheme == "http" { resp.Code = transaction.ProtocolNotSupported
r, err := o.Client.Get(req.Req.URL) } else {
switch { switch u.Scheme {
case err != nil: case "https":
resp.Code = transaction.Error if !o.MainCfg.AllowPrivateHost {
case r.StatusCode == http.StatusOK: err = o.URIValidator(u)
result, err := readResponse(r.Body, transaction.MaxOracleResultSize) if err != nil {
if err != nil { o.Log.Warn("forbidden oracle request", zap.String("url", req.Req.URL))
if errors.Is(err, ErrResponseTooLarge) { resp.Code = transaction.Forbidden
resp.Code = transaction.ResponseTooLarge break
} else {
resp.Code = transaction.Error
} }
}
r, err := o.Client.Get(req.Req.URL)
if err != nil {
o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err))
resp.Code = transaction.Error
break break
} }
resp.Code, resp.Result = filterRequest(result, req.Req) switch r.StatusCode {
case r.StatusCode == http.StatusForbidden: case http.StatusOK:
resp.Code = transaction.Forbidden result, err := readResponse(r.Body, transaction.MaxOracleResultSize)
case r.StatusCode == http.StatusNotFound: if err != nil {
resp.Code = transaction.NotFound if errors.Is(err, ErrResponseTooLarge) {
case r.StatusCode == http.StatusRequestTimeout: resp.Code = transaction.ResponseTooLarge
resp.Code = transaction.Timeout } else {
resp.Code = transaction.Error
}
o.Log.Warn("failed to read data for oracle request", zap.String("url", req.Req.URL), zap.Error(err))
break
}
resp.Code, resp.Result = filterRequest(result, req.Req)
case http.StatusForbidden:
resp.Code = transaction.Forbidden
case http.StatusNotFound:
resp.Code = transaction.NotFound
case http.StatusRequestTimeout:
resp.Code = transaction.Timeout
default:
resp.Code = transaction.Error
}
case neofs.URIScheme:
ctx, cancel := context.WithTimeout(context.Background(), o.MainCfg.NeoFS.Timeout)
defer cancel()
index := (int(req.ID) + incTx.attempts) % len(o.MainCfg.NeoFS.Nodes)
res, err := neofs.Get(ctx, priv, u, o.MainCfg.NeoFS.Nodes[index])
if err != nil {
o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err))
resp.Code = transaction.Error
} else {
resp.Code, resp.Result = filterRequest(res, req.Req)
}
default: default:
resp.Code = transaction.Error resp.Code = transaction.ProtocolNotSupported
} o.Log.Warn("unknown oracle request scheme", zap.String("url", req.Req.URL))
} else if err == nil && u.Scheme == neofs.URIScheme {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(o.MainCfg.NeoFS.Timeout)*time.Millisecond)
defer cancel()
index := (int(req.ID) + incTx.attempts) % len(o.MainCfg.NeoFS.Nodes)
res, err := neofs.Get(ctx, priv, u, o.MainCfg.NeoFS.Nodes[index])
if err != nil {
resp.Code = transaction.Error
} else {
resp.Code, resp.Result = filterRequest(res, req.Req)
} }
} }
o.Log.Debug("oracle request processed", zap.String("url", req.Req.URL), zap.Int("code", int(resp.Code)), zap.String("result", string(resp.Result)))
currentHeight := o.Chain.BlockHeight() currentHeight := o.Chain.BlockHeight()
_, h, err := o.Chain.GetTransaction(req.Req.OriginalTxID) _, h, err := o.Chain.GetTransaction(req.Req.OriginalTxID)