diff --git a/api/errors/errors.go b/api/errors/errors.go index 5d2709b..5c807b1 100644 --- a/api/errors/errors.go +++ b/api/errors/errors.go @@ -47,6 +47,7 @@ const ( ErrInvalidRequestBody ErrInvalidCopySource ErrInvalidMetadataDirective + ErrInvalidTaggingDirective ErrInvalidCopyDest ErrInvalidPolicyDocument ErrInvalidObjectState @@ -294,6 +295,12 @@ var errorCodes = errorCodeMap{ Description: "Unknown metadata directive.", HTTPStatusCode: http.StatusBadRequest, }, + ErrInvalidTaggingDirective: { + ErrCode: ErrInvalidTaggingDirective, + Code: "InvalidArgument", + Description: "Unknown tagging directive.", + HTTPStatusCode: http.StatusBadRequest, + }, ErrInvalidStorageClass: { ErrCode: ErrInvalidStorageClass, Code: "InvalidStorageClass", diff --git a/api/handler/copy.go b/api/handler/copy.go index 14bc73a..69ee957 100644 --- a/api/handler/copy.go +++ b/api/handler/copy.go @@ -17,9 +17,13 @@ import ( type copyObjectArgs struct { Conditional *conditionalArgs MetadataDirective string + TaggingDirective string } -const replaceMetadataDirective = "REPLACE" +const ( + replaceDirective = "REPLACE" + copyDirective = "COPY" +) // path2BucketObject returns a bucket and an object. func path2BucketObject(path string) (bucket, prefix string) { @@ -33,8 +37,10 @@ func path2BucketObject(path string) (bucket, prefix string) { func (h *handler) CopyObjectHandler(w http.ResponseWriter, r *http.Request) { var ( + err error versionID string metadata map[string]string + tagSet map[string]string sessionTokenEACL *session.Container reqInfo = api.GetReqInfo(r.Context()) @@ -42,7 +48,7 @@ func (h *handler) CopyObjectHandler(w http.ResponseWriter, r *http.Request) { containsACL = containsACLHeaders(r) ) - src := r.Header.Get("X-Amz-Copy-Source") + src := r.Header.Get(api.AmzCopySource) // Check https://docs.aws.amazon.com/AmazonS3/latest/dev/ObjectVersioning.html // Regardless of whether you have enabled versioning, each object in your bucket // has a version ID. If you have not enabled versioning, Amazon S3 sets the value @@ -55,23 +61,11 @@ func (h *handler) CopyObjectHandler(w http.ResponseWriter, r *http.Request) { srcBucket, srcObject := path2BucketObject(src) - args, err := parseCopyObjectArgs(r.Header) - if err != nil { - h.logAndSendError(w, "could not parse request params", reqInfo, err) - return - } p := &layer.HeadObjectParams{ Object: srcObject, VersionID: versionID, } - if args.MetadataDirective == replaceMetadataDirective { - metadata = parseMetadata(r) - } else if srcBucket == reqInfo.BucketName && srcObject == reqInfo.ObjectName { - h.logAndSendError(w, "could not copy to itself", reqInfo, errors.GetAPIError(errors.ErrInvalidRequest)) - return - } - if p.BktInfo, err = h.getBucketAndCheckOwner(r, srcBucket, api.AmzSourceExpectedBucketOwner); err != nil { h.logAndSendError(w, "couldn't get source bucket", reqInfo, err) return @@ -96,6 +90,36 @@ func (h *handler) CopyObjectHandler(w http.ResponseWriter, r *http.Request) { return } + args, err := parseCopyObjectArgs(r.Header) + if err != nil { + h.logAndSendError(w, "could not parse request params", reqInfo, err) + return + } + + if args.MetadataDirective == replaceDirective { + metadata = parseMetadata(r) + } + + if args.TaggingDirective == replaceDirective { + tagSet, err = parseTaggingHeader(r.Header) + if err != nil { + h.logAndSendError(w, "could not parse tagging header", reqInfo, err) + return + } + } else { + objVersion := &layer.ObjectVersion{ + BktInfo: p.BktInfo, + ObjectName: srcObject, + VersionID: objInfo.VersionID(), + } + + _, tagSet, err = h.obj.GetObjectTagging(r.Context(), objVersion) + if err != nil { + h.logAndSendError(w, "could not get object tagging", reqInfo, err) + return + } + } + encryptionParams, err := h.formEncryptionParams(r.Header) if err != nil { h.logAndSendError(w, "invalid sse headers", reqInfo, err) @@ -178,6 +202,18 @@ func (h *handler) CopyObjectHandler(w http.ResponseWriter, r *http.Request) { } } + if tagSet != nil { + t := &layer.ObjectVersion{ + BktInfo: dstBktInfo, + ObjectName: reqInfo.ObjectName, + VersionID: objInfo.VersionID(), + } + if _, err = h.obj.PutObjectTagging(r.Context(), t, tagSet); err != nil { + h.logAndSendError(w, "could not upload object tagging", reqInfo, err) + return + } + } + h.log.Info("object is copied", zap.String("bucket", objInfo.Bucket), zap.String("object", objInfo.Name), @@ -213,7 +249,21 @@ func parseCopyObjectArgs(headers http.Header) (*copyObjectArgs, error) { } copyArgs := ©ObjectArgs{Conditional: args} + copyArgs.MetadataDirective = headers.Get(api.AmzMetadataDirective) + if !isValidDirective(copyArgs.MetadataDirective) { + return nil, errors.GetAPIError(errors.ErrInvalidMetadataDirective) + } + + copyArgs.TaggingDirective = headers.Get(api.AmzTaggingDirective) + if !isValidDirective(copyArgs.TaggingDirective) { + return nil, errors.GetAPIError(errors.ErrInvalidTaggingDirective) + } return copyArgs, nil } + +func isValidDirective(directive string) bool { + return len(directive) == 0 || + directive == replaceDirective || directive == copyDirective +} diff --git a/api/handler/copy_test.go b/api/handler/copy_test.go new file mode 100644 index 0000000..2a13ab0 --- /dev/null +++ b/api/handler/copy_test.go @@ -0,0 +1,95 @@ +package handler + +import ( + "encoding/xml" + "net/http" + "net/url" + "testing" + + "github.com/nspcc-dev/neofs-s3-gw/api" + "github.com/stretchr/testify/require" +) + +type CopyMeta struct { + TaggingDirective string + Tags map[string]string + MetadataDirective string + Metadata map[string]string +} + +func TestCopyWithTaggingDirective(t *testing.T) { + tc := prepareHandlerContext(t) + + bktName, objName := "bucket-for-copy", "object-from-copy" + objToCopy, objToCopy2 := "object-to-copy", "object-to-copy-2" + createBucketAndObject(t, tc, bktName, objName) + + putObjectTagging(t, tc, bktName, objName, map[string]string{"key": "val"}) + + copyMeta := CopyMeta{ + Tags: map[string]string{"key2": "val"}, + } + copyObject(t, tc, bktName, objName, objToCopy, copyMeta) + tagging := getObjectTagging(t, tc, bktName, objToCopy, emptyVersion) + require.Len(t, tagging.TagSet, 1) + require.Equal(t, "key", tagging.TagSet[0].Key) + require.Equal(t, "val", tagging.TagSet[0].Value) + + copyMeta.TaggingDirective = replaceDirective + copyObject(t, tc, bktName, objName, objToCopy2, copyMeta) + tagging = getObjectTagging(t, tc, bktName, objToCopy2, emptyVersion) + require.Len(t, tagging.TagSet, 1) + require.Equal(t, "key2", tagging.TagSet[0].Key) + require.Equal(t, "val", tagging.TagSet[0].Value) +} + +func copyObject(t *testing.T, tc *handlerContext, bktName, fromObject, toObject string, copyMeta CopyMeta) { + w, r := prepareTestRequest(t, bktName, toObject, nil) + r.Header.Set(api.AmzCopySource, bktName+"/"+fromObject) + + r.Header.Set(api.AmzMetadataDirective, copyMeta.MetadataDirective) + for key, val := range copyMeta.Metadata { + r.Header.Set(api.MetadataPrefix+key, val) + } + + r.Header.Set(api.AmzTaggingDirective, copyMeta.TaggingDirective) + tagsQuery := make(url.Values) + for key, val := range copyMeta.Tags { + tagsQuery.Set(key, val) + } + r.Header.Set(api.AmzTagging, tagsQuery.Encode()) + + tc.Handler().CopyObjectHandler(w, r) + assertStatus(t, w, http.StatusOK) +} + +func putObjectTagging(t *testing.T, tc *handlerContext, bktName, objName string, tags map[string]string) { + body := &Tagging{ + TagSet: make([]Tag, 0, len(tags)), + } + + for key, val := range tags { + body.TagSet = append(body.TagSet, Tag{ + Key: key, + Value: val, + }) + } + + w, r := prepareTestRequest(t, bktName, objName, body) + tc.Handler().PutObjectTaggingHandler(w, r) + assertStatus(t, w, http.StatusOK) +} + +func getObjectTagging(t *testing.T, tc *handlerContext, bktName, objName, version string) *Tagging { + query := make(url.Values) + query.Add(api.QueryVersionID, version) + + w, r := prepareTestFullRequest(t, bktName, objName, query, nil) + tc.Handler().GetObjectTaggingHandler(w, r) + assertStatus(t, w, http.StatusOK) + + tagging := &Tagging{} + err := xml.NewDecoder(w.Result().Body).Decode(tagging) + require.NoError(t, err) + return tagging +} diff --git a/api/headers.go b/api/headers.go index 1bc585b..b4587e8 100644 --- a/api/headers.go +++ b/api/headers.go @@ -5,6 +5,7 @@ const ( MetadataPrefix = "X-Amz-Meta-" NeoFSSystemMetadataPrefix = "S3-" AmzMetadataDirective = "X-Amz-Metadata-Directive" + AmzTaggingDirective = "X-Amz-Tagging-Directive" AmzVersionID = "X-Amz-Version-Id" AmzTaggingCount = "X-Amz-Tagging-Count" AmzTagging = "X-Amz-Tagging" diff --git a/api/layer/tree_mock.go b/api/layer/tree_mock.go index cb412ae..3392c44 100644 --- a/api/layer/tree_mock.go +++ b/api/layer/tree_mock.go @@ -16,6 +16,7 @@ type TreeServiceMock struct { versions map[string]map[string][]*data.NodeVersion system map[string]map[string]*data.BaseNodeVersion locks map[string]map[uint64]*data.LockInfo + tags map[string]map[uint64]map[string]string multiparts map[string]map[string][]*data.MultipartInfo parts map[string]map[int]*data.PartInfo } @@ -26,19 +27,37 @@ func (t *TreeServiceMock) GetObjectTaggingAndLock(ctx context.Context, cnrID cid return nil, lock, err } -func (t *TreeServiceMock) GetObjectTagging(ctx context.Context, cnrID cid.ID, objVersion *data.NodeVersion) (map[string]string, error) { - // TODO implement me - panic("implement me") +func (t *TreeServiceMock) GetObjectTagging(_ context.Context, cnrID cid.ID, nodeVersion *data.NodeVersion) (map[string]string, error) { + cnrTagsMap, ok := t.tags[cnrID.EncodeToString()] + if !ok { + return nil, nil + } + + return cnrTagsMap[nodeVersion.ID], nil } -func (t *TreeServiceMock) PutObjectTagging(ctx context.Context, cnrID cid.ID, objVersion *data.NodeVersion, tagSet map[string]string) error { - // TODO implement me - panic("implement me") +func (t *TreeServiceMock) PutObjectTagging(_ context.Context, cnrID cid.ID, nodeVersion *data.NodeVersion, tagSet map[string]string) error { + cnrTagsMap, ok := t.tags[cnrID.EncodeToString()] + if !ok { + t.tags[cnrID.EncodeToString()] = map[uint64]map[string]string{ + nodeVersion.ID: tagSet, + } + return nil + } + + cnrTagsMap[nodeVersion.ID] = tagSet + + return nil } -func (t *TreeServiceMock) DeleteObjectTagging(ctx context.Context, cnrID cid.ID, objVersion *data.NodeVersion) error { - // TODO implement me - panic("implement me") +func (t *TreeServiceMock) DeleteObjectTagging(_ context.Context, cnrID cid.ID, objVersion *data.NodeVersion) error { + cnrTagsMap, ok := t.tags[cnrID.EncodeToString()] + if !ok { + return nil + } + + delete(cnrTagsMap, objVersion.ID) + return nil } func (t *TreeServiceMock) GetBucketTagging(ctx context.Context, cnrID cid.ID) (map[string]string, error) { @@ -62,6 +81,7 @@ func NewTreeService() *TreeServiceMock { versions: make(map[string]map[string][]*data.NodeVersion), system: make(map[string]map[string]*data.BaseNodeVersion), locks: make(map[string]map[uint64]*data.LockInfo), + tags: make(map[string]map[uint64]map[string]string), multiparts: make(map[string]map[string][]*data.MultipartInfo), parts: make(map[string]map[int]*data.PartInfo), }