diff --git a/pkg/services/oracle/neofs/neofs.go b/pkg/services/oracle/neofs/neofs.go index 044fb2904..d7a701534 100644 --- a/pkg/services/oracle/neofs/neofs.go +++ b/pkg/services/oracle/neofs/neofs.go @@ -1,6 +1,7 @@ package neofs import ( + "bytes" "context" "errors" "fmt" @@ -18,10 +19,6 @@ import ( oid "github.com/nspcc-dev/neofs-sdk-go/object/id" ) -// ResultReader is a function that reads required amount of data and -// checks it. -type ResultReader func(io.Reader) ([]byte, error) - const ( // URIScheme is the name of neofs URI scheme. URIScheme = "neofs" @@ -47,7 +44,7 @@ var ( // Get returns a neofs object from the provided url. // URI scheme is "neofs://". // If Command is not provided, full object is requested. -func Get(ctx context.Context, priv *keys.PrivateKey, u *url.URL, addr string, resReader ResultReader) ([]byte, error) { +func Get(ctx context.Context, priv *keys.PrivateKey, u *url.URL, addr string) (io.ReadCloser, error) { objectAddr, ps, err := parseNeoFSURL(u) if err != nil { return nil, err @@ -63,27 +60,44 @@ func Get(ctx context.Context, priv *keys.PrivateKey, u *url.URL, addr string, re return nil, fmt.Errorf("failed to create client: %w", err) } - var prmd client.PrmDial + var ( + res = clientCloseWrapper{c: c} + prmd client.PrmDial + ) prmd.SetServerURI(addr) prmd.SetContext(ctx) err = c.Dial(prmd) //nolint:contextcheck // contextcheck: Function `Dial->Balance->SendUnary->Init->setNeoFSAPIServer` should pass the context parameter if err != nil { - return nil, err + return res, err } - defer c.Close() switch { case len(ps) == 0 || ps[0] == "": // Get request - return getPayload(ctx, c, objectAddr, resReader) + res.ReadCloser, err = getPayload(ctx, c, objectAddr) case ps[0] == rangeCmd: - return getRange(ctx, c, objectAddr, resReader, ps[1:]...) + res.ReadCloser, err = getRange(ctx, c, objectAddr, ps[1:]...) case ps[0] == headerCmd: - return getHeader(ctx, c, objectAddr) + res.ReadCloser, err = getHeader(ctx, c, objectAddr) case ps[0] == hashCmd: - return getHash(ctx, c, objectAddr, ps[1:]...) + res.ReadCloser, err = getHash(ctx, c, objectAddr, ps[1:]...) default: - return nil, ErrInvalidCommand + err = ErrInvalidCommand } + return res, err +} + +type clientCloseWrapper struct { + io.ReadCloser + c *client.Client +} + +func (w clientCloseWrapper) Close() error { + var res error + if w.ReadCloser != nil { + res = w.ReadCloser.Close() + } + w.c.Close() + return res } // parseNeoFSURL returns parsed neofs address. @@ -112,24 +126,11 @@ func parseNeoFSURL(u *url.URL) (*oid.Address, []string, error) { return objAddr, ps[2:], nil } -func getPayload(ctx context.Context, c *client.Client, addr *oid.Address, resReader ResultReader) ([]byte, error) { - objR, err := c.ObjectGetInit(ctx, addr.Container(), addr.Object(), client.PrmObjectGet{}) - if err != nil { - return nil, err - } - resp, err := resReader(objR) - if err != nil { - return nil, err - } - err = objR.Close() - if err != nil { - return nil, err - } - - return resp, nil +func getPayload(ctx context.Context, c *client.Client, addr *oid.Address) (io.ReadCloser, error) { + return c.ObjectGetInit(ctx, addr.Container(), addr.Object(), client.PrmObjectGet{}) } -func getRange(ctx context.Context, c *client.Client, addr *oid.Address, resReader ResultReader, ps ...string) ([]byte, error) { +func getRange(ctx context.Context, c *client.Client, addr *oid.Address, ps ...string) (io.ReadCloser, error) { if len(ps) == 0 { return nil, ErrInvalidRange } @@ -138,20 +139,7 @@ func getRange(ctx context.Context, c *client.Client, addr *oid.Address, resReade return nil, err } - rangeR, err := c.ObjectRangeInit(ctx, addr.Container(), addr.Object(), r.GetOffset(), r.GetLength(), client.PrmObjectRange{}) - if err != nil { - return nil, err - } - resp, err := resReader(rangeR) - if err != nil { - return nil, err - } - err = rangeR.Close() - if err != nil { - return nil, err - } - - return resp, nil + return c.ObjectRangeInit(ctx, addr.Container(), addr.Object(), r.GetOffset(), r.GetLength(), client.PrmObjectRange{}) } func getObjHeader(ctx context.Context, c *client.Client, addr *oid.Address) (*object.Object, error) { @@ -166,15 +154,19 @@ func getObjHeader(ctx context.Context, c *client.Client, addr *oid.Address) (*ob return obj, nil } -func getHeader(ctx context.Context, c *client.Client, addr *oid.Address) ([]byte, error) { +func getHeader(ctx context.Context, c *client.Client, addr *oid.Address) (io.ReadCloser, error) { obj, err := getObjHeader(ctx, c, addr) if err != nil { return nil, err } - return obj.MarshalHeaderJSON() + res, err := obj.MarshalHeaderJSON() + if err != nil { + return nil, err + } + return io.NopCloser(bytes.NewReader(res)), nil } -func getHash(ctx context.Context, c *client.Client, addr *oid.Address, ps ...string) ([]byte, error) { +func getHash(ctx context.Context, c *client.Client, addr *oid.Address, ps ...string) (io.ReadCloser, error) { if len(ps) == 0 || ps[0] == "" { // hash of the full payload obj, err := getObjHeader(ctx, c, addr) if err != nil { @@ -184,7 +176,7 @@ func getHash(ctx context.Context, c *client.Client, addr *oid.Address, ps ...str if !flag { return nil, errors.New("missing checksum in the reply") } - return sum.Value(), nil + return io.NopCloser(bytes.NewReader(sum.Value())), nil } r, err := parseRange(ps[0]) if err != nil { @@ -204,7 +196,11 @@ func getHash(ctx context.Context, c *client.Client, addr *oid.Address, ps ...str if err != nil { return nil, fmt.Errorf("decode Uint256: %w", err) } - return u256.MarshalJSON() + res, err := u256.MarshalJSON() + if err != nil { + return nil, err + } + return io.NopCloser(bytes.NewReader(res)), nil } func parseRange(s string) (*object.Range, error) { diff --git a/pkg/services/oracle/request.go b/pkg/services/oracle/request.go index a8bd75490..6267f55d0 100644 --- a/pkg/services/oracle/request.go +++ b/pkg/services/oracle/request.go @@ -146,16 +146,7 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error { break } - resp.Result, err = readResponse(r.Body) - if err != nil { - if errors.Is(err, ErrResponseTooLarge) { - resp.Code = transaction.ResponseTooLarge - } else { - resp.Code = transaction.Error - } - o.Log.Warn("failed to read data for oracle request", zap.String("url", req.Req.URL), zap.Error(err)) - break - } + resp.Result, resp.Code = o.readResponse(r.Body, req.Req.URL) case http.StatusForbidden: resp.Code = transaction.Forbidden case http.StatusNotFound: @@ -169,15 +160,17 @@ func (o *Oracle) processRequest(priv *keys.PrivateKey, req request) error { ctx, cancel := context.WithTimeout(context.Background(), o.MainCfg.NeoFS.Timeout) defer cancel() index := (int(req.ID) + incTx.attempts) % len(o.MainCfg.NeoFS.Nodes) - resp.Result, err = neofs.Get(ctx, priv, u, o.MainCfg.NeoFS.Nodes[index], readResponse) + rc, err := neofs.Get(ctx, priv, u, o.MainCfg.NeoFS.Nodes[index]) if err != nil { - if errors.Is(err, ErrResponseTooLarge) { - resp.Code = transaction.ResponseTooLarge - } else { - resp.Code = transaction.Error + resp.Code = transaction.Error + o.Log.Warn("failed to perform oracle request", zap.String("url", req.Req.URL), zap.Error(err)) + if rc != nil { + rc.Close() // intentionally skip the closing error, make it unified with Oracle `https` protocol. } - o.Log.Warn("oracle request failed", zap.String("url", req.Req.URL), zap.Error(err)) + break } + resp.Result, resp.Code = o.readResponse(rc, req.Req.URL) + rc.Close() // intentionally skip the closing error, make it unified with Oracle `https` protocol. default: resp.Code = transaction.ProtocolNotSupported o.Log.Warn("unknown oracle request scheme", zap.String("url", req.Req.URL)) diff --git a/pkg/services/oracle/response.go b/pkg/services/oracle/response.go index 6b1563e84..06062941e 100644 --- a/pkg/services/oracle/response.go +++ b/pkg/services/oracle/response.go @@ -68,17 +68,30 @@ func (o *Oracle) AddResponse(pub *keys.PublicKey, reqID uint64, txSig []byte) { // ErrResponseTooLarge is returned when a response exceeds the max allowed size. var ErrResponseTooLarge = errors.New("too big response") -func readResponse(rc gio.Reader) ([]byte, error) { +func (o *Oracle) readResponse(rc gio.Reader, url string) ([]byte, transaction.OracleResponseCode) { const limit = transaction.MaxOracleResultSize buf := make([]byte, limit+1) n, err := gio.ReadFull(rc, buf) if errors.Is(err, gio.ErrUnexpectedEOF) && n <= limit { - return checkUTF8(buf[:n]) + res, err := checkUTF8(buf[:n]) + return o.handleResponseError(res, err, url) } if err == nil || n > limit { - return nil, ErrResponseTooLarge + return o.handleResponseError(nil, ErrResponseTooLarge, url) } - return nil, err + + return o.handleResponseError(nil, err, url) +} + +func (o *Oracle) handleResponseError(data []byte, err error, url string) ([]byte, transaction.OracleResponseCode) { + if err != nil { + o.Log.Warn("failed to read data for oracle request", zap.String("url", url), zap.Error(err)) + if errors.Is(err, ErrResponseTooLarge) { + return nil, transaction.ResponseTooLarge + } + return nil, transaction.Error + } + return data, transaction.Success } func checkUTF8(v []byte) ([]byte, error) {