diff --git a/api/handler/patch_test.go b/api/handler/patch_test.go index 65b12f2..92970b0 100644 --- a/api/handler/patch_test.go +++ b/api/handler/patch_test.go @@ -3,6 +3,7 @@ package handler import ( "bytes" "crypto/md5" + "crypto/rand" "crypto/sha256" "encoding/hex" "encoding/xml" @@ -107,6 +108,115 @@ func TestPatch(t *testing.T) { } } +func TestPatchMultipartObject(t *testing.T) { + tc := prepareHandlerContextWithMinCache(t) + + bktName, objName, partSize := "bucket-for-multipart-patch", "object-for-multipart-patch", 5*1024*1024 + createTestBucket(tc, bktName) + + // patch beginning of the first part + multipartInfo := createMultipartUpload(tc, bktName, objName, map[string]string{}) + etag1, data1 := uploadPart(tc, bktName, objName, multipartInfo.UploadID, 1, partSize) + etag2, data2 := uploadPart(tc, bktName, objName, multipartInfo.UploadID, 2, partSize) + etag3, data3 := uploadPart(tc, bktName, objName, multipartInfo.UploadID, 3, partSize) + completeMultipartUpload(tc, bktName, objName, multipartInfo.UploadID, []string{etag1, etag2, etag3}) + + patchSize := partSize / 2 + patchBody := make([]byte, patchSize) + _, err := rand.Read(patchBody) + require.NoError(t, err) + + patchObject(t, tc, bktName, objName, "bytes 0-"+strconv.Itoa(patchSize-1)+"/*", patchBody, nil) + data, header := getObject(tc, bktName, objName) + contentLen, err := strconv.Atoi(header.Get(api.ContentLength)) + require.NoError(t, err) + require.Equal(t, bytes.Join([][]byte{patchBody, data1[patchSize:], data2, data3}, []byte("")), data) + require.Equal(t, partSize*3, contentLen) + + // patch middle of the first part + multipartInfo = createMultipartUpload(tc, bktName, objName, map[string]string{}) + etag1, data1 = uploadPart(tc, bktName, objName, multipartInfo.UploadID, 1, partSize) + etag2, data2 = uploadPart(tc, bktName, objName, multipartInfo.UploadID, 2, partSize) + etag3, data3 = uploadPart(tc, bktName, objName, multipartInfo.UploadID, 3, partSize) + completeMultipartUpload(tc, bktName, objName, multipartInfo.UploadID, []string{etag1, etag2, etag3}) + + patchObject(t, tc, bktName, objName, "bytes "+strconv.Itoa(partSize/4)+"-"+strconv.Itoa(partSize*3/4-1)+"/*", patchBody, nil) + data, header = getObject(tc, bktName, objName) + contentLen, err = strconv.Atoi(header.Get(api.ContentLength)) + require.NoError(t, err) + require.Equal(t, bytes.Join([][]byte{data1[:partSize/4], patchBody, data1[partSize*3/4:], data2, data3}, []byte("")), data) + require.Equal(t, partSize*3, contentLen) + + // patch few parts + multipartInfo = createMultipartUpload(tc, bktName, objName, map[string]string{}) + etag1, data1 = uploadPart(tc, bktName, objName, multipartInfo.UploadID, 1, partSize) + etag2, _ = uploadPart(tc, bktName, objName, multipartInfo.UploadID, 2, partSize) + etag3, data3 = uploadPart(tc, bktName, objName, multipartInfo.UploadID, 3, partSize) + completeMultipartUpload(tc, bktName, objName, multipartInfo.UploadID, []string{etag1, etag2, etag3}) + + patchSize = partSize * 2 + patchBody = make([]byte, patchSize) + _, err = rand.Read(patchBody) + require.NoError(t, err) + + patchObject(t, tc, bktName, objName, "bytes "+strconv.Itoa(partSize/2-1)+"-"+strconv.Itoa(partSize/2+patchSize-2)+"/*", patchBody, nil) + data, header = getObject(tc, bktName, objName) + contentLen, err = strconv.Atoi(header.Get(api.ContentLength)) + require.NoError(t, err) + require.Equal(t, bytes.Join([][]byte{data1[:partSize/2-1], patchBody, data3[partSize/2-1:]}, []byte("")), data) + require.Equal(t, partSize*3, contentLen) + + // patch last part + multipartInfo = createMultipartUpload(tc, bktName, objName, map[string]string{}) + etag1, data1 = uploadPart(tc, bktName, objName, multipartInfo.UploadID, 1, partSize) + etag2, data2 = uploadPart(tc, bktName, objName, multipartInfo.UploadID, 2, partSize) + etag3, _ = uploadPart(tc, bktName, objName, multipartInfo.UploadID, 3, partSize) + completeMultipartUpload(tc, bktName, objName, multipartInfo.UploadID, []string{etag1, etag2, etag3}) + + patchBody = make([]byte, partSize) + _, err = rand.Read(patchBody) + require.NoError(t, err) + + patchObject(t, tc, bktName, objName, "bytes "+strconv.Itoa(partSize*2)+"-"+strconv.Itoa(partSize*3-1)+"/*", patchBody, nil) + data, header = getObject(tc, bktName, objName) + contentLen, err = strconv.Atoi(header.Get(api.ContentLength)) + require.NoError(t, err) + require.Equal(t, bytes.Join([][]byte{data1, data2, patchBody}, []byte("")), data) + require.Equal(t, partSize*3, contentLen) + + // patch last part and append bytes + multipartInfo = createMultipartUpload(tc, bktName, objName, map[string]string{}) + etag1, data1 = uploadPart(tc, bktName, objName, multipartInfo.UploadID, 1, partSize) + etag2, data2 = uploadPart(tc, bktName, objName, multipartInfo.UploadID, 2, partSize) + etag3, data3 = uploadPart(tc, bktName, objName, multipartInfo.UploadID, 3, partSize) + completeMultipartUpload(tc, bktName, objName, multipartInfo.UploadID, []string{etag1, etag2, etag3}) + + patchObject(t, tc, bktName, objName, "bytes "+strconv.Itoa(partSize*2+3)+"-"+strconv.Itoa(partSize*3+2)+"/*", patchBody, nil) + data, header = getObject(tc, bktName, objName) + contentLen, err = strconv.Atoi(header.Get(api.ContentLength)) + require.NoError(t, err) + require.Equal(t, bytes.Join([][]byte{data1, data2, data3[:3], patchBody}, []byte("")), data) + require.Equal(t, partSize*3+3, contentLen) + + // append bytes + multipartInfo = createMultipartUpload(tc, bktName, objName, map[string]string{}) + etag1, data1 = uploadPart(tc, bktName, objName, multipartInfo.UploadID, 1, partSize) + etag2, data2 = uploadPart(tc, bktName, objName, multipartInfo.UploadID, 2, partSize) + etag3, data3 = uploadPart(tc, bktName, objName, multipartInfo.UploadID, 3, partSize) + completeMultipartUpload(tc, bktName, objName, multipartInfo.UploadID, []string{etag1, etag2, etag3}) + + patchBody = make([]byte, partSize) + _, err = rand.Read(patchBody) + require.NoError(t, err) + + patchObject(t, tc, bktName, objName, "bytes "+strconv.Itoa(partSize*3)+"-"+strconv.Itoa(partSize*4-1)+"/*", patchBody, nil) + data, header = getObject(tc, bktName, objName) + contentLen, err = strconv.Atoi(header.Get(api.ContentLength)) + require.NoError(t, err) + require.Equal(t, bytes.Join([][]byte{data1, data2, data3, patchBody}, []byte("")), data) + require.Equal(t, partSize*4, contentLen) +} + func TestPatchWithVersion(t *testing.T) { hc := prepareHandlerContextWithMinCache(t) bktName, objName := "bucket", "obj" diff --git a/api/layer/frostfs.go b/api/layer/frostfs.go index 4d59768..e23a2b9 100644 --- a/api/layer/frostfs.go +++ b/api/layer/frostfs.go @@ -209,10 +209,13 @@ type PrmObjectPatch struct { Payload io.Reader // Object range to patch. - Range *RangeParams + Range RangeParams // Size of original object payload. ObjectSize uint64 + + // New object attributes. + Attributes []object.Attribute } var ( diff --git a/api/layer/frostfs_mock.go b/api/layer/frostfs_mock.go index 91e3661..fba950f 100644 --- a/api/layer/frostfs_mock.go +++ b/api/layer/frostfs_mock.go @@ -435,6 +435,7 @@ func (t *TestFrostFS) PatchObject(ctx context.Context, prm PrmObjectPatch) (oid. newID := oidtest.ID() newObj.SetID(newID) + newObj.SetAttributes(mergeAttributes(obj.Attributes(), prm.Attributes)...) t.objects[newAddress(prm.Container, newID).EncodeToString()] = &newObj @@ -470,3 +471,20 @@ func isMatched(attributes []object.Attribute, filter object.SearchFilter) bool { } return false } + +func mergeAttributes(oldAttributes []object.Attribute, newAttributes []object.Attribute) []object.Attribute { + for i := range newAttributes { + var found bool + for j := range oldAttributes { + if oldAttributes[j].Key() == newAttributes[i].Key() { + oldAttributes[j].SetValue(newAttributes[i].Value()) + found = true + break + } + } + if !found { + oldAttributes = append(oldAttributes, newAttributes[i]) + } + } + return oldAttributes +} diff --git a/api/layer/layer.go b/api/layer/layer.go index 3f01afe..3305476 100644 --- a/api/layer/layer.go +++ b/api/layer/layer.go @@ -1,6 +1,7 @@ package layer import ( + "bytes" "context" "crypto/ecdsa" "crypto/rand" @@ -27,6 +28,7 @@ import ( "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client" cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object" oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/session" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user" @@ -547,15 +549,14 @@ func (n *Layer) PatchObject(ctx context.Context, p *PatchObjectParams) (*data.Ex } if p.Object.Headers[MultipartObjectSize] != "" { - // TODO: support multipart object patch - return nil, fmt.Errorf("patch multipart object") + return n.patchMultipartObject(ctx, p) } prmPatch := PrmObjectPatch{ Container: p.BktInfo.CID, Object: p.Object.ID, Payload: p.NewBytes, - Range: p.Range, + Range: *p.Range, ObjectSize: p.Object.Size, } n.prepareAuthParameters(ctx, &prmPatch.PrmAuth, p.BktInfo.Owner) @@ -601,6 +602,158 @@ func (n *Layer) PatchObject(ctx context.Context, p *PatchObjectParams) (*data.Ex }, nil } +func (n *Layer) patchMultipartObject(ctx context.Context, p *PatchObjectParams) (*data.ExtendedObjectInfo, error) { + combinedObj, err := n.objectGet(ctx, p.BktInfo, p.Object.ID) + if err != nil { + return nil, fmt.Errorf("get combined object '%s': %w", p.Object.ID.EncodeToString(), err) + } + + var parts []*data.PartInfo + if err = json.NewDecoder(combinedObj.Payload).Decode(&parts); err != nil { + return nil, fmt.Errorf("unmarshal combined object parts: %w", err) + } + + prmPatch := PrmObjectPatch{ + Container: p.BktInfo.CID, + } + n.prepareAuthParameters(ctx, &prmPatch.PrmAuth, p.BktInfo.Owner) + + prmHead := PrmObjectHead{ + PrmAuth: prmPatch.PrmAuth, + Container: p.BktInfo.CID, + } + + allPatchLen := p.Range.End - p.Range.Start + 1 + patchReader := newPartialReader(p.NewBytes, 64*1024) + var multipartObjectSize, patchedLen uint64 + for i, part := range parts { + if patchedLen == allPatchLen { + multipartObjectSize += part.Size + continue + } + + if p.Range.Start > part.Size || (p.Range.Start == part.Size && i != len(parts)-1) { + multipartObjectSize += part.Size + p.Range.Start -= part.Size + p.Range.End -= part.Size + continue + } + + prmPatch.Object = part.OID + prmPatch.ObjectSize = part.Size + prmPatch.Range = *p.Range + + var patchLen uint64 + if i != len(parts)-1 { + if prmPatch.Range.End >= part.Size-1 { + prmPatch.Range.End = part.Size - 1 + } + patchLen = prmPatch.Range.End - prmPatch.Range.Start + 1 + } else { + patchLen = allPatchLen - patchedLen + } + prmPatch.Payload = patchReader.partReader(patchLen) + patchedLen += patchLen + + objID, err := n.frostFS.PatchObject(ctx, prmPatch) + if err != nil { + return nil, fmt.Errorf("patch object '%s': %w", part.OID.EncodeToString(), err) + } + + prmHead.Object = objID + obj, err := n.frostFS.HeadObject(ctx, prmHead) + if err != nil { + return nil, fmt.Errorf("head object '%s': %w", objID.EncodeToString(), err) + } + + payloadChecksum, _ := obj.PayloadChecksum() + hashSum := hex.EncodeToString(payloadChecksum.Value()) + + parts[i].OID = objID + parts[i].Size = obj.PayloadSize() + parts[i].MD5 = "" + parts[i].ETag = hashSum + + multipartObjectSize += obj.PayloadSize() + + if p.Range.Start > 0 { + p.Range.Start = 0 + } + p.Range.End -= part.Size + } + + newParts, err := json.Marshal(parts) + if err != nil { + return nil, fmt.Errorf("marshal parts for combined object: %w", err) + } + + prmPatch.Object = p.Object.ID + prmPatch.ObjectSize = p.Object.Size + prmPatch.Range.Start = 0 + prmPatch.Range.End = p.Object.Size - 1 + prmPatch.Payload = bytes.NewReader(newParts) + prmPatch.Attributes = make([]object.Attribute, 0, 2) + + var a object.Attribute + a.SetKey(MultipartObjectSize) + a.SetValue(strconv.FormatUint(multipartObjectSize, 10)) + prmPatch.Attributes = append(prmPatch.Attributes, a) + + var headerParts strings.Builder + for i, part := range parts { + headerPart := part.ToHeaderString() + if i != len(parts)-1 { + headerPart += "," + } + headerParts.WriteString(headerPart) + } + + a.SetKey(UploadCompletedParts) + a.SetValue(headerParts.String()) + prmPatch.Attributes = append(prmPatch.Attributes, a) + + objID, err := n.frostFS.PatchObject(ctx, prmPatch) + if err != nil { + return nil, fmt.Errorf("patch completed object: %w", err) + } + + prmHead.Object = objID + obj, err := n.frostFS.HeadObject(ctx, prmHead) + if err != nil { + return nil, fmt.Errorf("head completed object: %w", err) + } + + payloadChecksum, _ := obj.PayloadChecksum() + hashSum := hex.EncodeToString(payloadChecksum.Value()) + newVersion := &data.NodeVersion{ + BaseNodeVersion: data.BaseNodeVersion{ + OID: objID, + ETag: hashSum, + FilePath: p.Object.Name, + Size: obj.PayloadSize(), + Created: &p.Object.Created, + Owner: &n.gateOwner, + // TODO: Add creation epoch + }, + IsUnversioned: !p.VersioningEnabled, + IsCombined: p.Object.Headers[MultipartObjectSize] != "", + } + + if newVersion.ID, err = n.treeService.AddVersion(ctx, p.BktInfo, newVersion); err != nil { + return nil, fmt.Errorf("couldn't add new verion to tree service: %w", err) + } + + p.Object.ID = objID + p.Object.Size = obj.PayloadSize() + p.Object.MD5Sum = "" + p.Object.HashSum = hashSum + + return &data.ExtendedObjectInfo{ + ObjectInfo: p.Object, + NodeVersion: newVersion, + }, nil +} + func getRandomOID() (oid.ID, error) { b := [32]byte{} if _, err := rand.Read(b[:]); err != nil { diff --git a/api/layer/partial_reader.go b/api/layer/partial_reader.go new file mode 100644 index 0000000..fd6c1e6 --- /dev/null +++ b/api/layer/partial_reader.go @@ -0,0 +1,67 @@ +package layer + +import "io" + +type partialReader struct { + r io.Reader + buf []byte + bufRemains int + bufRemainsStart int +} + +func newPartialReader(r io.Reader, bufSize int) *partialReader { + return &partialReader{ + r: r, + buf: make([]byte, bufSize), + } +} + +func (p *partialReader) partReader(length uint64) io.Reader { + r, w := io.Pipe() + + go func() { + if p.bufRemains > 0 { + if length <= uint64(p.bufRemains) { + _, _ = w.Write(p.buf[p.bufRemainsStart : p.bufRemainsStart+int(length)]) + p.bufRemains -= int(length) + p.bufRemainsStart += int(length) + if p.bufRemains == 0 { + p.bufRemainsStart = 0 + } + _ = w.CloseWithError(io.EOF) + return + } + + _, _ = w.Write(p.buf[p.bufRemainsStart : p.bufRemainsStart+p.bufRemains]) + length -= uint64(p.bufRemains) + p.bufRemains = 0 + p.bufRemainsStart = 0 + } + + for { + n, err := p.r.Read(p.buf) + if n > 0 { + if length <= uint64(n) { + _, _ = w.Write(p.buf[:length]) + p.bufRemains = n - int(length) + p.bufRemainsStart = int(length) + if p.bufRemains == 0 { + p.bufRemainsStart = 0 + } + _ = w.CloseWithError(io.EOF) + break + } + + _, _ = w.Write(p.buf[:n]) + length -= uint64(n) + } + + if err != nil { + _ = w.CloseWithError(err) + break + } + } + }() + + return r +} diff --git a/api/layer/partial_reader_test.go b/api/layer/partial_reader_test.go new file mode 100644 index 0000000..a6fd196 --- /dev/null +++ b/api/layer/partial_reader_test.go @@ -0,0 +1,101 @@ +package layer + +import ( + "io" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPartialReader(t *testing.T) { + for _, tc := range []struct { + name string + reader io.Reader + bufSize int + lengths []uint64 + expectedBytes [][]byte + }{ + { + name: "buf size is equal to length", + reader: strings.NewReader("abcdef"), + bufSize: 3, + lengths: []uint64{3, 3}, + expectedBytes: [][]byte{[]byte("abc"), []byte("def")}, + }, + { + name: "buf size is equal to length, read more than write", + reader: strings.NewReader("abcdef"), + bufSize: 3, + lengths: []uint64{3, 3, 3}, + expectedBytes: [][]byte{[]byte("abc"), []byte("def"), []byte("")}, + }, + { + name: "buf size is greater than length", + reader: strings.NewReader("abcdefg"), + bufSize: 4, + lengths: []uint64{3, 3}, + expectedBytes: [][]byte{[]byte("abc"), []byte("def")}, + }, + { + name: "buf size is greater than length, read more than write", + reader: strings.NewReader("abcdefg"), + bufSize: 4, + lengths: []uint64{3, 3, 3}, + expectedBytes: [][]byte{[]byte("abc"), []byte("def"), []byte("g")}, + }, + { + name: "buf size is less than length", + reader: strings.NewReader("abcdefg"), + bufSize: 2, + lengths: []uint64{3, 3}, + expectedBytes: [][]byte{[]byte("abc"), []byte("def")}, + }, + { + name: "buf size is less than length, read more than write", + reader: strings.NewReader("abcdefg"), + bufSize: 2, + lengths: []uint64{3, 3, 3}, + expectedBytes: [][]byte{[]byte("abc"), []byte("def"), []byte("g")}, + }, + { + name: "length is divided by buf size", + reader: strings.NewReader("abcdefghi"), + bufSize: 2, + lengths: []uint64{4, 4}, + expectedBytes: [][]byte{[]byte("abcd"), []byte("efgh")}, + }, + { + name: "length is divided by buf size, read more than write", + reader: strings.NewReader("abcdefghij"), + bufSize: 2, + lengths: []uint64{4, 4, 4}, + expectedBytes: [][]byte{[]byte("abcd"), []byte("efgh"), []byte("ij")}, + }, + { + name: "buf size is divided by length", + reader: strings.NewReader("abcdefg"), + bufSize: 4, + lengths: []uint64{2, 2, 2}, + expectedBytes: [][]byte{[]byte("ab"), []byte("cd"), []byte("ef")}, + }, + { + name: "buf size is divided by length, read more than write", + reader: strings.NewReader("abcdefg"), + bufSize: 4, + lengths: []uint64{2, 2, 2, 2}, + expectedBytes: [][]byte{[]byte("ab"), []byte("cd"), []byte("ef"), []byte("g")}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + r := newPartialReader(tc.reader, tc.bufSize) + + for i, length := range tc.lengths { + partR := r.partReader(length) + res, err := io.ReadAll(partR) + require.NoError(t, err) + require.Equal(t, tc.expectedBytes[i], res) + } + }) + } +} diff --git a/internal/frostfs/frostfs.go b/internal/frostfs/frostfs.go index 54519b5..99656e0 100644 --- a/internal/frostfs/frostfs.go +++ b/internal/frostfs/frostfs.go @@ -406,6 +406,7 @@ func (x *FrostFS) PatchObject(ctx context.Context, prm layer.PrmObjectPatch) (oi prmPatch.SetRange(&rng) prmPatch.SetPayloadReader(prm.Payload) + prmPatch.SetNewAttributes(prm.Attributes) if prm.BearerToken != nil { prmPatch.UseBearer(*prm.BearerToken)