Fix failure to comply with a requirement of stdlib `io.Reader` docs: `When Read encounters an error or end-of-file condition after successfully reading n > 0 bytes, it returns the number of bytes read.` Prepare a platform for unit tests and test the affected case. Signed-off-by: Leonard Lyubich <leonard@nspcc.ru>
This commit is contained in:
parent
616b4b71a1
commit
6d531a07a5
2 changed files with 112 additions and 11 deletions
|
@ -496,31 +496,35 @@ func (p *GetObjectParams) WithPayloadReaderHandler(f ReaderHandler) *GetObjectPa
|
|||
|
||||
// wrapper over the Object Get stream that provides io.Reader.
|
||||
type objectPayloadReader struct {
|
||||
stream *rpcapi.GetResponseReader
|
||||
stream interface {
|
||||
Read(*v2object.GetResponse) error
|
||||
}
|
||||
|
||||
resp v2object.GetResponse
|
||||
|
||||
tail []byte
|
||||
}
|
||||
|
||||
func (x *objectPayloadReader) Read(p []byte) (int, error) {
|
||||
func (x *objectPayloadReader) Read(p []byte) (read int, err error) {
|
||||
// read remaining tail
|
||||
read := copy(p, x.tail)
|
||||
read = copy(p, x.tail)
|
||||
|
||||
x.tail = x.tail[read:]
|
||||
|
||||
if len(p)-read == 0 {
|
||||
return read, nil
|
||||
return
|
||||
}
|
||||
|
||||
// receive message from server stream
|
||||
err := x.stream.Read(&x.resp)
|
||||
err = x.stream.Read(&x.resp)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return 0, io.EOF
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("reading the response failed: %w", err)
|
||||
err = fmt.Errorf("reading the response failed: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// get chunk part message
|
||||
|
@ -528,12 +532,14 @@ func (x *objectPayloadReader) Read(p []byte) (int, error) {
|
|||
|
||||
chunkPart, ok := part.(*v2object.GetObjectPartChunk)
|
||||
if !ok {
|
||||
return 0, errWrongMessageSeq
|
||||
err = errWrongMessageSeq
|
||||
return
|
||||
}
|
||||
|
||||
// verify response structure
|
||||
if err := signature.VerifyServiceMessage(&x.resp); err != nil {
|
||||
return 0, fmt.Errorf("response verification failed: %w", err)
|
||||
if err = signature.VerifyServiceMessage(&x.resp); err != nil {
|
||||
err = fmt.Errorf("response verification failed: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// read new chunk
|
||||
|
@ -546,7 +552,7 @@ func (x *objectPayloadReader) Read(p []byte) (int, error) {
|
|||
// save the tail
|
||||
x.tail = append(x.tail, chunk[tailOffset:]...)
|
||||
|
||||
return read, nil
|
||||
return
|
||||
}
|
||||
|
||||
var errWrongMessageSeq = errors.New("incorrect message sequence")
|
||||
|
|
95
pkg/client/object_test.go
Normal file
95
pkg/client/object_test.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neofs-api-go/v2/object"
|
||||
"github.com/nspcc-dev/neofs-api-go/v2/signature"
|
||||
"github.com/nspcc-dev/neofs-crypto/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type singleResponseStream struct {
|
||||
called bool
|
||||
resp object.GetResponse
|
||||
}
|
||||
|
||||
func (x *singleResponseStream) Read(r *object.GetResponse) error {
|
||||
if x.called {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
x.called = true
|
||||
|
||||
*r = x.resp
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var key = test.DecodeKey(0)
|
||||
|
||||
func chunkResponse(c []byte) (r object.GetResponse) {
|
||||
chunkPart := new(object.GetObjectPartChunk)
|
||||
chunkPart.SetChunk(c)
|
||||
|
||||
body := new(object.GetResponseBody)
|
||||
body.SetObjectPart(chunkPart)
|
||||
|
||||
r.SetBody(body)
|
||||
|
||||
if err := signature.SignServiceMessage(key, &r); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func data(sz int) []byte {
|
||||
data := make([]byte, sz)
|
||||
|
||||
for i := range data {
|
||||
data[i] = byte(i) % ^byte(0)
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func checkFullRead(t *testing.T, r io.Reader, buf, payload []byte) {
|
||||
var (
|
||||
restored []byte
|
||||
read int
|
||||
)
|
||||
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
|
||||
read += n
|
||||
restored = append(restored, buf[:n]...)
|
||||
|
||||
if err != nil {
|
||||
require.Equal(t, err, io.EOF)
|
||||
break
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, payload, restored)
|
||||
require.EqualValues(t, len(payload), read)
|
||||
}
|
||||
|
||||
func TestObjectPayloadReader_Read(t *testing.T) {
|
||||
t.Run("read with tail", func(t *testing.T) {
|
||||
payload := data(10)
|
||||
|
||||
buf := make([]byte, len(payload)-1)
|
||||
|
||||
var r io.Reader = &objectPayloadReader{
|
||||
stream: &singleResponseStream{
|
||||
resp: chunkResponse(payload),
|
||||
},
|
||||
}
|
||||
|
||||
checkFullRead(t, r, buf, payload)
|
||||
})
|
||||
}
|
Loading…
Reference in a new issue