[#313] client/object: Always return number of bytes read from Get stream (#316)

Fix failure to comply with a requirement of stdlib `io.Reader` docs: `When
Read encounters an error or end-of-file condition after successfully reading
n > 0 bytes, it returns the number of bytes read.`

Prepare a platform for unit tests and test the affected case.

Signed-off-by: Leonard Lyubich <leonard@nspcc.ru>
remotes/KirillovDenis/feature/refactor-sig-rpc
Leonard Lyubich 2021-06-23 14:15:58 +03:00 committed by GitHub
parent 616b4b71a1
commit 6d531a07a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 112 additions and 11 deletions

View File

@ -496,31 +496,35 @@ func (p *GetObjectParams) WithPayloadReaderHandler(f ReaderHandler) *GetObjectPa
// wrapper over the Object Get stream that provides io.Reader.
type objectPayloadReader struct {
stream *rpcapi.GetResponseReader
stream interface {
Read(*v2object.GetResponse) error
}
resp v2object.GetResponse
tail []byte
}
func (x *objectPayloadReader) Read(p []byte) (int, error) {
func (x *objectPayloadReader) Read(p []byte) (read int, err error) {
// read remaining tail
read := copy(p, x.tail)
read = copy(p, x.tail)
x.tail = x.tail[read:]
if len(p)-read == 0 {
return read, nil
return
}
// receive message from server stream
err := x.stream.Read(&x.resp)
err = x.stream.Read(&x.resp)
if err != nil {
if errors.Is(err, io.EOF) {
return 0, io.EOF
err = io.EOF
return
}
return 0, fmt.Errorf("reading the response failed: %w", err)
err = fmt.Errorf("reading the response failed: %w", err)
return
}
// get chunk part message
@ -528,12 +532,14 @@ func (x *objectPayloadReader) Read(p []byte) (int, error) {
chunkPart, ok := part.(*v2object.GetObjectPartChunk)
if !ok {
return 0, errWrongMessageSeq
err = errWrongMessageSeq
return
}
// verify response structure
if err := signature.VerifyServiceMessage(&x.resp); err != nil {
return 0, fmt.Errorf("response verification failed: %w", err)
if err = signature.VerifyServiceMessage(&x.resp); err != nil {
err = fmt.Errorf("response verification failed: %w", err)
return
}
// read new chunk
@ -546,7 +552,7 @@ func (x *objectPayloadReader) Read(p []byte) (int, error) {
// save the tail
x.tail = append(x.tail, chunk[tailOffset:]...)
return read, nil
return
}
var errWrongMessageSeq = errors.New("incorrect message sequence")

View File

@ -0,0 +1,95 @@
package client
import (
"io"
"testing"
"github.com/nspcc-dev/neofs-api-go/v2/object"
"github.com/nspcc-dev/neofs-api-go/v2/signature"
"github.com/nspcc-dev/neofs-crypto/test"
"github.com/stretchr/testify/require"
)
type singleResponseStream struct {
called bool
resp object.GetResponse
}
func (x *singleResponseStream) Read(r *object.GetResponse) error {
if x.called {
return io.EOF
}
x.called = true
*r = x.resp
return nil
}
var key = test.DecodeKey(0)
func chunkResponse(c []byte) (r object.GetResponse) {
chunkPart := new(object.GetObjectPartChunk)
chunkPart.SetChunk(c)
body := new(object.GetResponseBody)
body.SetObjectPart(chunkPart)
r.SetBody(body)
if err := signature.SignServiceMessage(key, &r); err != nil {
panic(err)
}
return
}
func data(sz int) []byte {
data := make([]byte, sz)
for i := range data {
data[i] = byte(i) % ^byte(0)
}
return data
}
func checkFullRead(t *testing.T, r io.Reader, buf, payload []byte) {
var (
restored []byte
read int
)
for {
n, err := r.Read(buf)
read += n
restored = append(restored, buf[:n]...)
if err != nil {
require.Equal(t, err, io.EOF)
break
}
}
require.Equal(t, payload, restored)
require.EqualValues(t, len(payload), read)
}
func TestObjectPayloadReader_Read(t *testing.T) {
t.Run("read with tail", func(t *testing.T) {
payload := data(10)
buf := make([]byte, len(payload)-1)
var r io.Reader = &objectPayloadReader{
stream: &singleResponseStream{
resp: chunkResponse(payload),
},
}
checkFullRead(t, r, buf, payload)
})
}