oracle: make use of ReadCloser returned from NeoFS's getters
Close #3032. Signed-off-by: Anna Shaleva <shaleva.ann@nspcc.ru>
This commit is contained in:
parent
4b2fc32462
commit
0d470edf21
3 changed files with 70 additions and 68 deletions
|
@ -1,6 +1,7 @@
|
||||||
package neofs
|
package neofs
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -18,10 +19,6 @@ import (
|
||||||
oid "github.com/nspcc-dev/neofs-sdk-go/object/id"
|
oid "github.com/nspcc-dev/neofs-sdk-go/object/id"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ResultReader is a function that reads required amount of data and
|
|
||||||
// checks it.
|
|
||||||
type ResultReader func(io.Reader) ([]byte, error)
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// URIScheme is the name of neofs URI scheme.
|
// URIScheme is the name of neofs URI scheme.
|
||||||
URIScheme = "neofs"
|
URIScheme = "neofs"
|
||||||
|
@ -47,7 +44,7 @@ var (
|
||||||
// Get returns a neofs object from the provided url.
|
// Get returns a neofs object from the provided url.
|
||||||
// URI scheme is "neofs:<Container-ID>/<Object-ID/<Command>/<Params>".
|
// URI scheme is "neofs:<Container-ID>/<Object-ID/<Command>/<Params>".
|
||||||
// If Command is not provided, full object is requested.
|
// If Command is not provided, full object is requested.
|
||||||
func Get(ctx context.Context, priv *keys.PrivateKey, u *url.URL, addr string, resReader ResultReader) ([]byte, error) {
|
func Get(ctx context.Context, priv *keys.PrivateKey, u *url.URL, addr string) (io.ReadCloser, error) {
|
||||||
objectAddr, ps, err := parseNeoFSURL(u)
|
objectAddr, ps, err := parseNeoFSURL(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -63,27 +60,44 @@ func Get(ctx context.Context, priv *keys.PrivateKey, u *url.URL, addr string, re
|
||||||
return nil, fmt.Errorf("failed to create client: %w", err)
|
return nil, fmt.Errorf("failed to create client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var prmd client.PrmDial
|
var (
|
||||||
|
res = clientCloseWrapper{c: c}
|
||||||
|
prmd client.PrmDial
|
||||||
|
)
|
||||||
prmd.SetServerURI(addr)
|
prmd.SetServerURI(addr)
|
||||||
prmd.SetContext(ctx)
|
prmd.SetContext(ctx)
|
||||||
err = c.Dial(prmd) //nolint:contextcheck // contextcheck: Function `Dial->Balance->SendUnary->Init->setNeoFSAPIServer` should pass the context parameter
|
err = c.Dial(prmd) //nolint:contextcheck // contextcheck: Function `Dial->Balance->SendUnary->Init->setNeoFSAPIServer` should pass the context parameter
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return res, err
|
||||||
}
|
}
|
||||||
defer c.Close()
|
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case len(ps) == 0 || ps[0] == "": // Get request
|
case len(ps) == 0 || ps[0] == "": // Get request
|
||||||
return getPayload(ctx, c, objectAddr, resReader)
|
res.ReadCloser, err = getPayload(ctx, c, objectAddr)
|
||||||
case ps[0] == rangeCmd:
|
case ps[0] == rangeCmd:
|
||||||
return getRange(ctx, c, objectAddr, resReader, ps[1:]...)
|
res.ReadCloser, err = getRange(ctx, c, objectAddr, ps[1:]...)
|
||||||
case ps[0] == headerCmd:
|
case ps[0] == headerCmd:
|
||||||
return getHeader(ctx, c, objectAddr)
|
res.ReadCloser, err = getHeader(ctx, c, objectAddr)
|
||||||
case ps[0] == hashCmd:
|
case ps[0] == hashCmd:
|
||||||
return getHash(ctx, c, objectAddr, ps[1:]...)
|
res.ReadCloser, err = getHash(ctx, c, objectAddr, ps[1:]...)
|
||||||
default:
|
default:
|
||||||
return nil, ErrInvalidCommand
|
err = ErrInvalidCommand
|
||||||
}
|
}
|
||||||
|
return res, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientCloseWrapper struct {
|
||||||
|
io.ReadCloser
|
||||||
|
c *client.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w clientCloseWrapper) Close() error {
|
||||||
|
var res error
|
||||||
|
if w.ReadCloser != nil {
|
||||||
|
res = w.ReadCloser.Close()
|
||||||
|
}
|
||||||
|
w.c.Close()
|
||||||
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseNeoFSURL returns parsed neofs address.
|
// parseNeoFSURL returns parsed neofs address.
|
||||||
|
@ -112,24 +126,11 @@ func parseNeoFSURL(u *url.URL) (*oid.Address, []string, error) {
|
||||||
return objAddr, ps[2:], nil
|
return objAddr, ps[2:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPayload(ctx context.Context, c *client.Client, addr *oid.Address, resReader ResultReader) ([]byte, error) {
|
func getPayload(ctx context.Context, c *client.Client, addr *oid.Address) (io.ReadCloser, error) {
|
||||||
objR, err := c.ObjectGetInit(ctx, addr.Container(), addr.Object(), client.PrmObjectGet{})
|
return c.ObjectGetInit(ctx, addr.Container(), addr.Object(), client.PrmObjectGet{})
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
resp, err := resReader(objR)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = objR.Close()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRange(ctx context.Context, c *client.Client, addr *oid.Address, resReader ResultReader, ps ...string) ([]byte, error) {
|
func getRange(ctx context.Context, c *client.Client, addr *oid.Address, ps ...string) (io.ReadCloser, error) {
|
||||||
if len(ps) == 0 {
|
if len(ps) == 0 {
|
||||||
return nil, ErrInvalidRange
|
return nil, ErrInvalidRange
|
||||||
}
|
}
|
||||||
|
@ -138,20 +139,7 @@ func getRange(ctx context.Context, c *client.Client, addr *oid.Address, resReade
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
rangeR, err := c.ObjectRangeInit(ctx, addr.Container(), addr.Object(), r.GetOffset(), r.GetLength(), client.PrmObjectRange{})
|
return c.ObjectRangeInit(ctx, addr.Container(), addr.Object(), r.GetOffset(), r.GetLength(), client.PrmObjectRange{})
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
resp, err := resReader(rangeR)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = rangeR.Close()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getObjHeader(ctx context.Context, c *client.Client, addr *oid.Address) (*object.Object, error) {
|
func getObjHeader(ctx context.Context, c *client.Client, addr *oid.Address) (*object.Object, error) {
|
||||||
|
@ -166,15 +154,19 @@ func getObjHeader(ctx context.Context, c *client.Client, addr *oid.Address) (*ob
|
||||||
return obj, nil
|
return obj, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getHeader(ctx context.Context, c *client.Client, addr *oid.Address) ([]byte, error) {
|
func getHeader(ctx context.Context, c *client.Client, addr *oid.Address) (io.ReadCloser, error) {
|
||||||
obj, err := getObjHeader(ctx, c, addr)
|
obj, err := getObjHeader(ctx, c, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return obj.MarshalHeaderJSON()
|
res, err := obj.MarshalHeaderJSON()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return io.NopCloser(bytes.NewReader(res)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getHash(ctx context.Context, c *client.Client, addr *oid.Address, ps ...string) ([]byte, error) {
|
func getHash(ctx context.Context, c *client.Client, addr *oid.Address, ps ...string) (io.ReadCloser, error) {
|
||||||
if len(ps) == 0 || ps[0] == "" { // hash of the full payload
|
if len(ps) == 0 || ps[0] == "" { // hash of the full payload
|
||||||
obj, err := getObjHeader(ctx, c, addr)
|
obj, err := getObjHeader(ctx, c, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -184,7 +176,7 @@ func getHash(ctx context.Context, c *client.Client, addr *oid.Address, ps ...str
|
||||||
if !flag {
|
if !flag {
|
||||||
return nil, errors.New("missing checksum in the reply")
|
return nil, errors.New("missing checksum in the reply")
|
||||||
}
|
}
|
||||||
return sum.Value(), nil
|
return io.NopCloser(bytes.NewReader(sum.Value())), nil
|
||||||
}
|
}
|
||||||
r, err := parseRange(ps[0])
|
r, err := parseRange(ps[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -204,7 +196,11 @@ func getHash(ctx context.Context, c *client.Client, addr *oid.Address, ps ...str
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("decode Uint256: %w", err)
|
return nil, fmt.Errorf("decode Uint256: %w", err)
|
||||||
}
|
}
|
||||||
return u256.MarshalJSON()
|
res, err := u256.MarshalJSON()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return io.NopCloser(bytes.NewReader(res)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseRange(s string) (*object.Range, error) {
|
func parseRange(s string) (*object.Range, error) {
|
||||||
|
|
|
@ -146,16 +146,7 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Result, err = readResponse(r.Body)
|
resp.Result, resp.Code = o.readResponse(r.Body, req.Req.URL)
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, ErrResponseTooLarge) {
|
|
||||||
resp.Code = transaction.ResponseTooLarge
|
|
||||||
} else {
|
|
||||||
resp.Code = transaction.Error
|
|
||||||
}
|
|
||||||
o.Log.Warn("failed to read data for oracle request", zap.String("url", req.Req.URL), zap.Error(err))
|
|
||||||
break
|
|
||||||
}
|
|
||||||
case http.StatusForbidden:
|
case http.StatusForbidden:
|
||||||
resp.Code = transaction.Forbidden
|
resp.Code = transaction.Forbidden
|
||||||
case http.StatusNotFound:
|
case http.StatusNotFound:
|
||||||
|
@ -169,15 +160,17 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), o.MainCfg.NeoFS.Timeout)
|
ctx, cancel := context.WithTimeout(context.Background(), o.MainCfg.NeoFS.Timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
index := (int(req.ID) + incTx.attempts) % len(o.MainCfg.NeoFS.Nodes)
|
index := (int(req.ID) + incTx.attempts) % len(o.MainCfg.NeoFS.Nodes)
|
||||||
resp.Result, err = neofs.Get(ctx, priv, u, o.MainCfg.NeoFS.Nodes[index], readResponse)
|
rc, err := neofs.Get(ctx, priv, u, o.MainCfg.NeoFS.Nodes[index])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrResponseTooLarge) {
|
resp.Code = transaction.Error
|
||||||
resp.Code = transaction.ResponseTooLarge
|
o.Log.Warn("failed to perform oracle request", zap.String("url", req.Req.URL), zap.Error(err))
|
||||||
} else {
|
if rc != nil {
|
||||||
resp.Code = transaction.Error
|
rc.Close() // intentionally skip the closing error, make it unified with Oracle `https` protocol.
|
||||||
}
|
}
|
||||||
o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err))
|
break
|
||||||
}
|
}
|
||||||
|
resp.Result, resp.Code = o.readResponse(rc, req.Req.URL)
|
||||||
|
rc.Close() // intentionally skip the closing error, make it unified with Oracle `https` protocol.
|
||||||
default:
|
default:
|
||||||
resp.Code = transaction.ProtocolNotSupported
|
resp.Code = transaction.ProtocolNotSupported
|
||||||
o.Log.Warn("unknown oracle request scheme", zap.String("url", req.Req.URL))
|
o.Log.Warn("unknown oracle request scheme", zap.String("url", req.Req.URL))
|
||||||
|
|
|
@ -68,17 +68,30 @@ func (o *Oracle) AddResponse(pub *keys.PublicKey, reqID uint64, txSig []byte) {
|
||||||
// ErrResponseTooLarge is returned when a response exceeds the max allowed size.
|
// ErrResponseTooLarge is returned when a response exceeds the max allowed size.
|
||||||
var ErrResponseTooLarge = errors.New("too big response")
|
var ErrResponseTooLarge = errors.New("too big response")
|
||||||
|
|
||||||
func readResponse(rc gio.Reader) ([]byte, error) {
|
func (o *Oracle) readResponse(rc gio.Reader, url string) ([]byte, transaction.OracleResponseCode) {
|
||||||
const limit = transaction.MaxOracleResultSize
|
const limit = transaction.MaxOracleResultSize
|
||||||
buf := make([]byte, limit+1)
|
buf := make([]byte, limit+1)
|
||||||
n, err := gio.ReadFull(rc, buf)
|
n, err := gio.ReadFull(rc, buf)
|
||||||
if errors.Is(err, gio.ErrUnexpectedEOF) && n <= limit {
|
if errors.Is(err, gio.ErrUnexpectedEOF) && n <= limit {
|
||||||
return checkUTF8(buf[:n])
|
res, err := checkUTF8(buf[:n])
|
||||||
|
return o.handleResponseError(res, err, url)
|
||||||
}
|
}
|
||||||
if err == nil || n > limit {
|
if err == nil || n > limit {
|
||||||
return nil, ErrResponseTooLarge
|
return o.handleResponseError(nil, ErrResponseTooLarge, url)
|
||||||
}
|
}
|
||||||
return nil, err
|
|
||||||
|
return o.handleResponseError(nil, err, url)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Oracle) handleResponseError(data []byte, err error, url string) ([]byte, transaction.OracleResponseCode) {
|
||||||
|
if err != nil {
|
||||||
|
o.Log.Warn("failed to read data for oracle request", zap.String("url", url), zap.Error(err))
|
||||||
|
if errors.Is(err, ErrResponseTooLarge) {
|
||||||
|
return nil, transaction.ResponseTooLarge
|
||||||
|
}
|
||||||
|
return nil, transaction.Error
|
||||||
|
}
|
||||||
|
return data, transaction.Success
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkUTF8(v []byte) ([]byte, error) {
|
func checkUTF8(v []byte) ([]byte, error) {
|
||||||
|
|
Loading…
Reference in a new issue