diff --git a/api/handler/lifecycle.go b/api/handler/lifecycle.go index e458101..bd89fe9 100644 --- a/api/handler/lifecycle.go +++ b/api/handler/lifecycle.go @@ -1,8 +1,11 @@ package handler import ( + "bytes" + "crypto/md5" "encoding/base64" "fmt" + "io" "math" "net/http" "time" @@ -44,6 +47,9 @@ func (h *handler) GetBucketLifecycleHandler(w http.ResponseWriter, r *http.Reque } func (h *handler) PutBucketLifecycleHandler(w http.ResponseWriter, r *http.Request) { + var buf bytes.Buffer + + tee := io.TeeReader(r.Body, &buf) ctx := r.Context() reqInfo := middleware.GetReqInfo(ctx) @@ -54,23 +60,35 @@ func (h *handler) PutBucketLifecycleHandler(w http.ResponseWriter, r *http.Reque return } - if _, err := base64.StdEncoding.DecodeString(r.Header.Get(api.ContentMD5)); err != nil { + headerMD5, err := base64.StdEncoding.DecodeString(r.Header.Get(api.ContentMD5)) + if err != nil { h.logAndSendError(w, "invalid Content-MD5", reqInfo, apierr.GetAPIError(apierr.ErrInvalidDigest)) return } + cfg := new(data.LifecycleConfiguration) + if err = h.cfg.NewXMLDecoder(tee).Decode(cfg); err != nil { + h.logAndSendError(w, "could not decode body", reqInfo, fmt.Errorf("%w: %s", apierr.GetAPIError(apierr.ErrMalformedXML), err.Error())) + return + } + + bodyMD5, err := getContentMD5(&buf) + if err != nil { + h.logAndSendError(w, "could not get content md5", reqInfo, err) + return + } + + if !bytes.Equal(headerMD5, bodyMD5) { + h.logAndSendError(w, "Content-MD5 does not match", reqInfo, apierr.GetAPIError(apierr.ErrInvalidDigest)) + return + } + bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName) if err != nil { h.logAndSendError(w, "could not get bucket info", reqInfo, err) return } - cfg := new(data.LifecycleConfiguration) - if err = h.cfg.NewXMLDecoder(r.Body).Decode(cfg); err != nil { - h.logAndSendError(w, "could not decode body", reqInfo, fmt.Errorf("%w: %s", apierr.GetAPIError(apierr.ErrMalformedXML), err.Error())) - return - } - networkInfo, err := h.obj.GetNetworkInfo(ctx) if err != nil { h.logAndSendError(w, "could not get network info", reqInfo, err) @@ -286,3 +304,24 @@ func timeToEpoch(ni *netmap.NetworkInfo, t time.Time) (uint64, error) { return epoch, nil } + +func getContentMD5(reader io.Reader) ([]byte, error) { + hash := md5.New() + buf := make([]byte, 64*1024) + + for { + n, err := reader.Read(buf) + if n > 0 { + println(string(buf[:n])) + hash.Write(buf[:n]) + } + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + } + + return hash.Sum(nil), nil +} diff --git a/api/handler/lifecycle_test.go b/api/handler/lifecycle_test.go index ecb9f26..5c5b96d 100644 --- a/api/handler/lifecycle_test.go +++ b/api/handler/lifecycle_test.go @@ -1,6 +1,7 @@ package handler import ( + "bytes" "crypto/md5" "crypto/rand" "encoding/base64" @@ -377,6 +378,11 @@ func TestPutBucketLifecycleInvalidMD5(t *testing.T) { hc.Handler().PutBucketLifecycleHandler(w, r) assertS3Error(hc.t, w, apierr.GetAPIError(apierr.ErrMissingContentMD5)) + w, r = prepareTestRequest(hc, bktName, "", lifecycle) + r.Header.Set(api.ContentMD5, "") + hc.Handler().PutBucketLifecycleHandler(w, r) + assertS3Error(hc.t, w, apierr.GetAPIError(apierr.ErrInvalidDigest)) + w, r = prepareTestRequest(hc, bktName, "", lifecycle) r.Header.Set(api.ContentMD5, "some-hash") hc.Handler().PutBucketLifecycleHandler(w, r) @@ -389,8 +395,14 @@ func TestPutBucketLifecycleInvalidXML(t *testing.T) { bktName := "bucket-lifecycle-invalid-xml" createBucket(hc, bktName) - w, r := prepareTestRequest(hc, bktName, "", &data.CORSConfiguration{}) - r.Header.Set(api.ContentMD5, "") + cfg := &data.CORSConfiguration{} + body, err := xml.Marshal(cfg) + require.NoError(t, err) + contentMD5, err := getContentMD5(bytes.NewReader(body)) + require.NoError(t, err) + + w, r := prepareTestRequest(hc, bktName, "", cfg) + r.Header.Set(api.ContentMD5, base64.StdEncoding.EncodeToString(contentMD5)) hc.Handler().PutBucketLifecycleHandler(w, r) assertS3Error(hc.t, w, apierr.GetAPIError(apierr.ErrMalformedXML)) }