diff --git a/pkg/services/object/put/streamer.go b/pkg/services/object/put/streamer.go index 3ccb80006..c1461a73c 100644 --- a/pkg/services/object/put/streamer.go +++ b/pkg/services/object/put/streamer.go @@ -21,6 +21,8 @@ type Streamer struct { target transformer.ObjectTarget relay func(client.Client) error + + maxPayloadSz uint64 // network config } var errNotInit = errors.New("stream not initialized") @@ -39,6 +41,13 @@ func (p *Streamer) Init(prm *PutInitPrm) error { return nil } +// MaxObjectSize returns maximum payload size for the streaming session. +// +// Must be called after the successful Init. +func (p *Streamer) MaxObjectSize() uint64 { + return p.maxPayloadSz +} + func (p *Streamer) initTarget(prm *PutInitPrm) error { // prevent re-calling if p.target != nil { @@ -50,6 +59,11 @@ func (p *Streamer) initTarget(prm *PutInitPrm) error { return fmt.Errorf("(%T) could not prepare put parameters: %w", p, err) } + p.maxPayloadSz = p.maxSizeSrc.MaxObjectSize() + if p.maxPayloadSz == 0 { + return fmt.Errorf("(%T) could not obtain max object size parameter", p) + } + if prm.hdr.Signature() != nil { p.relay = prm.relay @@ -57,6 +71,8 @@ func (p *Streamer) initTarget(prm *PutInitPrm) error { p.target = &validatingTarget{ nextTarget: p.newCommonTarget(prm), fmt: p.fmtValidator, + + maxPayloadSz: p.maxPayloadSz, } return nil @@ -72,13 +88,8 @@ func (p *Streamer) initTarget(prm *PutInitPrm) error { return fmt.Errorf("(%T) could not receive session key: %w", p, err) } - maxSz := p.maxSizeSrc.MaxObjectSize() - if maxSz == 0 { - return fmt.Errorf("(%T) could not obtain max object size parameter", p) - } - p.target = transformer.NewPayloadSizeLimiter( - maxSz, + p.maxPayloadSz, func() transformer.ObjectTarget { return transformer.NewFormatTarget(&transformer.FormatterParams{ Key: sessionKey, diff --git a/pkg/services/object/put/v2/streamer.go b/pkg/services/object/put/v2/streamer.go index e09122e13..9f07975f9 100644 --- a/pkg/services/object/put/v2/streamer.go +++ b/pkg/services/object/put/v2/streamer.go @@ -1,6 +1,7 @@ package putsvc import ( + "errors" "fmt" "github.com/nspcc-dev/neofs-api-go/pkg/client" @@ -19,8 +20,25 @@ type streamer struct { saveChunks bool init *object.PutRequest chunks []*object.PutRequest + + *sizes // only for relay streams } +type sizes struct { + payloadSz uint64 // value from the header + + writtenPayload uint64 // sum size of already cached chunks +} + +// TODO: errors are copy-pasted from putsvc package +// consider replacing to core library + +// errors related to invalid payload size +var ( + errExceedingMaxSize = errors.New("payload size is greater than the limit") + errWrongPayloadSize = errors.New("wrong payload size") +) + func (s *streamer) Send(req *object.PutRequest) (err error) { switch v := req.GetBody().GetObjectPart().(type) { case *object.PutObjectPartInit: @@ -37,9 +55,29 @@ func (s *streamer) Send(req *object.PutRequest) (err error) { s.saveChunks = v.GetSignature() != nil if s.saveChunks { + maxSz := s.stream.MaxObjectSize() + + s.sizes = &sizes{ + payloadSz: uint64(v.GetHeader().GetPayloadLength()), + } + + // check payload size limit overflow + if s.payloadSz > maxSz { + return errExceedingMaxSize + } + s.init = req } case *object.PutObjectPartChunk: + if s.saveChunks { + s.writtenPayload += uint64(len(v.GetChunk())) + + // check payload size overflow + if s.writtenPayload > s.payloadSz { + return errWrongPayloadSize + } + } + if err = s.stream.SendChunk(toChunkPrm(v)); err != nil { err = fmt.Errorf("(%T) could not send payload chunk: %w", s, err) } @@ -71,6 +109,13 @@ func (s *streamer) Send(req *object.PutRequest) (err error) { } func (s *streamer) CloseAndRecv() (*object.PutResponse, error) { + if s.saveChunks { + // check payload size correctness + if s.writtenPayload != s.payloadSz { + return nil, errWrongPayloadSize + } + } + resp, err := s.stream.Close() if err != nil { return nil, fmt.Errorf("(%T) could not object put stream: %w", s, err) diff --git a/pkg/services/object/put/validation.go b/pkg/services/object/put/validation.go index 6de78f48f..042c6b8ce 100644 --- a/pkg/services/object/put/validation.go +++ b/pkg/services/object/put/validation.go @@ -3,6 +3,7 @@ package putsvc import ( "bytes" "crypto/sha256" + "errors" "fmt" "hash" @@ -20,9 +21,34 @@ type validatingTarget struct { hash hash.Hash checksum []byte + + maxPayloadSz uint64 // network config + + payloadSz uint64 // payload size of the streaming object from header + + writtenPayload uint64 // number of already written payload bytes } +// errors related to invalid payload size +var ( + errExceedingMaxSize = errors.New("payload size is greater than the limit") + errWrongPayloadSize = errors.New("wrong payload size") +) + func (t *validatingTarget) WriteHeader(obj *object.RawObject) error { + t.payloadSz = obj.PayloadSize() + chunkLn := uint64(len(obj.Payload())) + + // check chunk size + if chunkLn > t.payloadSz { + return errWrongPayloadSize + } + + // check payload size limit + if t.payloadSz > t.maxPayloadSz { + return errExceedingMaxSize + } + cs := obj.PayloadChecksum() switch typ := cs.Type(); typ { default: @@ -39,19 +65,47 @@ func (t *validatingTarget) WriteHeader(obj *object.RawObject) error { return fmt.Errorf("(%T) coult not validate object format: %w", t, err) } - return t.nextTarget.WriteHeader(obj) + err := t.nextTarget.WriteHeader(obj) + if err != nil { + return err + } + + // update written bytes + // + // Note: we MUST NOT add obj.PayloadSize() since obj + // can carry only the chunk of the full payload + t.writtenPayload += chunkLn + + return nil } func (t *validatingTarget) Write(p []byte) (n int, err error) { - n, err = t.hash.Write(p) + chunkLn := uint64(len(p)) + + // check if new chunk will overflow payload size + if t.writtenPayload+chunkLn > t.payloadSz { + return 0, errWrongPayloadSize + } + + _, err = t.hash.Write(p) if err != nil { return } - return t.nextTarget.Write(p) + n, err = t.nextTarget.Write(p) + if err == nil { + t.writtenPayload += uint64(n) + } + + return } func (t *validatingTarget) Close() (*transformer.AccessIdentifiers, error) { + // check payload size correctness + if t.payloadSz != t.writtenPayload { + return nil, errWrongPayloadSize + } + if !bytes.Equal(t.hash.Sum(nil), t.checksum) { return nil, fmt.Errorf("(%T) incorrect payload checksum", t) }