From 31b04558828c143af85ee8f4c02e05d02e8a2cad Mon Sep 17 00:00:00 2001 From: Marina Biryukova Date: Fri, 28 Jun 2024 17:05:01 +0300 Subject: [PATCH] Use one get object in PATCH Signed-off-by: Marina Biryukova --- api/layer/layer.go | 183 +++++++++++++++++++++++++--------------- api/layer/layer_test.go | 102 ++++++++++++++++++++++ 2 files changed, 219 insertions(+), 66 deletions(-) create mode 100644 api/layer/layer_test.go diff --git a/api/layer/layer.go b/api/layer/layer.go index f6bb4e5c..90060787 100644 --- a/api/layer/layer.go +++ b/api/layer/layer.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "encoding/json" "encoding/xml" + stderrors "errors" "fmt" "io" "net/url" @@ -655,17 +656,17 @@ func (n *layer) CopyObject(ctx context.Context, p *CopyObjectParams) (*data.Exte } func (n *layer) PatchObject(ctx context.Context, p *PatchObjectParams) (*data.ExtendedObjectInfo, error) { - if p.Range.Start == p.SrcSize { - objPayload, err := n.GetObject(ctx, &GetObjectParams{ - ObjectInfo: p.Object, - Versioned: true, - BucketInfo: p.BktInfo, - Encryption: p.Encryption, - }) - if err != nil { - return nil, fmt.Errorf("get object to patch: %w", err) - } + objPayload, err := n.GetObject(ctx, &GetObjectParams{ + ObjectInfo: p.Object, + Versioned: true, + BucketInfo: p.BktInfo, + Encryption: p.Encryption, + }) + if err != nil { + return nil, fmt.Errorf("get object to patch: %w", err) + } + if p.Range.Start == p.SrcSize { return n.PutObject(ctx, &PutObjectParams{ BktInfo: p.BktInfo, Object: p.Object.Name, @@ -677,6 +678,7 @@ func (n *layer) PatchObject(ctx context.Context, p *PatchObjectParams) (*data.Ex }) } + var size uint64 if p.Range.Start == 0 { if p.Range.End >= p.SrcSize-1 { return n.PutObject(ctx, &PutObjectParams{ @@ -690,67 +692,18 @@ func (n *layer) PatchObject(ctx context.Context, p *PatchObjectParams) (*data.Ex }) } - objPayload, err := n.GetObject(ctx, &GetObjectParams{ - ObjectInfo: p.Object, - Range: &RangeParams{Start: p.Range.End + 1, End: p.SrcSize - 1}, - Versioned: true, - BucketInfo: p.BktInfo, - Encryption: p.Encryption, - }) - if err != nil { - return nil, fmt.Errorf("get object range to patch: %w", err) - } - - return n.PutObject(ctx, &PutObjectParams{ - BktInfo: p.BktInfo, - Object: p.Object.Name, - Size: p.SrcSize - 1 - p.Range.End + p.NewBytesSize, - Reader: io.MultiReader(p.NewBytes, objPayload), - Header: p.Header, - Encryption: p.Encryption, - CopiesNumbers: p.CopiesNumbers, - }) - } - - objPayload1, err := n.GetObject(ctx, &GetObjectParams{ - ObjectInfo: p.Object, - Range: &RangeParams{Start: 0, End: p.Range.Start - 1}, - Versioned: true, - BucketInfo: p.BktInfo, - Encryption: p.Encryption, - }) - if err != nil { - return nil, fmt.Errorf("get object range 1 to patch: %w", err) - } - - if p.Range.End >= p.SrcSize-1 { - return n.PutObject(ctx, &PutObjectParams{ - BktInfo: p.BktInfo, - Object: p.Object.Name, - Size: p.Range.Start + p.NewBytesSize, - Reader: io.MultiReader(objPayload1, p.NewBytes), - Header: p.Header, - Encryption: p.Encryption, - CopiesNumbers: p.CopiesNumbers, - }) - } - - objPayload2, err := n.GetObject(ctx, &GetObjectParams{ - ObjectInfo: p.Object, - Range: &RangeParams{Start: p.Range.End + 1, End: p.SrcSize - 1}, - Versioned: true, - BucketInfo: p.BktInfo, - Encryption: p.Encryption, - }) - if err != nil { - return nil, fmt.Errorf("get object range 2 to patch: %w", err) + size = p.SrcSize - 1 - p.Range.End + p.NewBytesSize + } else if p.Range.End >= p.SrcSize-1 { + size = p.Range.Start + p.NewBytesSize + } else { + size = p.SrcSize } return n.PutObject(ctx, &PutObjectParams{ BktInfo: p.BktInfo, Object: p.Object.Name, - Size: p.SrcSize, - Reader: io.MultiReader(objPayload1, p.NewBytes, objPayload2), + Size: size, + Reader: wrapPatchReader(objPayload, p.NewBytes, p.Range, 64*1024), Header: p.Header, Encryption: p.Encryption, CopiesNumbers: p.CopiesNumbers, @@ -993,3 +946,101 @@ func (n *layer) DeleteBucket(ctx context.Context, p *DeleteBucketParams) error { n.cache.DeleteBucket(p.BktInfo) return n.frostFS.DeleteContainer(ctx, p.BktInfo.CID, p.SessionToken) } + +func wrapPatchReader(payload, rngPayload io.Reader, rng *RangeParams, bufSize int) io.Reader { + if payload == nil || rngPayload == nil { + return nil + } + + r, w := io.Pipe() + go func() { + var buf = make([]byte, bufSize) + + if rng.Start == 0 { + err := readRange(rngPayload, w, buf) + if err != nil { + _ = w.CloseWithError(err) + return + } + + var readSize uint64 + for { + n, err := payload.Read(buf) + if err != nil && !stderrors.Is(err, io.EOF) { + _ = w.CloseWithError(err) + break + } + readSize += uint64(n) + if readSize > rng.End+1 { + var start uint64 + if readSize-rng.End-1 < uint64(n) { + start = uint64(n) - (readSize - rng.End - 1) + } + _, _ = w.Write(buf[start:n]) + } + if stderrors.Is(err, io.EOF) { + _ = w.CloseWithError(err) + break + } + } + } else { + var ( + readSize uint64 + readRng bool + ) + + for { + n, err := payload.Read(buf) + if err != nil && !stderrors.Is(err, io.EOF) { + _ = w.CloseWithError(err) + break + } + readSize += uint64(n) + if readSize <= rng.Start { + _, _ = w.Write(buf[:n]) + continue + } + if readSize-rng.Start < uint64(n) { + _, _ = w.Write(buf[:n-int(readSize-rng.Start)]) + } + if !readRng { + err = readRange(rngPayload, w, buf) + if err != nil { + _ = w.CloseWithError(err) + break + } + readRng = true + } + if readSize > rng.End+1 { + var start uint64 + if readSize-rng.End-1 < uint64(n) { + start = uint64(n) - (readSize - rng.End - 1) + } + _, _ = w.Write(buf[start:n]) + } + if stderrors.Is(err, io.EOF) { + _ = w.CloseWithError(err) + break + } + } + } + }() + return r +} + +func readRange(r io.Reader, w *io.PipeWriter, buf []byte) error { + for { + n, err := r.Read(buf) + if n > 0 { + _, _ = w.Write(buf[:n]) + } + if err != nil { + if !stderrors.Is(err, io.EOF) { + return err + } + break + } + } + + return nil +} diff --git a/api/layer/layer_test.go b/api/layer/layer_test.go new file mode 100644 index 00000000..aa55733c --- /dev/null +++ b/api/layer/layer_test.go @@ -0,0 +1,102 @@ +package layer + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWrapPatchReader(t *testing.T) { + payload := "abcdefghijklmn" + rngPayload := "123" + + for _, tc := range []struct { + name string + rng *RangeParams + bufSize int + expected string + }{ + { + name: "patch object start, buffer is less than range size", + rng: &RangeParams{ + Start: 0, + End: 2, + }, + bufSize: 2, + expected: "123defghijklmn", + }, + { + name: "patch object start, buffer is equal to range size", + rng: &RangeParams{ + Start: 0, + End: 2, + }, + bufSize: 3, + expected: "123defghijklmn", + }, + { + name: "patch object start, buffer is greater than range size", + rng: &RangeParams{ + Start: 0, + End: 2, + }, + bufSize: 4, + expected: "123defghijklmn", + }, + { + name: "patch object middle, range at the beginning of buffer", + rng: &RangeParams{ + Start: 5, + End: 7, + }, + bufSize: 5, + expected: "abcde123ijklmn", + }, + { + name: "patch object middle, range in the middle of buffer", + rng: &RangeParams{ + Start: 6, + End: 8, + }, + bufSize: 5, + expected: "abcdef123jklmn", + }, + { + name: "patch object middle, range in the end of buffer", + rng: &RangeParams{ + Start: 7, + End: 9, + }, + bufSize: 5, + expected: "abcdefg123klmn", + }, + { + name: "patch object end, increase size", + rng: &RangeParams{ + Start: 12, + End: 14, + }, + bufSize: 4, + expected: "abcdefghijkl123", + }, + { + name: "patch object end", + rng: &RangeParams{ + Start: 11, + End: 13, + }, + bufSize: 4, + expected: "abcdefghijk123", + }, + } { + t.Run(tc.name, func(t *testing.T) { + wrappedReader := wrapPatchReader(bytes.NewBufferString(payload), bytes.NewBufferString(rngPayload), tc.rng, tc.bufSize) + + res, err := io.ReadAll(wrappedReader) + require.NoError(t, err) + require.Equal(t, tc.expected, string(res)) + }) + } +}