From 562c7e994c62db17d4c4319e12167d579102f559 Mon Sep 17 00:00:00 2001 From: Denis Kirillov Date: Fri, 29 Apr 2022 18:25:26 +0300 Subject: [PATCH] [#148] Add custom multipart reader Signed-off-by: Denis Kirillov --- uploader/multipart.go | 4 +- uploader/multipart/multipart.go | 422 ++++++++++++++++++++++++++++++++ uploader/multipart_test.go | 161 ++++++++++++ 3 files changed, 586 insertions(+), 1 deletion(-) create mode 100644 uploader/multipart/multipart.go create mode 100644 uploader/multipart_test.go diff --git a/uploader/multipart.go b/uploader/multipart.go index c79ab94..5928e60 100644 --- a/uploader/multipart.go +++ b/uploader/multipart.go @@ -2,8 +2,8 @@ package uploader import ( "io" - "mime/multipart" + "github.com/nspcc-dev/neofs-http-gw/uploader/multipart" "go.uber.org/zap" ) @@ -15,6 +15,8 @@ type MultipartFile interface { } func fetchMultipartFile(l *zap.Logger, r io.Reader, boundary string) (MultipartFile, error) { + // To have a custom buffer (3mb) the custom multipart reader is used. + // https://github.com/nspcc-dev/neofs-http-gw/issues/148 reader := multipart.NewReader(r, boundary) for { diff --git a/uploader/multipart/multipart.go b/uploader/multipart/multipart.go new file mode 100644 index 0000000..6fc6031 --- /dev/null +++ b/uploader/multipart/multipart.go @@ -0,0 +1,422 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// + +/* +Package multipart implements MIME multipart parsing, as defined in RFC +2046. + +The implementation is sufficient for HTTP (RFC 2388) and the multipart +bodies generated by popular browsers. +*/ +package multipart + +import ( + "bufio" + "bytes" + "fmt" + "io" + "mime" + "mime/quotedprintable" + "net/textproto" + "strings" +) + +var emptyParams = make(map[string]string) + +// This constant needs to be at least 76 for this package to work correctly. +// This is because \r\n--separator_of_len_70- would fill the buffer and it +// wouldn't be safe to consume a single byte from it. +// This constant is different from the constant in stdlib. The standard value is 4096. +const peekBufferSize = 3 << 20 + +// A Part represents a single part in a multipart body. +type Part struct { + // The headers of the body, if any, with the keys canonicalized + // in the same fashion that the Go http.Request headers are. + // For example, "foo-bar" changes case to "Foo-Bar" + Header textproto.MIMEHeader + + mr *Reader + + disposition string + dispositionParams map[string]string + + // r is either a reader directly reading from mr, or it's a + // wrapper around such a reader, decoding the + // Content-Transfer-Encoding + r io.Reader + + n int // known data bytes waiting in mr.bufReader + total int64 // total data bytes read already + err error // error to return when n == 0 + readErr error // read error observed from mr.bufReader +} + +// FormName returns the name parameter if p has a Content-Disposition +// of type "form-data". Otherwise it returns the empty string. +func (p *Part) FormName() string { + // See https://tools.ietf.org/html/rfc2183 section 2 for EBNF + // of Content-Disposition value format. + if p.dispositionParams == nil { + p.parseContentDisposition() + } + if p.disposition != "form-data" { + return "" + } + return p.dispositionParams["name"] +} + +// FileName returns the filename parameter of the Part's +// Content-Disposition header. +func (p *Part) FileName() string { + if p.dispositionParams == nil { + p.parseContentDisposition() + } + return p.dispositionParams["filename"] +} + +func (p *Part) parseContentDisposition() { + v := p.Header.Get("Content-Disposition") + var err error + p.disposition, p.dispositionParams, err = mime.ParseMediaType(v) + if err != nil { + p.dispositionParams = emptyParams + } +} + +// NewReader creates a new multipart Reader reading from r using the +// given MIME boundary. +// +// The boundary is usually obtained from the "boundary" parameter of +// the message's "Content-Type" header. Use mime.ParseMediaType to +// parse such headers. +func NewReader(r io.Reader, boundary string) *Reader { + b := []byte("\r\n--" + boundary + "--") + return &Reader{ + bufReader: bufio.NewReaderSize(&stickyErrorReader{r: r}, peekBufferSize), + nl: b[:2], + nlDashBoundary: b[:len(b)-2], + dashBoundaryDash: b[2:], + dashBoundary: b[2 : len(b)-2], + } +} + +// stickyErrorReader is an io.Reader which never calls Read on its +// underlying Reader once an error has been seen. (the io.Reader +// interface's contract promises nothing about the return values of +// Read calls after an error, yet this package does do multiple Reads +// after error) +type stickyErrorReader struct { + r io.Reader + err error +} + +func (r *stickyErrorReader) Read(p []byte) (n int, _ error) { + if r.err != nil { + return 0, r.err + } + n, r.err = r.r.Read(p) + return n, r.err +} + +func newPart(mr *Reader, rawPart bool) (*Part, error) { + bp := &Part{ + Header: make(map[string][]string), + mr: mr, + } + if err := bp.populateHeaders(); err != nil { + return nil, err + } + bp.r = partReader{bp} + + // rawPart is used to switch between Part.NextPart and Part.NextRawPart. + if !rawPart { + const cte = "Content-Transfer-Encoding" + if strings.EqualFold(bp.Header.Get(cte), "quoted-printable") { + bp.Header.Del(cte) + bp.r = quotedprintable.NewReader(bp.r) + } + } + return bp, nil +} + +func (bp *Part) populateHeaders() error { + r := textproto.NewReader(bp.mr.bufReader) + header, err := r.ReadMIMEHeader() + if err == nil { + bp.Header = header + } + return err +} + +// Read reads the body of a part, after its headers and before the +// next part (if any) begins. +func (p *Part) Read(d []byte) (n int, err error) { + return p.r.Read(d) +} + +// partReader implements io.Reader by reading raw bytes directly from the +// wrapped *Part, without doing any Transfer-Encoding decoding. +type partReader struct { + p *Part +} + +func (pr partReader) Read(d []byte) (int, error) { + p := pr.p + br := p.mr.bufReader + + // Read into buffer until we identify some data to return, + // or we find a reason to stop (boundary or read error). + for p.n == 0 && p.err == nil { + peek, _ := br.Peek(br.Buffered()) + p.n, p.err = scanUntilBoundary(peek, p.mr.dashBoundary, p.mr.nlDashBoundary, p.total, p.readErr) + if p.n == 0 && p.err == nil { + // Force buffered I/O to read more into buffer. + _, p.readErr = br.Peek(len(peek) + 1) + if p.readErr == io.EOF { + p.readErr = io.ErrUnexpectedEOF + } + } + } + + // Read out from "data to return" part of buffer. + if p.n == 0 { + return 0, p.err + } + n := len(d) + if n > p.n { + n = p.n + } + n, _ = br.Read(d[:n]) + p.total += int64(n) + p.n -= n + if p.n == 0 { + return n, p.err + } + return n, nil +} + +// scanUntilBoundary scans buf to identify how much of it can be safely +// returned as part of the Part body. +// dashBoundary is "--boundary". +// nlDashBoundary is "\r\n--boundary" or "\n--boundary", depending on what mode we are in. +// The comments below (and the name) assume "\n--boundary", but either is accepted. +// total is the number of bytes read out so far. If total == 0, then a leading "--boundary" is recognized. +// readErr is the read error, if any, that followed reading the bytes in buf. +// scanUntilBoundary returns the number of data bytes from buf that can be +// returned as part of the Part body and also the error to return (if any) +// once those data bytes are done. +func scanUntilBoundary(buf, dashBoundary, nlDashBoundary []byte, total int64, readErr error) (int, error) { + if total == 0 { + // At beginning of body, allow dashBoundary. + if bytes.HasPrefix(buf, dashBoundary) { + switch matchAfterPrefix(buf, dashBoundary, readErr) { + case -1: + return len(dashBoundary), nil + case 0: + return 0, nil + case +1: + return 0, io.EOF + } + } + if bytes.HasPrefix(dashBoundary, buf) { + return 0, readErr + } + } + + // Search for "\n--boundary". + if i := bytes.Index(buf, nlDashBoundary); i >= 0 { + switch matchAfterPrefix(buf[i:], nlDashBoundary, readErr) { + case -1: + return i + len(nlDashBoundary), nil + case 0: + return i, nil + case +1: + return i, io.EOF + } + } + if bytes.HasPrefix(nlDashBoundary, buf) { + return 0, readErr + } + + // Otherwise, anything up to the final \n is not part of the boundary + // and so must be part of the body. + // Also if the section from the final \n onward is not a prefix of the boundary, + // it too must be part of the body. + i := bytes.LastIndexByte(buf, nlDashBoundary[0]) + if i >= 0 && bytes.HasPrefix(nlDashBoundary, buf[i:]) { + return i, nil + } + return len(buf), readErr +} + +// matchAfterPrefix checks whether buf should be considered to match the boundary. +// The prefix is "--boundary" or "\r\n--boundary" or "\n--boundary", +// and the caller has verified already that bytes.HasPrefix(buf, prefix) is true. +// +// matchAfterPrefix returns +1 if the buffer does match the boundary, +// meaning the prefix is followed by a dash, space, tab, cr, nl, or end of input. +// It returns -1 if the buffer definitely does NOT match the boundary, +// meaning the prefix is followed by some other character. +// For example, "--foobar" does not match "--foo". +// It returns 0 more input needs to be read to make the decision, +// meaning that len(buf) == len(prefix) and readErr == nil. +func matchAfterPrefix(buf, prefix []byte, readErr error) int { + if len(buf) == len(prefix) { + if readErr != nil { + return +1 + } + return 0 + } + c := buf[len(prefix)] + if c == ' ' || c == '\t' || c == '\r' || c == '\n' || c == '-' { + return +1 + } + return -1 +} + +func (p *Part) Close() error { + io.Copy(io.Discard, p) + return nil +} + +// Reader is an iterator over parts in a MIME multipart body. +// Reader's underlying parser consumes its input as needed. Seeking +// isn't supported. +type Reader struct { + bufReader *bufio.Reader + + currentPart *Part + partsRead int + + nl []byte // "\r\n" or "\n" (set after seeing first boundary line) + nlDashBoundary []byte // nl + "--boundary" + dashBoundaryDash []byte // "--boundary--" + dashBoundary []byte // "--boundary" +} + +// NextPart returns the next part in the multipart or an error. +// When there are no more parts, the error io.EOF is returned. +// +// As a special case, if the "Content-Transfer-Encoding" header +// has a value of "quoted-printable", that header is instead +// hidden and the body is transparently decoded during Read calls. +func (r *Reader) NextPart() (*Part, error) { + return r.nextPart(false) +} + +// NextRawPart returns the next part in the multipart or an error. +// When there are no more parts, the error io.EOF is returned. +// +// Unlike NextPart, it does not have special handling for +// "Content-Transfer-Encoding: quoted-printable". +func (r *Reader) NextRawPart() (*Part, error) { + return r.nextPart(true) +} + +func (r *Reader) nextPart(rawPart bool) (*Part, error) { + if r.currentPart != nil { + r.currentPart.Close() + } + if string(r.dashBoundary) == "--" { + return nil, fmt.Errorf("multipart: boundary is empty") + } + expectNewPart := false + for { + line, err := r.bufReader.ReadSlice('\n') + + if err == io.EOF && r.isFinalBoundary(line) { + // If the buffer ends in "--boundary--" without the + // trailing "\r\n", ReadSlice will return an error + // (since it's missing the '\n'), but this is a valid + // multipart EOF so we need to return io.EOF instead of + // a fmt-wrapped one. + return nil, io.EOF + } + if err != nil { + return nil, fmt.Errorf("multipart: NextPart: %v", err) + } + + if r.isBoundaryDelimiterLine(line) { + r.partsRead++ + bp, err := newPart(r, rawPart) + if err != nil { + return nil, err + } + r.currentPart = bp + return bp, nil + } + + if r.isFinalBoundary(line) { + // Expected EOF + return nil, io.EOF + } + + if expectNewPart { + return nil, fmt.Errorf("multipart: expecting a new Part; got line %q", string(line)) + } + + if r.partsRead == 0 { + // skip line + continue + } + + // Consume the "\n" or "\r\n" separator between the + // body of the previous part and the boundary line we + // now expect will follow. (either a new part or the + // end boundary) + if bytes.Equal(line, r.nl) { + expectNewPart = true + continue + } + + return nil, fmt.Errorf("multipart: unexpected line in Next(): %q", line) + } +} + +// isFinalBoundary reports whether line is the final boundary line +// indicating that all parts are over. +// It matches `^--boundary--[ \t]*(\r\n)?$` +func (mr *Reader) isFinalBoundary(line []byte) bool { + if !bytes.HasPrefix(line, mr.dashBoundaryDash) { + return false + } + rest := line[len(mr.dashBoundaryDash):] + rest = skipLWSPChar(rest) + return len(rest) == 0 || bytes.Equal(rest, mr.nl) +} + +func (mr *Reader) isBoundaryDelimiterLine(line []byte) (ret bool) { + // https://tools.ietf.org/html/rfc2046#section-5.1 + // The boundary delimiter line is then defined as a line + // consisting entirely of two hyphen characters ("-", + // decimal value 45) followed by the boundary parameter + // value from the Content-Type header field, optional linear + // whitespace, and a terminating CRLF. + if !bytes.HasPrefix(line, mr.dashBoundary) { + return false + } + rest := line[len(mr.dashBoundary):] + rest = skipLWSPChar(rest) + + // On the first part, see our lines are ending in \n instead of \r\n + // and switch into that mode if so. This is a violation of the spec, + // but occurs in practice. + if mr.partsRead == 0 && len(rest) == 1 && rest[0] == '\n' { + mr.nl = mr.nl[1:] + mr.nlDashBoundary = mr.nlDashBoundary[1:] + } + return bytes.Equal(rest, mr.nl) +} + +// skipLWSPChar returns b with leading spaces and tabs removed. +// RFC 822 defines: +// LWSP-char = SPACE / HTAB +func skipLWSPChar(b []byte) []byte { + for len(b) > 0 && (b[0] == ' ' || b[0] == '\t') { + b = b[1:] + } + return b +} diff --git a/uploader/multipart_test.go b/uploader/multipart_test.go new file mode 100644 index 0000000..9763f88 --- /dev/null +++ b/uploader/multipart_test.go @@ -0,0 +1,161 @@ +package uploader + +import ( + "crypto/rand" + "fmt" + "io" + "mime/multipart" + "os" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func generateRandomFile(size int64) (string, error) { + file, err := os.CreateTemp("", "data") + if err != nil { + return "", err + } + + _, err = io.CopyN(file, rand.Reader, size) + if err != nil { + return "", err + } + + return file.Name(), file.Close() +} + +func BenchmarkAll(b *testing.B) { + fileName, err := generateRandomFile(1024 * 1024 * 256) + require.NoError(b, err) + fmt.Println(fileName) + defer os.Remove(fileName) + + b.Run("bare", func(b *testing.B) { + for i := 0; i < b.N; i++ { + err := bareRead(fileName) + require.NoError(b, err) + } + }) + + b.Run("default", func(b *testing.B) { + for i := 0; i < b.N; i++ { + err := defaultMultipart(fileName) + require.NoError(b, err) + } + }) + + b.Run("custom", func(b *testing.B) { + for i := 0; i < b.N; i++ { + err := customMultipart(fileName) + require.NoError(b, err) + } + }) +} + +func defaultMultipart(filename string) error { + r, bound := multipartFile(filename) + + logger, err := zap.NewProduction() + if err != nil { + return err + } + + file, err := fetchMultipartFileDefault(logger, r, bound) + if err != nil { + return err + } + + _, err = io.Copy(io.Discard, file) + return err +} + +func TestName(t *testing.T) { + fileName, err := generateRandomFile(1024 * 1024 * 256) + require.NoError(t, err) + fmt.Println(fileName) + defer os.Remove(fileName) + + err = defaultMultipart(fileName) + require.NoError(t, err) +} + +func customMultipart(filename string) error { + r, bound := multipartFile(filename) + + logger, err := zap.NewProduction() + if err != nil { + return err + } + + file, err := fetchMultipartFile(logger, r, bound) + if err != nil { + return err + } + + _, err = io.Copy(io.Discard, file) + return err +} + +func fetchMultipartFileDefault(l *zap.Logger, r io.Reader, boundary string) (MultipartFile, error) { + reader := multipart.NewReader(r, boundary) + + for { + part, err := reader.NextPart() + if err != nil { + return nil, err + } + + name := part.FormName() + if name == "" { + l.Debug("ignore part, empty form name") + continue + } + + filename := part.FileName() + + // ignore multipart/form-data values + if filename == "" { + l.Debug("ignore part, empty filename", zap.String("form", name)) + + continue + } + + return part, nil + } +} + +func bareRead(filename string) error { + r, _ := multipartFile(filename) + + _, err := io.Copy(io.Discard, r) + return err +} + +func multipartFile(filename string) (*io.PipeReader, string) { + r, w := io.Pipe() + m := multipart.NewWriter(w) + go func() { + defer w.Close() + defer m.Close() + part, err := m.CreateFormFile("myFile", "foo.txt") + if err != nil { + fmt.Println(err) + return + } + + file, err := os.Open(filename) + if err != nil { + fmt.Println(err) + return + } + defer file.Close() + if _, err = io.Copy(part, file); err != nil { + fmt.Println(err) + return + } + }() + + return r, m.Boundary() +}