diff --git a/api/handler/acl_test.go b/api/handler/acl_test.go index fb1c46b..6f0ae90 100644 --- a/api/handler/acl_test.go +++ b/api/handler/acl_test.go @@ -7,6 +7,7 @@ import ( "encoding/xml" "net/http" "net/http/httptest" + "strconv" "testing" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api" @@ -28,7 +29,7 @@ func TestPutObjectACLErrorAPE(t *testing.T) { info := createBucket(hc, bktName) - putObjectWithHeadersAssertS3Error(hc, bktName, objName, map[string]string{api.AmzACL: basicACLPublic}, s3errors.ErrAccessControlListNotSupported) + putObjectWithHeadersAssertS3Error(hc, bktName, objName, "", map[string]string{api.AmzACL: basicACLPublic}, s3errors.ErrAccessControlListNotSupported) putObjectWithHeaders(hc, bktName, objName, map[string]string{api.AmzACL: basicACLPrivate}) // only `private` canned acl is allowed, that is actually ignored putObjectWithHeaders(hc, bktName, objName, nil) @@ -396,8 +397,14 @@ func putObjectWithHeaders(hc *handlerContext, bktName, objName string, headers m return w.Header() } -func putObjectWithHeadersAssertS3Error(hc *handlerContext, bktName, objName string, headers map[string]string, code s3errors.ErrorCode) { - w := putObjectWithHeadersBase(hc, bktName, objName, headers, nil, nil) +func putObjectContentWithHeaders(hc *handlerContext, bktName, objName, content string, headers map[string]string) http.Header { + w := putObjectWithHeadersBase(hc, bktName, objName, headers, nil, []byte(content)) + assertStatus(hc.t, w, http.StatusOK) + return w.Header() +} + +func putObjectWithHeadersAssertS3Error(hc *handlerContext, bktName, objName, content string, headers map[string]string, code s3errors.ErrorCode) { + w := putObjectWithHeadersBase(hc, bktName, objName, headers, nil, []byte(content)) assertS3Error(hc.t, w, s3errors.GetAPIError(code)) } @@ -408,6 +415,7 @@ func putObjectWithHeadersBase(hc *handlerContext, bktName, objName string, heade for k, v := range headers { r.Header.Set(k, v) } + r.Header.Set(api.ContentLength, strconv.Itoa(len(data))) ctx := middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box}) r = r.WithContext(ctx) diff --git a/api/handler/get.go b/api/handler/get.go index acdc375..302c64d 100644 --- a/api/handler/get.go +++ b/api/handler/get.go @@ -25,7 +25,7 @@ type conditionalArgs struct { func fetchRangeHeader(headers http.Header, fullSize uint64) (*layer.RangeParams, error) { const prefix = "bytes=" - rangeHeader := headers.Get("Range") + rangeHeader := headers.Get(api.Range) if len(rangeHeader) == 0 { return nil, nil } diff --git a/api/handler/put.go b/api/handler/put.go index 8a7a10f..e878e13 100644 --- a/api/handler/put.go +++ b/api/handler/put.go @@ -190,6 +190,11 @@ func (h *handler) PutObjectHandler(w http.ResponseWriter, r *http.Request) { reqInfo = middleware.GetReqInfo(ctx) ) + if rangeStr := r.Header.Get(api.Range); rangeStr != "" { + h.putObjectWithRange(w, r) + return + } + bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName) if err != nil { h.logAndSendError(w, "could not get bucket objInfo", reqInfo, err) @@ -310,6 +315,171 @@ func (h *handler) PutObjectHandler(w http.ResponseWriter, r *http.Request) { } } +func (h *handler) putObjectWithRange(w http.ResponseWriter, r *http.Request) { + var ( + err error + ctx = r.Context() + reqInfo = middleware.GetReqInfo(ctx) + rangeStr = r.Header.Get(api.Range) + ) + + bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName) + if err != nil { + h.logAndSendError(w, "could not get bucket objInfo", reqInfo, err) + return + } + + settings, err := h.obj.GetBucketSettings(ctx, bktInfo) + if err != nil { + h.logAndSendError(w, "could not get bucket settings", reqInfo, err) + return + } + + body, err := h.getBodyReader(r) + if err != nil { + h.logAndSendError(w, "failed to get body reader", reqInfo, err) + return + } + + srcObjPrm := &layer.HeadObjectParams{ + Object: reqInfo.ObjectName, + BktInfo: bktInfo, + VersionID: reqInfo.URL.Query().Get(api.QueryVersionID), + } + + extendedSrcObjInfo, err := h.obj.GetExtendedObjectInfo(ctx, srcObjPrm) + if err != nil { + h.logAndSendError(w, "could not find object", reqInfo, err) + return + } + + srcSize, err := layer.GetObjectSize(extendedSrcObjInfo.ObjectInfo) + if err != nil { + h.logAndSendError(w, "failed to get source object size", reqInfo, err) + return + } + + var contentLen uint64 + if r.ContentLength > 0 { + contentLen = uint64(r.ContentLength) + } + + byteRange, overwrite, err := parsePutRange(rangeStr, srcSize, contentLen) + if err != nil { + h.logAndSendError(w, "could not parse byte range", reqInfo, errors.GetAPIError(errors.ErrInvalidRange), zap.Error(err)) + return + } + + if maxPatchSize < byteRange.End-byteRange.Start+1 { + h.logAndSendError(w, "byte range length is longer than allowed", reqInfo, errors.GetAPIError(errors.ErrInvalidRange), zap.Error(err)) + return + } + + if !overwrite && contentLen != byteRange.End-byteRange.Start+1 { + h.logAndSendError(w, "content-length must be equal to byte range length", reqInfo, errors.GetAPIError(errors.ErrInvalidRangeLength), zap.Error(err)) + return + } + + if byteRange.Start > srcSize { + h.logAndSendError(w, "start byte is greater than object size", reqInfo, errors.GetAPIError(errors.ErrRangeOutOfBounds)) + return + } + + params := &layer.PatchObjectParams{ + Object: extendedSrcObjInfo, + BktInfo: bktInfo, + NewBytes: body, + Range: byteRange, + VersioningEnabled: settings.VersioningEnabled(), + Overwrite: overwrite, + } + + params.CopiesNumbers, err = h.pickCopiesNumbers(nil, reqInfo.Namespace, bktInfo.LocationConstraint) + if err != nil { + h.logAndSendError(w, "invalid copies number", reqInfo, err) + return + } + + extendedObjInfo, err := h.obj.PatchObject(ctx, params) + if err != nil { + if isErrObjectLocked(err) { + h.logAndSendError(w, "object is locked", reqInfo, errors.GetAPIError(errors.ErrAccessDenied)) + } else { + h.logAndSendError(w, "could not patch object", reqInfo, err) + } + return + } + + if settings.VersioningEnabled() { + w.Header().Set(api.AmzVersionID, extendedObjInfo.ObjectInfo.VersionID()) + } + + w.Header().Set(api.ETag, data.Quote(extendedObjInfo.ObjectInfo.ETag(h.cfg.MD5Enabled()))) + + if err = middleware.WriteSuccessResponseHeadersOnly(w); err != nil { + h.logAndSendError(w, "write response", reqInfo, err) + } +} + +func parsePutRange(rangeStr string, objSize, contentLen uint64) (*layer.RangeParams, bool, error) { + const prefix = "bytes=" + var overwrite bool + + if rangeStr == "" { + return nil, overwrite, fmt.Errorf("empty range") + } + + if !strings.HasPrefix(rangeStr, prefix) { + return nil, overwrite, fmt.Errorf("unknown unit in range header") + } + + rangeStr = strings.TrimPrefix(rangeStr, prefix) + i := strings.LastIndex(rangeStr, "-") + if i < 0 { + return nil, overwrite, fmt.Errorf("invalid range: %s", rangeStr) + } + + startStr, endStr := rangeStr[:i], rangeStr[i+1:] + start, err := strconv.ParseInt(startStr, 10, 64) + if err != nil { + return nil, overwrite, fmt.Errorf("invalid start byte: %s", startStr) + } + + if start == -1 && len(endStr) == 0 { + return &layer.RangeParams{ + Start: objSize, + End: objSize + contentLen - 1, + }, overwrite, nil + } + + if start < 0 { + return nil, overwrite, fmt.Errorf("invalid range: %s", rangeStr) + } + + end := uint64(start) + contentLen - 1 + if contentLen == 0 { + end = objSize - 1 + } + + if len(endStr) > 0 { + end, err = strconv.ParseUint(endStr, 10, 64) + if err != nil { + return nil, overwrite, fmt.Errorf("invalid end byte: %s", endStr) + } + } else { + overwrite = true + } + + if uint64(start) > end { + return nil, overwrite, fmt.Errorf("start byte is greater than end byte") + } + + return &layer.RangeParams{ + Start: uint64(start), + End: end, + }, overwrite, nil +} + func (h *handler) getBodyReader(r *http.Request) (io.ReadCloser, error) { if !api.IsSignedStreamingV4(r) { return r.Body, nil diff --git a/api/handler/put_test.go b/api/handler/put_test.go index 570dfb6..df6237c 100644 --- a/api/handler/put_test.go +++ b/api/handler/put_test.go @@ -4,9 +4,11 @@ import ( "bytes" "context" "crypto/md5" + "crypto/rand" "encoding/base64" "encoding/hex" "encoding/json" + "fmt" "io" "mime/multipart" "net/http" @@ -34,6 +36,430 @@ const ( awsChunkedRequestExampleContentLength = 66824 ) +func TestPutObjectWithRange(t *testing.T) { + hc := prepareHandlerContext(t) + + bktName, objName, partSize := "bucket-put-range", "object-put-range", 5*1024*1024 + putObjectWithHeadersAssertS3Error(hc, bktName, objName, "content", map[string]string{api.Range: "bytes=-1-"}, s3errors.ErrNoSuchBucket) + + createBucket(hc, bktName) + putObjectWithHeadersAssertS3Error(hc, bktName, objName, "content", map[string]string{api.Range: "bytes=-1-"}, s3errors.ErrNoSuchKey) + + for _, tt := range []struct { + name string + createObj func() + rng string + rngContent []byte + expected func([]byte, []byte) []byte + code s3errors.ErrorCode + }{ + { + name: "append empty regular object", + createObj: func() { + putObjectContent(hc, bktName, objName, "") + }, + rng: "bytes=-1-", + rngContent: []byte("content"), + expected: func(_, rngContent []byte) []byte { + return rngContent + }, + }, + { + name: "append regular object", + createObj: func() { + putObjectContent(hc, bktName, objName, "object") + }, + rng: "bytes=-1-", + rngContent: []byte("content"), + expected: func(objContent, rngContent []byte) []byte { + return bytes.Join([][]byte{objContent, rngContent}, []byte{}) + }, + }, + { + name: "append empty multipart object", + createObj: func() { + multipartInfo := createMultipartUpload(hc, bktName, objName, map[string]string{}) + etag1, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 1, 0) + completeMultipartUpload(hc, bktName, objName, multipartInfo.UploadID, []string{etag1}) + }, + rng: "bytes=-1-", + rngContent: []byte("content"), + expected: func(_, rngContent []byte) []byte { + return rngContent + }, + }, + { + name: "append multipart object", + createObj: func() { + multipartInfo := createMultipartUpload(hc, bktName, objName, map[string]string{}) + etag1, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 1, partSize) + etag2, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 2, partSize) + etag3, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 3, partSize) + completeMultipartUpload(hc, bktName, objName, multipartInfo.UploadID, []string{etag1, etag2, etag3}) + }, + rng: "bytes=-1-", + rngContent: []byte("content"), + expected: func(objContent, rngContent []byte) []byte { + return bytes.Join([][]byte{objContent, rngContent}, []byte{}) + }, + }, + { + name: "update regular object", + createObj: func() { + putObjectContent(hc, bktName, objName, "object old content") + }, + rng: "bytes=7-9", + rngContent: []byte("new"), + expected: func(objContent, rngContent []byte) []byte { + return bytes.Join([][]byte{objContent[:7], rngContent, objContent[10:]}, []byte{}) + }, + }, + { + name: "update multipart object", + createObj: func() { + multipartInfo := createMultipartUpload(hc, bktName, objName, map[string]string{}) + etag1, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 1, partSize) + etag2, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 2, partSize) + etag3, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 3, partSize) + completeMultipartUpload(hc, bktName, objName, multipartInfo.UploadID, []string{etag1, etag2, etag3}) + }, + rng: "bytes=" + strconv.Itoa(partSize/2) + "-" + strconv.Itoa(partSize*5/2-1), + rngContent: func() []byte { + rangeContent := make([]byte, partSize*2) + _, err := rand.Read(rangeContent) + require.NoError(t, err) + return rangeContent + }(), + expected: func(objContent, rangeContent []byte) []byte { + return bytes.Join([][]byte{objContent[:partSize/2], rangeContent, objContent[partSize*5/2:]}, []byte{}) + }, + }, + { + name: "overwrite regular object, increase size", + createObj: func() { + putObjectContent(hc, bktName, objName, "object old") + }, + rng: "bytes=7-", + rngContent: []byte("new content"), + expected: func(objContent, rangeContent []byte) []byte { + return bytes.Join([][]byte{objContent[:7], rangeContent}, []byte{}) + }, + }, + { + name: "overwrite regular object, decrease size", + createObj: func() { + putObjectContent(hc, bktName, objName, "object old content") + }, + rng: "bytes=7-", + rngContent: []byte("new"), + expected: func(objContent, rangeContent []byte) []byte { + return bytes.Join([][]byte{objContent[:7], rangeContent}, []byte{}) + }, + }, + { + name: "overwrite multipart object, increase size", + createObj: func() { + multipartInfo := createMultipartUpload(hc, bktName, objName, map[string]string{}) + etag1, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 1, partSize) + etag2, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 2, partSize) + etag3, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 3, partSize) + completeMultipartUpload(hc, bktName, objName, multipartInfo.UploadID, []string{etag1, etag2, etag3}) + }, + rng: "bytes=" + strconv.Itoa(partSize*3/2) + "-", + rngContent: func() []byte { + rangeContent := make([]byte, partSize*2) + _, err := rand.Read(rangeContent) + require.NoError(t, err) + return rangeContent + }(), + expected: func(objContent, rangeContent []byte) []byte { + return bytes.Join([][]byte{objContent[:partSize*3/2], rangeContent}, []byte{}) + }, + }, + { + name: "overwrite multipart object, reduce number of parts", + createObj: func() { + multipartInfo := createMultipartUpload(hc, bktName, objName, map[string]string{}) + etag1, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 1, partSize) + etag2, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 2, partSize) + etag3, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 3, partSize) + completeMultipartUpload(hc, bktName, objName, multipartInfo.UploadID, []string{etag1, etag2, etag3}) + }, + rng: "bytes=" + strconv.Itoa(partSize/2-1) + "-", + rngContent: func() []byte { + rangeContent := make([]byte, partSize) + _, err := rand.Read(rangeContent) + require.NoError(t, err) + return rangeContent + }(), + expected: func(objContent, rangeContent []byte) []byte { + return bytes.Join([][]byte{objContent[:partSize/2-1], rangeContent}, []byte{}) + }, + }, + { + name: "overwrite multipart object, decrease size of last part", + createObj: func() { + multipartInfo := createMultipartUpload(hc, bktName, objName, map[string]string{}) + etag1, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 1, partSize) + etag2, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 2, partSize) + etag3, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 3, partSize) + completeMultipartUpload(hc, bktName, objName, multipartInfo.UploadID, []string{etag1, etag2, etag3}) + }, + rng: "bytes=" + strconv.Itoa(partSize*3/2) + "-", + rngContent: func() []byte { + rangeContent := make([]byte, partSize) + _, err := rand.Read(rangeContent) + require.NoError(t, err) + return rangeContent + }(), + expected: func(objContent, rangeContent []byte) []byte { + return bytes.Join([][]byte{objContent[:partSize*3/2], rangeContent}, []byte{}) + }, + }, + { + name: "overwrite last part, increase size", + createObj: func() { + multipartInfo := createMultipartUpload(hc, bktName, objName, map[string]string{}) + etag1, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 1, partSize) + etag2, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 2, partSize) + etag3, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 3, partSize) + completeMultipartUpload(hc, bktName, objName, multipartInfo.UploadID, []string{etag1, etag2, etag3}) + }, + rng: "bytes=" + strconv.Itoa(partSize*5/2) + "-", + rngContent: func() []byte { + rangeContent := make([]byte, partSize) + _, err := rand.Read(rangeContent) + require.NoError(t, err) + return rangeContent + }(), + expected: func(objContent, rangeContent []byte) []byte { + return bytes.Join([][]byte{objContent[:partSize*5/2], rangeContent}, []byte{}) + }, + }, + { + name: "overwrite last part, decrease size", + createObj: func() { + multipartInfo := createMultipartUpload(hc, bktName, objName, map[string]string{}) + etag1, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 1, partSize) + etag2, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 2, partSize) + etag3, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 3, partSize) + completeMultipartUpload(hc, bktName, objName, multipartInfo.UploadID, []string{etag1, etag2, etag3}) + }, + rng: "bytes=" + strconv.Itoa(partSize*9/4) + "-", + rngContent: func() []byte { + rangeContent := make([]byte, partSize/2) + _, err := rand.Read(rangeContent) + require.NoError(t, err) + return rangeContent + }(), + expected: func(objContent, rangeContent []byte) []byte { + return bytes.Join([][]byte{objContent[:partSize*9/4], rangeContent}, []byte{}) + }, + }, + { + name: "regular object, empty range content", + createObj: func() { + putObjectContent(hc, bktName, objName, "object old") + }, + rng: "bytes=6-", + rngContent: []byte{}, + expected: func(objContent, _ []byte) []byte { + return objContent[:6] + }, + }, + { + name: "multipart object, empty range content, decrease size of last part", + createObj: func() { + multipartInfo := createMultipartUpload(hc, bktName, objName, map[string]string{}) + etag1, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 1, partSize) + etag2, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 2, partSize) + etag3, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 3, partSize) + completeMultipartUpload(hc, bktName, objName, multipartInfo.UploadID, []string{etag1, etag2, etag3}) + }, + rng: "bytes=" + strconv.Itoa(partSize*9/4) + "-", + rngContent: []byte{}, + expected: func(objContent, _ []byte) []byte { + return bytes.Join([][]byte{objContent[:partSize*9/4]}, []byte{}) + }, + }, + { + name: "multipart object, empty range content, decrease number of parts", + createObj: func() { + multipartInfo := createMultipartUpload(hc, bktName, objName, map[string]string{}) + etag1, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 1, partSize) + etag2, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 2, partSize) + etag3, _ := uploadPart(hc, bktName, objName, multipartInfo.UploadID, 3, partSize) + completeMultipartUpload(hc, bktName, objName, multipartInfo.UploadID, []string{etag1, etag2, etag3}) + }, + rng: "bytes=" + strconv.Itoa(partSize*5/4) + "-", + rngContent: func() []byte { + return []byte{} + }(), + expected: func(objContent, _ []byte) []byte { + return bytes.Join([][]byte{objContent[:partSize*5/4]}, []byte{}) + }, + }, + { + name: "invalid range length", + createObj: func() { + putObjectContent(hc, bktName, objName, "object") + }, + rng: "bytes=4-7", + rngContent: []byte("content"), + code: s3errors.ErrInvalidRangeLength, + }, + { + name: "invalid start byte", + createObj: func() { + putObjectContent(hc, bktName, objName, "object") + }, + rng: "bytes=7-", + rngContent: []byte("content"), + code: s3errors.ErrRangeOutOfBounds, + }, + { + name: "invalid range", + createObj: func() { + putObjectContent(hc, bktName, objName, "object") + }, + rng: "bytes=12-6", + rngContent: []byte("content"), + code: s3errors.ErrInvalidRange, + }, + { + name: "encrypted object", + createObj: func() { + putEncryptedObject(t, hc, bktName, objName, "object") + }, + rng: "bytes=-1-", + rngContent: []byte("content"), + code: s3errors.ErrInternalError, + }, + { + name: "range is too long", + createObj: func() { + putObjectContent(hc, bktName, objName, "object") + }, + rng: "bytes=0-5368709120", + rngContent: []byte("content"), + code: s3errors.ErrInvalidRange, + }, + } { + t.Run(tt.name, func(t *testing.T) { + tt.createObj() + + if tt.code == 0 { + objContent, _ := getObject(hc, bktName, objName) + putObjectContentWithHeaders(hc, bktName, objName, string(tt.rngContent), map[string]string{api.Range: tt.rng}) + patchedObj, _ := getObject(hc, bktName, objName) + equalDataSlices(t, tt.expected(objContent, tt.rngContent), patchedObj) + } else { + putObjectWithHeadersAssertS3Error(hc, bktName, objName, string(tt.rngContent), map[string]string{api.Range: tt.rng}, tt.code) + } + }) + } +} + +func TestParsePutRange(t *testing.T) { + for _, tt := range []struct { + rng string + objSize uint64 + contentLen uint64 + expected *layer.RangeParams + overwrite bool + err bool + }{ + { + rng: "bytes=-1-", + objSize: 10, + contentLen: 10, + expected: &layer.RangeParams{ + Start: 10, + End: 19, + }, + }, + { + rng: "bytes=4-7", + expected: &layer.RangeParams{ + Start: 4, + End: 7, + }, + }, + { + rng: "bytes=4-", + contentLen: 7, + expected: &layer.RangeParams{ + Start: 4, + End: 10, + }, + overwrite: true, + }, + { + rng: "bytes=7-", + objSize: 10, + contentLen: 0, + expected: &layer.RangeParams{ + Start: 7, + End: 9, + }, + overwrite: true, + }, + { + rng: "", + err: true, + }, + { + rng: "4-7", + err: true, + }, + { + rng: "bytes=7-4", + err: true, + }, + { + rng: "bytes=-10-", + err: true, + }, + { + rng: "bytes=-1-10", + err: true, + }, + { + rng: "bytes=1--10", + err: true, + }, + { + rng: "bytes=10", + err: true, + }, + { + rng: "bytes=10-a", + err: true, + }, + { + rng: "bytes=a-10", + err: true, + }, + { + rng: "bytes=10-", + objSize: 10, + err: true, + }, + } { + t.Run(fmt.Sprintf("case: %s", tt.rng), func(t *testing.T) { + rng, overwrite, err := parsePutRange(tt.rng, tt.objSize, tt.contentLen) + if tt.err { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expected.Start, rng.Start) + require.Equal(t, tt.expected.End, rng.End) + require.Equal(t, tt.overwrite, overwrite) + } + }) + } +} + func TestCheckBucketName(t *testing.T) { for _, tc := range []struct { name string diff --git a/api/headers.go b/api/headers.go index 540d3cc..e1a6b18 100644 --- a/api/headers.go +++ b/api/headers.go @@ -40,6 +40,7 @@ const ( IfUnmodifiedSince = "If-Unmodified-Since" IfMatch = "If-Match" IfNoneMatch = "If-None-Match" + Range = "Range" AmzContentSha256 = "X-Amz-Content-Sha256" AmzCopyIfModifiedSince = "X-Amz-Copy-Source-If-Modified-Since" diff --git a/api/layer/patch.go b/api/layer/patch.go index 44d4e4f..89706ac 100644 --- a/api/layer/patch.go +++ b/api/layer/patch.go @@ -21,6 +21,7 @@ type PatchObjectParams struct { Range *RangeParams VersioningEnabled bool CopiesNumbers []uint32 + Overwrite bool } func (n *Layer) PatchObject(ctx context.Context, p *PatchObjectParams) (*data.ExtendedObjectInfo, error) { @@ -42,6 +43,10 @@ func (n *Layer) PatchObject(ctx context.Context, p *PatchObjectParams) (*data.Ex } n.prepareAuthParameters(ctx, &prmPatch.PrmAuth, p.BktInfo.Owner) + if p.Overwrite { + prmPatch.Length = p.Object.ObjectInfo.Size - p.Range.Start + } + createdObj, err := n.patchObject(ctx, prmPatch) if err != nil { return nil, fmt.Errorf("patch object: %w", err) @@ -115,9 +120,13 @@ func (n *Layer) patchMultipartObject(ctx context.Context, p *PatchObjectParams) } n.prepareAuthParameters(ctx, &prmPatch.PrmAuth, p.BktInfo.Owner) - off, ln := p.Range.Start, p.Range.End-p.Range.Start+1 - var multipartObjectSize uint64 - for i, part := range parts { + var ( + multipartObjectSize uint64 + i int + part *data.PartInfo + off, ln = p.Range.Start, p.Range.End - p.Range.Start + 1 + ) + for i, part = range parts { if off > part.Size || (off == part.Size && i != len(parts)-1) || ln == 0 { multipartObjectSize += part.Size if ln != 0 { @@ -133,21 +142,27 @@ func (n *Layer) patchMultipartObject(ctx context.Context, p *PatchObjectParams) } parts[i].OID = createdObj.ID - parts[i].Size = createdObj.Size parts[i].MD5 = "" parts[i].ETag = hex.EncodeToString(createdObj.HashSum) multipartObjectSize += createdObj.Size + + if createdObj.Size < parts[i].Size { + parts[i].Size = createdObj.Size + break + } + + parts[i].Size = createdObj.Size } - return n.updateCombinedObject(ctx, parts, multipartObjectSize, p) + return n.updateCombinedObject(ctx, parts[:i+1], multipartObjectSize, p) } // Returns patched part info, updated offset and length. func (n *Layer) patchPart(ctx context.Context, part *data.PartInfo, p *PatchObjectParams, prmPatch *PrmObjectPatch, off, ln uint64, lastPart bool) (*data.CreatedObjectInfo, uint64, uint64, error) { - if off == 0 && ln >= part.Size { + if off == 0 && (ln >= part.Size || p.Overwrite) { curLen := part.Size - if lastPart { + if lastPart || (p.Overwrite && ln > part.Size) { curLen = ln } prm := PrmObjectCreate{ @@ -162,13 +177,12 @@ func (n *Layer) patchPart(ctx context.Context, part *data.PartInfo, p *PatchObje return nil, 0, 0, fmt.Errorf("put new part object '%s': %w", part.OID.EncodeToString(), err) } - ln -= curLen - + ln -= min(curLen, createdObj.Size) return createdObj, off, ln, err } curLen := ln - if off+curLen > part.Size && !lastPart { + if (off+curLen > part.Size && !lastPart) || (p.Overwrite && off+curLen < part.Size) { curLen = part.Size - off } prmPatch.Object = part.OID @@ -183,7 +197,7 @@ func (n *Layer) patchPart(ctx context.Context, part *data.PartInfo, p *PatchObje return nil, 0, 0, fmt.Errorf("patch part object '%s': %w", part.OID.EncodeToString(), err) } - ln -= curLen + ln -= min(curLen, createdObj.Size-off) off = 0 return createdObj, off, ln, nil