diff --git a/api/handler/multipart_upload.go b/api/handler/multipart_upload.go index 18d717c..edffd06 100644 --- a/api/handler/multipart_upload.go +++ b/api/handler/multipart_upload.go @@ -199,17 +199,15 @@ func (h *handler) UploadPartHandler(w http.ResponseWriter, r *http.Request) { return } - body, decodedSize, err := h.getBodyReader(r) + body, err := h.getBodyReader(r) if err != nil { h.logAndSendError(w, "failed to get body reader", reqInfo, err, additional...) return } - var size uint64 - if decodedSize > 0 { - size = uint64(decodedSize) - } else if r.ContentLength > 0 { - size = uint64(r.ContentLength) + size, err := h.getPutPayloadSize(r) + if err != nil { + h.logAndSendError(w, "failed to get payload size", reqInfo, err, additional...) } p := &layer.UploadPartParams{ diff --git a/api/handler/put.go b/api/handler/put.go index c4f0453..c0b35de 100644 --- a/api/handler/put.go +++ b/api/handler/put.go @@ -233,7 +233,7 @@ func (h *handler) PutObjectHandler(w http.ResponseWriter, r *http.Request) { return } - body, decodedSize, err := h.getBodyReader(r) + body, err := h.getBodyReader(r) if err != nil { h.logAndSendError(w, "failed to get body reader", reqInfo, err) return @@ -242,12 +242,9 @@ func (h *handler) PutObjectHandler(w http.ResponseWriter, r *http.Request) { metadata[api.ContentEncoding] = encodings } - var size uint64 - - if decodedSize > 0 { - size = uint64(decodedSize) - } else if r.ContentLength > 0 { - size = uint64(r.ContentLength) + size, err := h.getPutPayloadSize(r) + if err != nil { + return } params := &layer.PutObjectParams{ @@ -313,9 +310,9 @@ func (h *handler) PutObjectHandler(w http.ResponseWriter, r *http.Request) { } } -func (h *handler) getBodyReader(r *http.Request) (io.ReadCloser, int, error) { +func (h *handler) getBodyReader(r *http.Request) (io.ReadCloser, error) { if !api.IsSignedStreamingV4(r) { - return r.Body, -1, nil + return r.Body, nil } encodings := r.Header.Values(api.ContentEncoding) @@ -334,26 +331,25 @@ func (h *handler) getBodyReader(r *http.Request) (io.ReadCloser, int, error) { r.Header.Set(api.ContentEncoding, strings.Join(resultContentEncoding, ",")) if !chunkedEncoding && !h.cfg.BypassContentEncodingInChunks() { - return nil, -1, fmt.Errorf("%w: request is not chunk encoded, encodings '%s'", + return nil, fmt.Errorf("%w: request is not chunk encoded, encodings '%s'", errors.GetAPIError(errors.ErrInvalidEncodingMethod), strings.Join(encodings, ",")) } decodeContentSize := r.Header.Get(api.AmzDecodedContentLength) if len(decodeContentSize) == 0 { - return nil, -1, errors.GetAPIError(errors.ErrMissingContentLength) + return nil, errors.GetAPIError(errors.ErrMissingContentLength) } - decoded, err := strconv.Atoi(decodeContentSize) - if err != nil { - return nil, -1, fmt.Errorf("%w: parse decoded content length: %s", errors.GetAPIError(errors.ErrMissingContentLength), err.Error()) + if _, err := strconv.Atoi(decodeContentSize); err != nil { + return nil, fmt.Errorf("%w: parse decoded content length: %s", errors.GetAPIError(errors.ErrMissingContentLength), err.Error()) } chunkReader, err := newSignV4ChunkedReader(r) if err != nil { - return nil, -1, fmt.Errorf("initialize chunk reader: %w", err) + return nil, fmt.Errorf("initialize chunk reader: %w", err) } - return chunkReader, decoded, nil + return chunkReader, nil } func formEncryptionParams(r *http.Request) (enc encryption.Params, err error) { diff --git a/api/handler/util.go b/api/handler/util.go index dde5274..464f878 100644 --- a/api/handler/util.go +++ b/api/handler/util.go @@ -106,6 +106,27 @@ func (h *handler) getBucketAndCheckOwner(r *http.Request, bucket string, header return bktInfo, checkOwner(bktInfo, expected) } +func (h *handler) getPutPayloadSize(r *http.Request) (uint64, error) { + decodeContentSize := r.Header.Get(api.AmzDecodedContentLength) + if len(decodeContentSize) == 0 { + return 0, s3errors.GetAPIError(s3errors.ErrMissingContentLength) + } + + decodedSize, err := strconv.Atoi(decodeContentSize) + if err != nil { + return 0, fmt.Errorf("%w: parse decoded content length: %s", s3errors.GetAPIError(s3errors.ErrMissingContentLength), err.Error()) + } + + var size uint64 + if decodedSize > 0 { + size = uint64(decodedSize) + } else if r.ContentLength > 0 { + size = uint64(r.ContentLength) + } + + return size, nil +} + func parseRange(s string) (*layer.RangeParams, error) { if s == "" { return nil, nil