diff --git a/api/auth/center.go b/api/auth/center.go index dafa9c1..6b7032e 100644 --- a/api/auth/center.go +++ b/api/auth/center.go @@ -285,7 +285,7 @@ func (c *Center) checkFormData(r *http.Request) (*middleware.Box, error) { secret := box.Gate.SecretKey service, region := submatches["service"], submatches["region"] - signature := signStr(secret, service, region, signatureDateTime, policy) + signature := SignStr(secret, service, region, signatureDateTime, policy) reqSignature := MultipartFormValue(r, "x-amz-signature") if signature != reqSignature { return nil, fmt.Errorf("%w: %s != %s", apiErrors.GetAPIError(apiErrors.ErrSignatureDoesNotMatch), @@ -359,7 +359,7 @@ func (c *Center) checkSign(authHeader *AuthHeader, box *accessbox.Box, request * return nil } -func signStr(secret, service, region string, t time.Time, strToSign string) string { +func SignStr(secret, service, region string, t time.Time, strToSign string) string { creds := deriveKey(secret, service, region, t) signature := hmacSHA256(creds, []byte(strToSign)) return hex.EncodeToString(signature) diff --git a/api/auth/center_test.go b/api/auth/center_test.go index d70ad6d..edcaca1 100644 --- a/api/auth/center_test.go +++ b/api/auth/center_test.go @@ -115,7 +115,7 @@ func TestSignature(t *testing.T) { panic(err) } - signature := signStr(secret, "s3", "us-east-1", signTime, strToSign) + signature := SignStr(secret, "s3", "us-east-1", signTime, strToSign) require.Equal(t, "dfbe886241d9e369cf4b329ca0f15eb27306c97aa1022cc0bb5a914c4ef87634", signature) } @@ -492,7 +492,7 @@ func TestHTTPPostAuthenticate(t *testing.T) { name: "HTTP POST valid", request: func() *http.Request { creds := getCredsStr(accessKeyID, timeToSignStr, region, service) - sign := signStr(secret.SecretKey, service, region, timeToSign, policyBase64) + sign := SignStr(secret.SecretKey, service, region, timeToSign, policyBase64) return getRequestWithMultipartForm(t, policyBase64, creds, timeToSignStr, sign, defaultFieldName) }(), @@ -501,7 +501,7 @@ func TestHTTPPostAuthenticate(t *testing.T) { name: "HTTP POST valid with custom field name", request: func() *http.Request { creds := getCredsStr(accessKeyID, timeToSignStr, region, service) - sign := signStr(secret.SecretKey, service, region, timeToSign, policyBase64) + sign := SignStr(secret.SecretKey, service, region, timeToSign, policyBase64) return getRequestWithMultipartForm(t, policyBase64, creds, timeToSignStr, sign, "files") }(), @@ -510,7 +510,7 @@ func TestHTTPPostAuthenticate(t *testing.T) { name: "HTTP POST valid with field name with a capital letter", request: func() *http.Request { creds := getCredsStr(accessKeyID, timeToSignStr, region, service) - sign := signStr(secret.SecretKey, service, region, timeToSign, policyBase64) + sign := SignStr(secret.SecretKey, service, region, timeToSign, policyBase64) return getRequestWithMultipartForm(t, policyBase64, creds, timeToSignStr, sign, "File") }(), @@ -530,7 +530,7 @@ func TestHTTPPostAuthenticate(t *testing.T) { name: "HTTP POST invalid signature date time", request: func() *http.Request { creds := getCredsStr(accessKeyID, timeToSignStr, region, service) - sign := signStr(secret.SecretKey, service, region, timeToSign, policyBase64) + sign := SignStr(secret.SecretKey, service, region, timeToSign, policyBase64) return getRequestWithMultipartForm(t, policyBase64, creds, invalidValue, sign, defaultFieldName) }(), @@ -539,7 +539,7 @@ func TestHTTPPostAuthenticate(t *testing.T) { { name: "HTTP POST invalid creds", request: func() *http.Request { - sign := signStr(secret.SecretKey, service, region, timeToSign, policyBase64) + sign := SignStr(secret.SecretKey, service, region, timeToSign, policyBase64) return getRequestWithMultipartForm(t, policyBase64, invalidValue, timeToSignStr, sign, defaultFieldName) }(), @@ -550,7 +550,7 @@ func TestHTTPPostAuthenticate(t *testing.T) { name: "HTTP POST missing policy", request: func() *http.Request { creds := getCredsStr(accessKeyID, timeToSignStr, region, service) - sign := signStr(secret.SecretKey, service, region, timeToSign, policyBase64) + sign := SignStr(secret.SecretKey, service, region, timeToSign, policyBase64) return getRequestWithMultipartForm(t, "", creds, timeToSignStr, sign, defaultFieldName) }(), @@ -560,7 +560,7 @@ func TestHTTPPostAuthenticate(t *testing.T) { name: "HTTP POST invalid accessKeyId", request: func() *http.Request { creds := getCredsStr(invalidValue, timeToSignStr, region, service) - sign := signStr(secret.SecretKey, service, region, timeToSign, policyBase64) + sign := SignStr(secret.SecretKey, service, region, timeToSign, policyBase64) return getRequestWithMultipartForm(t, policyBase64, creds, timeToSignStr, sign, defaultFieldName) }(), @@ -570,7 +570,7 @@ func TestHTTPPostAuthenticate(t *testing.T) { name: "HTTP POST invalid accessKeyId - a non-existent box", request: func() *http.Request { creds := getCredsStr(invalidAccessKeyID, timeToSignStr, region, service) - sign := signStr(secret.SecretKey, service, region, timeToSign, policyBase64) + sign := SignStr(secret.SecretKey, service, region, timeToSign, policyBase64) return getRequestWithMultipartForm(t, policyBase64, creds, timeToSignStr, sign, defaultFieldName) }(), @@ -580,7 +580,7 @@ func TestHTTPPostAuthenticate(t *testing.T) { name: "HTTP POST invalid signature", request: func() *http.Request { creds := getCredsStr(accessKeyID, timeToSignStr, region, service) - sign := signStr(secret.SecretKey, service, region, timeToSign, invalidValue) + sign := SignStr(secret.SecretKey, service, region, timeToSign, invalidValue) return getRequestWithMultipartForm(t, policyBase64, creds, timeToSignStr, sign, defaultFieldName) }(), @@ -617,7 +617,7 @@ func getRequestWithMultipartForm(t *testing.T, policy, creds, date, sign, fieldN writer := multipart.NewWriter(body) defer writer.Close() - err := writer.WriteField("Policy", policy) + err := writer.WriteField("policy", policy) require.NoError(t, err) err = writer.WriteField(AmzCredential, creds) require.NoError(t, err) diff --git a/api/handler/put.go b/api/handler/put.go index 94cc0f3..e740cbb 100644 --- a/api/handler/put.go +++ b/api/handler/put.go @@ -9,6 +9,7 @@ import ( stderrors "errors" "fmt" "io" + "mime/multipart" "net" "net/http" "net/url" @@ -469,21 +470,47 @@ func (h *handler) PostObject(w http.ResponseWriter, r *http.Request) { return } + reqInfo.ObjectName = auth.MultipartFormValue(r, "key") + var contentReader io.Reader var size uint64 + var filename string + if content, ok := r.MultipartForm.Value["file"]; ok { - contentReader = bytes.NewBufferString(content[0]) - size = uint64(len(content[0])) + fullContent := strings.Join(content, "") + contentReader = bytes.NewBufferString(fullContent) + size = uint64(len(fullContent)) + + if reqInfo.ObjectName == "" || strings.Contains(reqInfo.ObjectName, "${filename}") { + _, head, err := r.FormFile("file") + if err != nil { + h.logAndSendError(w, "could not parse file field", reqInfo, err) + return + } + filename = head.Filename + } } else { - file, head, err := r.FormFile("file") + var head *multipart.FileHeader + contentReader, head, err = r.FormFile("file") if err != nil { - h.logAndSendError(w, "could get uploading file", reqInfo, err) + h.logAndSendError(w, "could not parse file field", reqInfo, err) return } - contentReader = file size = uint64(head.Size) - reqInfo.ObjectName = strings.ReplaceAll(reqInfo.ObjectName, "${filename}", head.Filename) + filename = head.Filename } + + if reqInfo.ObjectName == "" { + reqInfo.ObjectName = filename + } else { + reqInfo.ObjectName = strings.ReplaceAll(reqInfo.ObjectName, "${filename}", filename) + } + + if reqInfo.ObjectName == "" { + h.logAndSendError(w, "missing object name", reqInfo, errors.GetAPIError(errors.ErrInvalidArgument)) + return + } + if !policy.CheckContentLength(size) { h.logAndSendError(w, "invalid content-length", reqInfo, errors.GetAPIError(errors.ErrInvalidArgument)) return @@ -599,10 +626,6 @@ func checkPostPolicy(r *http.Request, reqInfo *middleware.ReqInfo, metadata map[ if key == "content-type" { metadata[api.ContentType] = value } - - if key == "key" { - reqInfo.ObjectName = value - } } for _, cond := range policy.Conditions { diff --git a/api/handler/put_test.go b/api/handler/put_test.go index 152a969..f496660 100644 --- a/api/handler/put_test.go +++ b/api/handler/put_test.go @@ -17,6 +17,7 @@ import ( "time" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api" + "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth" v4 "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data" s3errors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors" @@ -122,6 +123,92 @@ func TestEmptyPostPolicy(t *testing.T) { require.NoError(t, err) } +// if content length is greater than this value +// data will be writen to file location. +const maxContentSizeForFormData = 10 + +func TestPostObject(t *testing.T) { + hc := prepareHandlerContext(t) + + ns, bktName := "", "bucket" + createTestBucket(hc, bktName) + + for _, tc := range []struct { + key string + filename string + content string + objName string + err bool + }{ + { + key: "user/user1/${filename}", + filename: "object", + content: "content", + objName: "user/user1/object", + }, + { + key: "user/user1/${filename}", + filename: "object", + content: "maxContentSizeForFormData", + objName: "user/user1/object", + }, + { + key: "user/user1/key-object", + filename: "object", + content: "", + objName: "user/user1/key-object", + }, + { + key: "user/user1/key-object", + filename: "object", + content: "maxContentSizeForFormData", + objName: "user/user1/key-object", + }, + { + key: "", + filename: "object", + content: "", + objName: "object", + }, + { + key: "", + filename: "object", + content: "maxContentSizeForFormData", + objName: "object", + }, + { + // RFC 7578, Section 4.2 requires that if a filename is provided, the + // directory path information must not be used. + key: "", + filename: "dir/object", + content: "content", + objName: "object", + }, + { + key: "object", + filename: "", + content: "content", + objName: "object", + }, + { + key: "", + filename: "", + err: true, + }, + } { + t.Run(tc.key+";"+tc.filename, func(t *testing.T) { + w := postObjectBase(hc, ns, bktName, tc.key, tc.filename, tc.content) + if tc.err { + assertS3Error(hc.t, w, s3errors.GetAPIError(s3errors.ErrInternalError)) + return + } + assertStatus(hc.t, w, http.StatusNoContent) + content, _ := getObject(hc, bktName, tc.objName) + require.Equal(t, tc.content, string(content)) + }) + } +} + func TestPutObjectOverrideCopiesNumber(t *testing.T) { tc := prepareHandlerContext(t) @@ -449,3 +536,85 @@ func TestPutObjectWithContentLanguage(t *testing.T) { tc.Handler().HeadObjectHandler(w, r) require.Equal(t, expectedContentLanguage, w.Header().Get(api.ContentLanguage)) } + +func postObjectBase(hc *handlerContext, ns, bktName, key, filename, content string) *httptest.ResponseRecorder { + policy := "eyJleHBpcmF0aW9uIjogIjIwMjUtMTItMDFUMTI6MDA6MDAuMDAwWiIsImNvbmRpdGlvbnMiOiBbCiBbInN0YXJ0cy13aXRoIiwgIiR4LWFtei1jcmVkZW50aWFsIiwgIiJdLAogWyJzdGFydHMtd2l0aCIsICIkeC1hbXotZGF0ZSIsICIiXSwKIFsic3RhcnRzLXdpdGgiLCAiJGtleSIsICIiXQpdfQ==" + + timeToSign := time.Now() + timeToSignStr := timeToSign.Format("20060102T150405Z") + region := "default" + service := "s3" + + accessKeyID := "5jizSbYu8hX345aqCKDgRWKCJYHxnzxRS8e6SUYHZ8Fw0HiRkf3KbJAWBn5mRzmiyHQ3UHADGyzVXLusn1BrmAfLn" + secretKey := "abf066d77c6744cd956a123a0b9612df587f5c14d3350ecb01b363f182dd7279" + + creds := getCredsStr(accessKeyID, timeToSignStr, region, service) + sign := auth.SignStr(secretKey, service, region, timeToSign, policy) + + body, contentType, err := getMultipartFormBody(policy, creds, timeToSignStr, sign, key, filename, content) + require.NoError(hc.t, err) + + w, r := prepareTestPostRequest(hc, bktName, body) + r.Header.Set(auth.ContentTypeHdr, contentType) + r.Header.Set("X-Frostfs-Namespace", ns) + + err = r.ParseMultipartForm(50 * 1024 * 1024) + require.NoError(hc.t, err) + + hc.Handler().PostObject(w, r) + return w +} + +func getCredsStr(accessKeyID, timeToSign, region, service string) string { + return accessKeyID + "/" + timeToSign + "/" + region + "/" + service + "/aws4_request" +} + +func getMultipartFormBody(policy, creds, date, sign, key, filename, content string) (io.Reader, string, error) { + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + defer writer.Close() + + if err := writer.WriteField("policy", policy); err != nil { + return nil, "", err + } + + if err := writer.WriteField("key", key); err != nil { + return nil, "", err + } + if err := writer.WriteField(strings.ToLower(auth.AmzCredential), creds); err != nil { + return nil, "", err + } + if err := writer.WriteField(strings.ToLower(auth.AmzDate), date); err != nil { + return nil, "", err + } + if err := writer.WriteField(strings.ToLower(auth.AmzSignature), sign); err != nil { + return nil, "", err + } + + file, err := writer.CreateFormFile("file", filename) + if err != nil { + return nil, "", err + } + + if len(content) < maxContentSizeForFormData { + if err = writer.WriteField("file", content); err != nil { + return nil, "", err + } + } else { + if _, err = file.Write([]byte(content)); err != nil { + return nil, "", err + } + } + + return body, writer.FormDataContentType(), nil +} + +func prepareTestPostRequest(hc *handlerContext, bktName string, payload io.Reader) (*httptest.ResponseRecorder, *http.Request) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, defaultURL+bktName, payload) + + reqInfo := middleware.NewReqInfo(w, r, middleware.ObjectRequest{Bucket: bktName}, "") + r = r.WithContext(middleware.SetReqInfo(hc.Context(), reqInfo)) + + return w, r +}