diff --git a/pkg/client/object.go b/pkg/client/object.go index 2a962b8c..62e895ea 100644 --- a/pkg/client/object.go +++ b/pkg/client/object.go @@ -496,31 +496,35 @@ func (p *GetObjectParams) WithPayloadReaderHandler(f ReaderHandler) *GetObjectPa // wrapper over the Object Get stream that provides io.Reader. type objectPayloadReader struct { - stream *rpcapi.GetResponseReader + stream interface { + Read(*v2object.GetResponse) error + } resp v2object.GetResponse tail []byte } -func (x *objectPayloadReader) Read(p []byte) (int, error) { +func (x *objectPayloadReader) Read(p []byte) (read int, err error) { // read remaining tail - read := copy(p, x.tail) + read = copy(p, x.tail) x.tail = x.tail[read:] if len(p)-read == 0 { - return read, nil + return } // receive message from server stream - err := x.stream.Read(&x.resp) + err = x.stream.Read(&x.resp) if err != nil { if errors.Is(err, io.EOF) { - return 0, io.EOF + err = io.EOF + return } - return 0, fmt.Errorf("reading the response failed: %w", err) + err = fmt.Errorf("reading the response failed: %w", err) + return } // get chunk part message @@ -528,12 +532,14 @@ func (x *objectPayloadReader) Read(p []byte) (int, error) { chunkPart, ok := part.(*v2object.GetObjectPartChunk) if !ok { - return 0, errWrongMessageSeq + err = errWrongMessageSeq + return } // verify response structure - if err := signature.VerifyServiceMessage(&x.resp); err != nil { - return 0, fmt.Errorf("response verification failed: %w", err) + if err = signature.VerifyServiceMessage(&x.resp); err != nil { + err = fmt.Errorf("response verification failed: %w", err) + return } // read new chunk @@ -546,7 +552,7 @@ func (x *objectPayloadReader) Read(p []byte) (int, error) { // save the tail x.tail = append(x.tail, chunk[tailOffset:]...) - return read, nil + return } var errWrongMessageSeq = errors.New("incorrect message sequence") diff --git a/pkg/client/object_test.go b/pkg/client/object_test.go new file mode 100644 index 00000000..4c7d296d --- /dev/null +++ b/pkg/client/object_test.go @@ -0,0 +1,95 @@ +package client + +import ( + "io" + "testing" + + "github.com/nspcc-dev/neofs-api-go/v2/object" + "github.com/nspcc-dev/neofs-api-go/v2/signature" + "github.com/nspcc-dev/neofs-crypto/test" + "github.com/stretchr/testify/require" +) + +type singleResponseStream struct { + called bool + resp object.GetResponse +} + +func (x *singleResponseStream) Read(r *object.GetResponse) error { + if x.called { + return io.EOF + } + + x.called = true + + *r = x.resp + + return nil +} + +var key = test.DecodeKey(0) + +func chunkResponse(c []byte) (r object.GetResponse) { + chunkPart := new(object.GetObjectPartChunk) + chunkPart.SetChunk(c) + + body := new(object.GetResponseBody) + body.SetObjectPart(chunkPart) + + r.SetBody(body) + + if err := signature.SignServiceMessage(key, &r); err != nil { + panic(err) + } + + return +} + +func data(sz int) []byte { + data := make([]byte, sz) + + for i := range data { + data[i] = byte(i) % ^byte(0) + } + + return data +} + +func checkFullRead(t *testing.T, r io.Reader, buf, payload []byte) { + var ( + restored []byte + read int + ) + + for { + n, err := r.Read(buf) + + read += n + restored = append(restored, buf[:n]...) + + if err != nil { + require.Equal(t, err, io.EOF) + break + + } + } + + require.Equal(t, payload, restored) + require.EqualValues(t, len(payload), read) +} + +func TestObjectPayloadReader_Read(t *testing.T) { + t.Run("read with tail", func(t *testing.T) { + payload := data(10) + + buf := make([]byte, len(payload)-1) + + var r io.Reader = &objectPayloadReader{ + stream: &singleResponseStream{ + resp: chunkResponse(payload), + }, + } + + checkFullRead(t, r, buf, payload) + }) +}