diff --git a/pkg/services/object/put/streamer.go b/pkg/services/object/put/streamer.go index ea885366b..f1ecd4df1 100644 --- a/pkg/services/object/put/streamer.go +++ b/pkg/services/object/put/streamer.go @@ -77,7 +77,7 @@ func (p *Streamer) initUntrustedTarget(prm *PutInitPrm) error { p.relay = prm.relay // prepare untrusted-Put object target - p.target = &validatingTarget{ + p.target = &validatingPreparedTarget{ nextTarget: p.newCommonTarget(prm), fmt: p.fmtValidator, @@ -125,8 +125,7 @@ func (p *Streamer) initTrustedTarget(prm *PutInitPrm) error { p.sessionKey = sessionKey p.target = &validatingTarget{ - fmt: p.fmtValidator, - unpreparedObject: true, + fmt: p.fmtValidator, nextTarget: transformer.NewPayloadSizeLimiter( p.maxPayloadSz, containerSDK.IsHomomorphicHashingDisabled(prm.cnr), diff --git a/pkg/services/object/put/validation.go b/pkg/services/object/put/validation.go index 8c40d0677..a4790071a 100644 --- a/pkg/services/object/put/validation.go +++ b/pkg/services/object/put/validation.go @@ -15,13 +15,18 @@ import ( "git.frostfs.info/TrueCloudLab/tzhash/tz" ) -// validatingTarget validates object format and content. +// validatingTarget validates unprepared object format and content (streaming PUT case). type validatingTarget struct { nextTarget transformer.ObjectTarget fmt *object.FormatValidator +} - unpreparedObject bool +// validatingPreparedTarget validates prepared object format and content. +type validatingPreparedTarget struct { + nextTarget transformer.ObjectTarget + + fmt *object.FormatValidator hash hash.Hash @@ -42,38 +47,52 @@ var ( ) func (t *validatingTarget) WriteHeader(ctx context.Context, obj *objectSDK.Object) error { + if err := t.fmt.Validate(ctx, obj, true); err != nil { + return fmt.Errorf("(%T) coult not validate object format: %w", t, err) + } + + return t.nextTarget.WriteHeader(ctx, obj) +} + +func (t *validatingTarget) Write(ctx context.Context, p []byte) (n int, err error) { + return t.nextTarget.Write(ctx, p) +} + +func (t *validatingTarget) Close(ctx context.Context) (*transformer.AccessIdentifiers, error) { + return t.nextTarget.Close(ctx) +} + +func (t *validatingPreparedTarget) WriteHeader(ctx context.Context, obj *objectSDK.Object) error { t.payloadSz = obj.PayloadSize() chunkLn := uint64(len(obj.Payload())) - if !t.unpreparedObject { - // check chunk size - if chunkLn > t.payloadSz { - return ErrWrongPayloadSize - } - - // check payload size limit - if t.payloadSz > t.maxPayloadSz { - return ErrExceedingMaxSize - } - - cs, csSet := obj.PayloadChecksum() - if !csSet { - return errors.New("missing payload checksum") - } - - switch typ := cs.Type(); typ { - default: - return fmt.Errorf("(%T) unsupported payload checksum type %v", t, typ) - case checksum.SHA256: - t.hash = sha256.New() - case checksum.TZ: - t.hash = tz.New() - } - - t.checksum = cs.Value() + // check chunk size + if chunkLn > t.payloadSz { + return ErrWrongPayloadSize } - if err := t.fmt.Validate(ctx, obj, t.unpreparedObject); err != nil { + // check payload size limit + if t.payloadSz > t.maxPayloadSz { + return ErrExceedingMaxSize + } + + cs, csSet := obj.PayloadChecksum() + if !csSet { + return errors.New("missing payload checksum") + } + + switch typ := cs.Type(); typ { + default: + return fmt.Errorf("(%T) unsupported payload checksum type %v", t, typ) + case checksum.SHA256: + t.hash = sha256.New() + case checksum.TZ: + t.hash = tz.New() + } + + t.checksum = cs.Value() + + if err := t.fmt.Validate(ctx, obj, false); err != nil { return fmt.Errorf("(%T) coult not validate object format: %w", t, err) } @@ -82,30 +101,26 @@ func (t *validatingTarget) WriteHeader(ctx context.Context, obj *objectSDK.Objec return err } - if !t.unpreparedObject { - // update written bytes - // - // Note: we MUST NOT add obj.PayloadSize() since obj - // can carry only the chunk of the full payload - t.writtenPayload += chunkLn - } + // update written bytes + // + // Note: we MUST NOT add obj.PayloadSize() since obj + // can carry only the chunk of the full payload + t.writtenPayload += chunkLn return nil } -func (t *validatingTarget) Write(ctx context.Context, p []byte) (n int, err error) { +func (t *validatingPreparedTarget) Write(ctx context.Context, p []byte) (n int, err error) { chunkLn := uint64(len(p)) - if !t.unpreparedObject { - // check if new chunk will overflow payload size - if t.writtenPayload+chunkLn > t.payloadSz { - return 0, ErrWrongPayloadSize - } + // check if new chunk will overflow payload size + if t.writtenPayload+chunkLn > t.payloadSz { + return 0, ErrWrongPayloadSize + } - _, err = t.hash.Write(p) - if err != nil { - return - } + _, err = t.hash.Write(p) + if err != nil { + return } n, err = t.nextTarget.Write(ctx, p) @@ -116,16 +131,14 @@ func (t *validatingTarget) Write(ctx context.Context, p []byte) (n int, err erro return } -func (t *validatingTarget) Close(ctx context.Context) (*transformer.AccessIdentifiers, error) { - if !t.unpreparedObject { - // check payload size correctness - if t.payloadSz != t.writtenPayload { - return nil, ErrWrongPayloadSize - } +func (t *validatingPreparedTarget) Close(ctx context.Context) (*transformer.AccessIdentifiers, error) { + // check payload size correctness + if t.payloadSz != t.writtenPayload { + return nil, ErrWrongPayloadSize + } - if !bytes.Equal(t.hash.Sum(nil), t.checksum) { - return nil, fmt.Errorf("(%T) incorrect payload checksum", t) - } + if !bytes.Equal(t.hash.Sum(nil), t.checksum) { + return nil, fmt.Errorf("(%T) incorrect payload checksum", t) } return t.nextTarget.Close(ctx)