diff --git a/pkg/client/object.go b/pkg/client/object.go index 5f178af..2a962b8 100644 --- a/pkg/client/object.go +++ b/pkg/client/object.go @@ -75,6 +75,8 @@ type GetObjectParams struct { raw bool w io.Writer + + readerHandler ReaderHandler } type ObjectHeaderParams struct { @@ -475,6 +477,78 @@ func (p *GetObjectParams) RawFlag() bool { return false } +// ReaderHandler is a function over io.Reader. +type ReaderHandler func(io.Reader) + +// WithPayloadReaderHandler sets handler of the payload reader. +// +// If provided, payload reader is composed after receiving the header. +// In this case payload writer set via WithPayloadWriter is ignored. +// +// Handler should not be nil. +func (p *GetObjectParams) WithPayloadReaderHandler(f ReaderHandler) *GetObjectParams { + if p != nil { + p.readerHandler = f + } + + return p +} + +// wrapper over the Object Get stream that provides io.Reader. +type objectPayloadReader struct { + stream *rpcapi.GetResponseReader + + resp v2object.GetResponse + + tail []byte +} + +func (x *objectPayloadReader) Read(p []byte) (int, error) { + // read remaining tail + read := copy(p, x.tail) + + x.tail = x.tail[read:] + + if len(p)-read == 0 { + return read, nil + } + + // receive message from server stream + err := x.stream.Read(&x.resp) + if err != nil { + if errors.Is(err, io.EOF) { + return 0, io.EOF + } + + return 0, fmt.Errorf("reading the response failed: %w", err) + } + + // get chunk part message + part := x.resp.GetBody().GetObjectPart() + + chunkPart, ok := part.(*v2object.GetObjectPartChunk) + if !ok { + return 0, errWrongMessageSeq + } + + // verify response structure + if err := signature.VerifyServiceMessage(&x.resp); err != nil { + return 0, fmt.Errorf("response verification failed: %w", err) + } + + // read new chunk + chunk := chunkPart.GetChunk() + + tailOffset := copy(p[read:], chunk) + + read += tailOffset + + // save the tail + x.tail = append(x.tail, chunk[tailOffset:]...) + + return read, nil +} + var errWrongMessageSeq = errors.New("incorrect message sequence") func (c *clientImpl) GetObject(ctx context.Context, p *GetObjectParams, opts ...CallOption) (*object.Object, error) { @@ -527,6 +601,7 @@ func (c *clientImpl) GetObject(ctx context.Context, p *GetObjectParams, opts ... resp = new(v2object.GetResponse) ) +loop: for { // receive message from server stream err := stream.Read(resp) @@ -563,6 +638,14 @@ func (c *clientImpl) GetObject(ctx context.Context, p *GetObjectParams, opts ... hdr := v.GetHeader() obj.SetHeader(hdr) + if p.readerHandler != nil { + p.readerHandler(&objectPayloadReader{ + stream: stream, + }) + + break loop + } + if p.w == nil { payload = make([]byte, 0, hdr.GetPayloadLength()) }