From 6d531a07a53dc19ad5ffbb4f90a550dfd8fe408f Mon Sep 17 00:00:00 2001 From: Leonard Lyubich <45413332+cthulhu-rider@users.noreply.github.com> Date: Wed, 23 Jun 2021 14:15:58 +0300 Subject: [PATCH] [#313] client/object: Always return number of bytes read from Get stream (#316) Fix failure to comply with a requirement of stdlib `io.Reader` docs: `When Read encounters an error or end-of-file condition after successfully reading n > 0 bytes, it returns the number of bytes read.` Prepare a platform for unit tests and test the affected case. Signed-off-by: Leonard Lyubich --- pkg/client/object.go | 28 +++++++----- pkg/client/object_test.go | 95 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 11 deletions(-) create mode 100644 pkg/client/object_test.go diff --git a/pkg/client/object.go b/pkg/client/object.go index 2a962b8..62e895e 100644 --- a/pkg/client/object.go +++ b/pkg/client/object.go @@ -496,31 +496,35 @@ func (p *GetObjectParams) WithPayloadReaderHandler(f ReaderHandler) *GetObjectPa // wrapper over the Object Get stream that provides io.Reader. type objectPayloadReader struct { - stream *rpcapi.GetResponseReader + stream interface { + Read(*v2object.GetResponse) error + } resp v2object.GetResponse tail []byte } -func (x *objectPayloadReader) Read(p []byte) (int, error) { +func (x *objectPayloadReader) Read(p []byte) (read int, err error) { // read remaining tail - read := copy(p, x.tail) + read = copy(p, x.tail) x.tail = x.tail[read:] if len(p)-read == 0 { - return read, nil + return } // receive message from server stream - err := x.stream.Read(&x.resp) + err = x.stream.Read(&x.resp) if err != nil { if errors.Is(err, io.EOF) { - return 0, io.EOF + err = io.EOF + return } - return 0, fmt.Errorf("reading the response failed: %w", err) + err = fmt.Errorf("reading the response failed: %w", err) + return } // get chunk part message @@ -528,12 +532,14 @@ func (x *objectPayloadReader) Read(p []byte) (int, error) { chunkPart, ok := part.(*v2object.GetObjectPartChunk) if !ok { - return 0, errWrongMessageSeq + err = errWrongMessageSeq + return } // verify response structure - if err := signature.VerifyServiceMessage(&x.resp); err != nil { - return 0, fmt.Errorf("response verification failed: %w", err) + if err = signature.VerifyServiceMessage(&x.resp); err != nil { + err = fmt.Errorf("response verification failed: %w", err) + return } // read new chunk @@ -546,7 +552,7 @@ func (x *objectPayloadReader) Read(p []byte) (int, error) { // save the tail x.tail = append(x.tail, chunk[tailOffset:]...) - return read, nil + return } var errWrongMessageSeq = errors.New("incorrect message sequence") diff --git a/pkg/client/object_test.go b/pkg/client/object_test.go new file mode 100644 index 0000000..4c7d296 --- /dev/null +++ b/pkg/client/object_test.go @@ -0,0 +1,95 @@ +package client + +import ( + "io" + "testing" + + "github.com/nspcc-dev/neofs-api-go/v2/object" + "github.com/nspcc-dev/neofs-api-go/v2/signature" + "github.com/nspcc-dev/neofs-crypto/test" + "github.com/stretchr/testify/require" +) + +type singleResponseStream struct { + called bool + resp object.GetResponse +} + +func (x *singleResponseStream) Read(r *object.GetResponse) error { + if x.called { + return io.EOF + } + + x.called = true + + *r = x.resp + + return nil +} + +var key = test.DecodeKey(0) + +func chunkResponse(c []byte) (r object.GetResponse) { + chunkPart := new(object.GetObjectPartChunk) + chunkPart.SetChunk(c) + + body := new(object.GetResponseBody) + body.SetObjectPart(chunkPart) + + r.SetBody(body) + + if err := signature.SignServiceMessage(key, &r); err != nil { + panic(err) + } + + return +} + +func data(sz int) []byte { + data := make([]byte, sz) + + for i := range data { + data[i] = byte(i) % ^byte(0) + } + + return data +} + +func checkFullRead(t *testing.T, r io.Reader, buf, payload []byte) { + var ( + restored []byte + read int + ) + + for { + n, err := r.Read(buf) + + read += n + restored = append(restored, buf[:n]...) + + if err != nil { + require.Equal(t, err, io.EOF) + break + + } + } + + require.Equal(t, payload, restored) + require.EqualValues(t, len(payload), read) +} + +func TestObjectPayloadReader_Read(t *testing.T) { + t.Run("read with tail", func(t *testing.T) { + payload := data(10) + + buf := make([]byte, len(payload)-1) + + var r io.Reader = &objectPayloadReader{ + stream: &singleResponseStream{ + resp: chunkResponse(payload), + }, + } + + checkFullRead(t, r, buf, payload) + }) +}