diff --git a/client/object_patch.go b/client/object_patch.go index b30bebe..ca7e8ed 100644 --- a/client/object_patch.go +++ b/client/object_patch.go @@ -106,7 +106,6 @@ func (c *Client) ObjectPatchInit(ctx context.Context, prm PrmObjectPatch) (Objec } objectPatcher.client = c objectPatcher.stream = stream - objectPatcher.firstPatchPayload = true if prm.MaxChunkLength > 0 { objectPatcher.maxChunkLen = prm.MaxChunkLength @@ -154,8 +153,6 @@ type objectPatcher struct { respV2 v2object.PatchResponse maxChunkLen int - - firstPatchPayload bool } func (x *objectPatcher) PatchAttributes(_ context.Context, newAttrs []object.Attribute, replace bool) bool { @@ -171,19 +168,33 @@ func (x *objectPatcher) PatchPayload(_ context.Context, rng *object.Range, paylo buf := make([]byte, x.maxChunkLen) - for { + for patchIter := 0; ; patchIter++ { n, err := payloadReader.Read(buf) if err != nil && err != io.EOF { x.err = fmt.Errorf("read payload: %w", err) return false } if n == 0 { + if patchIter == 0 { + if rng.GetLength() == 0 { + x.err = errors.New("zero-length empty payload patch can't be applied") + return false + } + if !x.patch(&object.Patch{ + Address: x.addr, + PayloadPatch: &object.PayloadPatch{ + Range: rng, + Chunk: []byte{}, + }, + }) { + return false + } + } break } rngPart := object.NewRange() - if x.firstPatchPayload { - x.firstPatchPayload = false + if patchIter == 0 { rngPart.SetOffset(offset) rngPart.SetLength(rng.GetLength()) } else { diff --git a/client/object_patch_test.go b/client/object_patch_test.go index 9c87820..839c453 100644 --- a/client/object_patch_test.go +++ b/client/object_patch_test.go @@ -170,12 +170,11 @@ func TestObjectPatcher(t *testing.T) { pk, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) patcher := objectPatcher{ - client: &Client{}, - stream: m, - addr: oidtest.Address(), - key: pk, - maxChunkLen: test.maxChunkLen, - firstPatchPayload: true, + client: &Client{}, + stream: m, + addr: oidtest.Address(), + key: pk, + maxChunkLen: test.maxChunkLen, } success := patcher.PatchAttributes(context.Background(), nil, false) @@ -194,6 +193,93 @@ func TestObjectPatcher(t *testing.T) { } } +func TestRepeatPayloadPatch(t *testing.T) { + t.Run("no payload patch partioning", func(t *testing.T) { + m := &mockPatchStream{} + + pk, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + const maxChunkLen = 20 + + patcher := objectPatcher{ + client: &Client{}, + stream: m, + addr: oidtest.Address(), + key: pk, + maxChunkLen: maxChunkLen, + } + + for _, pp := range []struct { + patchPayload string + rng *object.Range + }{ + { + patchPayload: "xxxxxxxxxx", + rng: newRange(1, 6), + }, + { + patchPayload: "yyyyyyyyyy", + rng: newRange(5, 9), + }, + { + patchPayload: "zzzzzzzzzz", + rng: newRange(10, 0), + }, + } { + success := patcher.PatchPayload(context.Background(), pp.rng, bytes.NewReader([]byte(pp.patchPayload))) + require.True(t, success) + } + + requireRangeChunk(t, m.streamedPayloadPatches[0], 1, 6, "xxxxxxxxxx") + requireRangeChunk(t, m.streamedPayloadPatches[1], 5, 9, "yyyyyyyyyy") + requireRangeChunk(t, m.streamedPayloadPatches[2], 10, 0, "zzzzzzzzzz") + }) + + t.Run("payload patch partioning", func(t *testing.T) { + m := &mockPatchStream{} + + pk, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + const maxChunkLen = 5 + + patcher := objectPatcher{ + client: &Client{}, + stream: m, + addr: oidtest.Address(), + key: pk, + maxChunkLen: maxChunkLen, + } + + for _, pp := range []struct { + patchPayload string + rng *object.Range + }{ + { + patchPayload: "xxxxxxxxxx", + rng: newRange(1, 6), + }, + { + patchPayload: "yyyyyyyyyy", + rng: newRange(5, 9), + }, + { + patchPayload: "zzzzzzzzzz", + rng: newRange(10, 0), + }, + } { + success := patcher.PatchPayload(context.Background(), pp.rng, bytes.NewReader([]byte(pp.patchPayload))) + require.True(t, success) + } + + requireRangeChunk(t, m.streamedPayloadPatches[0], 1, 6, "xxxxx") + requireRangeChunk(t, m.streamedPayloadPatches[1], 7, 0, "xxxxx") + requireRangeChunk(t, m.streamedPayloadPatches[2], 5, 9, "yyyyy") + requireRangeChunk(t, m.streamedPayloadPatches[3], 14, 0, "yyyyy") + requireRangeChunk(t, m.streamedPayloadPatches[4], 10, 0, "zzzzz") + requireRangeChunk(t, m.streamedPayloadPatches[5], 10, 0, "zzzzz") + }) +} + func requireRangeChunk(t *testing.T, pp *object.PayloadPatch, offset, length int, chunk string) { require.NotNil(t, pp) require.Equal(t, uint64(offset), pp.Range.GetOffset())