diff --git a/api/handler/lifecycle.go b/api/handler/lifecycle.go index 88e96cf..9bd1df8 100644 --- a/api/handler/lifecycle.go +++ b/api/handler/lifecycle.go @@ -1,9 +1,12 @@ package handler import ( + "bytes" "context" + "crypto/md5" "encoding/base64" "fmt" + "io" "net/http" "time" @@ -45,6 +48,9 @@ func (h *handler) GetBucketLifecycleHandler(w http.ResponseWriter, r *http.Reque } func (h *handler) PutBucketLifecycleHandler(w http.ResponseWriter, r *http.Request) { + var buf bytes.Buffer + + tee := io.TeeReader(r.Body, &buf) ctx := r.Context() reqInfo := middleware.GetReqInfo(ctx) @@ -55,23 +61,35 @@ func (h *handler) PutBucketLifecycleHandler(w http.ResponseWriter, r *http.Reque return } - if _, err := base64.StdEncoding.DecodeString(r.Header.Get(api.ContentMD5)); err != nil { + headerMD5, err := base64.StdEncoding.DecodeString(r.Header.Get(api.ContentMD5)) + if err != nil { h.logAndSendError(w, "invalid Content-MD5", reqInfo, apierr.GetAPIError(apierr.ErrInvalidDigest)) return } + cfg := new(data.LifecycleConfiguration) + if err = h.cfg.NewXMLDecoder(tee).Decode(cfg); err != nil { + h.logAndSendError(w, "could not decode body", reqInfo, fmt.Errorf("%w: %s", apierr.GetAPIError(apierr.ErrMalformedXML), err.Error())) + return + } + + bodyMD5, err := getContentMD5(&buf) + if err != nil { + h.logAndSendError(w, "could not get content md5", reqInfo, err) + return + } + + if !bytes.Equal(headerMD5, bodyMD5) { + h.logAndSendError(w, "Content-MD5 does not match", reqInfo, apierr.GetAPIError(apierr.ErrInvalidDigest)) + return + } + bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName) if err != nil { h.logAndSendError(w, "could not get bucket info", reqInfo, err) return } - cfg := new(data.LifecycleConfiguration) - if err = h.cfg.NewXMLDecoder(r.Body).Decode(cfg); err != nil { - h.logAndSendError(w, "could not decode body", reqInfo, fmt.Errorf("%w: %s", apierr.GetAPIError(apierr.ErrMalformedXML), err.Error())) - return - } - networkInfo, err := h.obj.GetNetworkInfo(ctx) if err != nil { h.logAndSendError(w, "could not get network info", reqInfo, err) @@ -253,3 +271,12 @@ func checkLifecycleRuleFilter(filter *data.LifecycleRuleFilter) error { return nil } + +func getContentMD5(reader io.Reader) ([]byte, error) { + hash := md5.New() + _, err := io.Copy(hash, reader) + if err != nil { + return nil, err + } + return hash.Sum(nil), nil +} diff --git a/api/handler/lifecycle_test.go b/api/handler/lifecycle_test.go index 8cba0fc..0794d55 100644 --- a/api/handler/lifecycle_test.go +++ b/api/handler/lifecycle_test.go @@ -1,6 +1,7 @@ package handler import ( + "bytes" "crypto/md5" "crypto/rand" "encoding/base64" @@ -376,6 +377,11 @@ func TestPutBucketLifecycleInvalidMD5(t *testing.T) { hc.Handler().PutBucketLifecycleHandler(w, r) assertS3Error(hc.t, w, apierr.GetAPIError(apierr.ErrMissingContentMD5)) + w, r = prepareTestRequest(hc, bktName, "", lifecycle) + r.Header.Set(api.ContentMD5, "") + hc.Handler().PutBucketLifecycleHandler(w, r) + assertS3Error(hc.t, w, apierr.GetAPIError(apierr.ErrInvalidDigest)) + w, r = prepareTestRequest(hc, bktName, "", lifecycle) r.Header.Set(api.ContentMD5, "some-hash") hc.Handler().PutBucketLifecycleHandler(w, r) @@ -388,8 +394,14 @@ func TestPutBucketLifecycleInvalidXML(t *testing.T) { bktName := "bucket-lifecycle-invalid-xml" createBucket(hc, bktName) - w, r := prepareTestRequest(hc, bktName, "", &data.CORSConfiguration{}) - r.Header.Set(api.ContentMD5, "") + cfg := &data.CORSConfiguration{} + body, err := xml.Marshal(cfg) + require.NoError(t, err) + contentMD5, err := getContentMD5(bytes.NewReader(body)) + require.NoError(t, err) + + w, r := prepareTestRequest(hc, bktName, "", cfg) + r.Header.Set(api.ContentMD5, base64.StdEncoding.EncodeToString(contentMD5)) hc.Handler().PutBucketLifecycleHandler(w, r) assertS3Error(hc.t, w, apierr.GetAPIError(apierr.ErrMalformedXML)) }