From 9df86954630b09efb78d4e87be6775bb1ff012f2 Mon Sep 17 00:00:00 2001 From: Denis Kirillov Date: Wed, 21 Jun 2023 17:16:40 +0300 Subject: [PATCH] [#143] Fix transformToS3Error function Unwrap error before checking for s3 error Signed-off-by: Denis Kirillov --- CHANGELOG.md | 1 + api/handler/encryption_test.go | 9 +++- api/handler/multipart_upload_test.go | 15 +++++++ api/handler/util.go | 30 +++++++------ api/handler/util_test.go | 64 ++++++++++++++++++++++++++++ 5 files changed, 104 insertions(+), 15 deletions(-) create mode 100644 api/handler/util_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index e8a93b9..c32b208 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ This document outlines major changes between releases. - Don't create unnecessary delete-markers (#83) - Handle negative `Content-Length` on put (#125) - Use `DisableURIPathEscaping` to presign urls (#125) +- Use specific s3 errors instead of `InternalError` where possible (#143) ### Added - Implement chunk uploading (#106) diff --git a/api/handler/encryption_test.go b/api/handler/encryption_test.go index 3394b32..5684ea3 100644 --- a/api/handler/encryption_test.go +++ b/api/handler/encryption_test.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "net/http/httptest" "net/url" "strconv" "strings" @@ -190,6 +191,11 @@ func createMultipartUploadBase(hc *handlerContext, bktName, objName string, encr } func completeMultipartUpload(hc *handlerContext, bktName, objName, uploadID string, partsETags []string) { + w := completeMultipartUploadBase(hc, bktName, objName, uploadID, partsETags) + assertStatus(hc.t, w, http.StatusOK) +} + +func completeMultipartUploadBase(hc *handlerContext, bktName, objName, uploadID string, partsETags []string) *httptest.ResponseRecorder { query := make(url.Values) query.Set(uploadIDQuery, uploadID) complete := &CompleteMultipartUpload{ @@ -204,7 +210,8 @@ func completeMultipartUpload(hc *handlerContext, bktName, objName, uploadID stri w, r := prepareTestFullRequest(hc, bktName, objName, query, complete) hc.Handler().CompleteMultipartUploadHandler(w, r) - assertStatus(hc.t, w, http.StatusOK) + + return w } func uploadPartEncrypted(hc *handlerContext, bktName, objName, uploadID string, num, size int) (string, []byte) { diff --git a/api/handler/multipart_upload_test.go b/api/handler/multipart_upload_test.go index 6ebf691..a43c468 100644 --- a/api/handler/multipart_upload_test.go +++ b/api/handler/multipart_upload_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + s3Errors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors" "github.com/stretchr/testify/require" ) @@ -46,3 +47,17 @@ func TestPeriodicWriter(t *testing.T) { }) }) } + +func TestMultipartUploadInvalidPart(t *testing.T) { + hc := prepareHandlerContext(t) + + bktName, objName := "bucket-to-upload-part", "object-multipart" + createTestBucket(hc, bktName) + partSize := 8 // less than min part size + + multipartUpload := createMultipartUpload(hc, bktName, objName, map[string]string{}) + etag1, _ := uploadPart(hc, bktName, objName, multipartUpload.UploadID, 1, partSize) + etag2, _ := uploadPart(hc, bktName, objName, multipartUpload.UploadID, 2, partSize) + w := completeMultipartUploadBase(hc, bktName, objName, multipartUpload.UploadID, []string{etag1, etag2}) + assertS3Error(hc.t, w, s3Errors.GetAPIError(s3Errors.ErrEntityTooSmall)) +} diff --git a/api/handler/util.go b/api/handler/util.go index 18d9131..e3a21a9 100644 --- a/api/handler/util.go +++ b/api/handler/util.go @@ -2,15 +2,16 @@ package handler import ( "context" - errorsStd "errors" + "errors" "net/http" "strconv" "strings" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data" - "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors" + s3errors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer" + frosterrors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/frostfs/errors" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/session" "go.uber.org/zap" ) @@ -51,20 +52,21 @@ func (h *handler) logAndSendErrorNoHeader(w http.ResponseWriter, logText string, } func transformToS3Error(err error) error { - if _, ok := err.(errors.Error); ok { + err = frosterrors.UnwrapErr(err) // this wouldn't work with errors.Join + if _, ok := err.(s3errors.Error); ok { return err } - if errorsStd.Is(err, layer.ErrAccessDenied) || - errorsStd.Is(err, layer.ErrNodeAccessDenied) { - return errors.GetAPIError(errors.ErrAccessDenied) + if errors.Is(err, layer.ErrAccessDenied) || + errors.Is(err, layer.ErrNodeAccessDenied) { + return s3errors.GetAPIError(s3errors.ErrAccessDenied) } - if errorsStd.Is(err, layer.ErrGatewayTimeout) { - return errors.GetAPIError(errors.ErrGatewayTimeout) + if errors.Is(err, layer.ErrGatewayTimeout) { + return s3errors.GetAPIError(s3errors.ErrGatewayTimeout) } - return errors.GetAPIError(errors.ErrInternalError) + return s3errors.GetAPIError(s3errors.ErrInternalError) } func (h *handler) ResolveBucket(ctx context.Context, bucket string) (*data.BucketInfo, error) { @@ -99,26 +101,26 @@ func parseRange(s string) (*layer.RangeParams, error) { prefix := "bytes=" if !strings.HasPrefix(s, prefix) { - return nil, errors.GetAPIError(errors.ErrInvalidRange) + return nil, s3errors.GetAPIError(s3errors.ErrInvalidRange) } s = strings.TrimPrefix(s, prefix) valuesStr := strings.Split(s, "-") if len(valuesStr) != 2 { - return nil, errors.GetAPIError(errors.ErrInvalidRange) + return nil, s3errors.GetAPIError(s3errors.ErrInvalidRange) } values := make([]uint64, 0, len(valuesStr)) for _, v := range valuesStr { num, err := strconv.ParseUint(v, 10, 64) if err != nil { - return nil, errors.GetAPIError(errors.ErrInvalidRange) + return nil, s3errors.GetAPIError(s3errors.ErrInvalidRange) } values = append(values, num) } if values[0] > values[1] { - return nil, errors.GetAPIError(errors.ErrInvalidRange) + return nil, s3errors.GetAPIError(s3errors.ErrInvalidRange) } return &layer.RangeParams{ @@ -134,7 +136,7 @@ func getSessionTokenSetEACL(ctx context.Context) (*session.Container, error) { } sessionToken := boxData.Gate.SessionTokenForSetEACL() if sessionToken == nil { - return nil, errors.GetAPIError(errors.ErrAccessDenied) + return nil, s3errors.GetAPIError(s3errors.ErrAccessDenied) } return sessionToken, nil diff --git a/api/handler/util_test.go b/api/handler/util_test.go new file mode 100644 index 0000000..587f92c --- /dev/null +++ b/api/handler/util_test.go @@ -0,0 +1,64 @@ +package handler + +import ( + "errors" + "fmt" + "testing" + + s3errors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors" + "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer" + "github.com/stretchr/testify/require" +) + +func TestTransformS3Errors(t *testing.T) { + for _, tc := range []struct { + name string + err error + expected s3errors.ErrorCode + }{ + { + name: "simple std error to internal error", + err: errors.New("some error"), + expected: s3errors.ErrInternalError, + }, + { + name: "layer access denied error to s3 access denied error", + err: layer.ErrAccessDenied, + expected: s3errors.ErrAccessDenied, + }, + { + name: "wrapped layer access denied error to s3 access denied error", + err: fmt.Errorf("wrap: %w", layer.ErrAccessDenied), + expected: s3errors.ErrAccessDenied, + }, + { + name: "layer node access denied error to s3 access denied error", + err: layer.ErrNodeAccessDenied, + expected: s3errors.ErrAccessDenied, + }, + { + name: "layer gateway timeout error to s3 gateway timeout error", + err: layer.ErrGatewayTimeout, + expected: s3errors.ErrGatewayTimeout, + }, + { + name: "s3 error to s3 error", + err: s3errors.GetAPIError(s3errors.ErrInvalidPart), + expected: s3errors.ErrInvalidPart, + }, + { + name: "wrapped s3 error to s3 error", + err: fmt.Errorf("wrap: %w", s3errors.GetAPIError(s3errors.ErrInvalidPart)), + expected: s3errors.ErrInvalidPart, + }, + } { + t.Run(tc.name, func(t *testing.T) { + err := transformToS3Error(tc.err) + s3err, ok := err.(s3errors.Error) + require.True(t, ok, "error must be s3 error") + require.Equalf(t, tc.expected, s3err.ErrCode, + "expected: '%s', got: '%s'", + s3errors.GetAPIError(tc.expected).Code, s3errors.GetAPIError(s3err.ErrCode).Code) + }) + } +}