diff --git a/downloader/download.go b/downloader/download.go index dd810f1..1718db8 100644 --- a/downloader/download.go +++ b/downloader/download.go @@ -2,6 +2,7 @@ package downloader import ( "archive/zip" + "bytes" "context" "errors" "fmt" @@ -25,74 +26,13 @@ import ( "go.uber.org/zap" ) -type ( - detector struct { - io.Reader - err error - contentType string - done chan struct{} - data []byte - } - - request struct { - *fasthttp.RequestCtx - log *zap.Logger - } - - errReader struct { - data []byte - err error - offset int - } -) +type request struct { + *fasthttp.RequestCtx + log *zap.Logger +} var errObjectNotFound = errors.New("object not found") -func newReader(data []byte, err error) *errReader { - return &errReader{data: data, err: err} -} - -func (r *errReader) Read(b []byte) (int, error) { - if r.offset >= len(r.data) { - return 0, io.EOF - } - n := copy(b, r.data[r.offset:]) - r.offset += n - if r.offset >= len(r.data) { - return n, r.err - } - return n, nil -} - -const contentTypeDetectSize = 512 - -func newDetector() *detector { - return &detector{done: make(chan struct{}), data: make([]byte, contentTypeDetectSize)} -} - -func (d *detector) Wait() { - <-d.done -} - -func (d *detector) SetReader(reader io.Reader) { - d.Reader = reader -} - -func (d *detector) Detect() { - n, err := d.Reader.Read(d.data) - if err != nil && err != io.EOF { - d.err = err - return - } - d.data = d.data[:n] - d.contentType = http.DetectContentType(d.data) - close(d.done) -} - -func (d *detector) MultiReader() io.Reader { - return io.MultiReader(newReader(d.data, d.err), d.Reader) -} - func isValidToken(s string) bool { for _, c := range s { if c <= ' ' || c > 127 { @@ -115,6 +55,35 @@ func isValidValue(s string) bool { return true } +type readCloser struct { + io.Reader + io.Closer +} + +// initializes io.Reader with limited size and detects Content-Type from it. +// Returns r's error directly. Also returns processed data. +func readContentType(maxSize uint64, rInit func(uint64) (io.Reader, error)) (string, []byte, error) { + if maxSize > sizeToDetectType { + maxSize = sizeToDetectType + } + + buf := make([]byte, maxSize) // maybe sync-pool the slice? + + r, err := rInit(maxSize) + if err != nil { + return "", nil, err + } + + n, err := r.Read(buf) + if err != nil && err != io.EOF { + return "", nil, err + } + + buf = buf[:n] + + return http.DetectContentType(buf), buf, err // to not lose io.EOF +} + func (r request) receiveFile(clnt pool.Object, objectAddress *address.Address) { var ( err error @@ -140,12 +109,9 @@ func (r request) receiveFile(clnt pool.Object, objectAddress *address.Address) { dis = "attachment" } - readDetector := newDetector() - readDetector.SetReader(rObj.Payload) - readDetector.Detect() + payloadSize := rObj.Header.PayloadSize() - r.Response.SetBodyStream(readDetector.MultiReader(), int(rObj.Header.PayloadSize())) - r.Response.Header.Set(fasthttp.HeaderContentLength, strconv.FormatUint(rObj.Header.PayloadSize(), 10)) + r.Response.Header.Set(fasthttp.HeaderContentLength, strconv.FormatUint(payloadSize, 10)) var contentType string for _, attr := range rObj.Header.Attributes() { key := attr.Key() @@ -179,17 +145,34 @@ func (r request) receiveFile(clnt pool.Object, objectAddress *address.Address) { idsToResponse(&r.Response, &rObj.Header) if len(contentType) == 0 { - if readDetector.err != nil { - r.log.Error("could not read object", zap.Error(err)) - response.Error(r.RequestCtx, "could not read object", fasthttp.StatusBadRequest) + // determine the Content-Type from the payload head + var payloadHead []byte + + contentType, payloadHead, err = readContentType(payloadSize, func(uint64) (io.Reader, error) { + return rObj.Payload, nil + }) + if err != nil && err != io.EOF { + r.log.Error("could not detect Content-Type from payload", zap.Error(err)) + response.Error(r.RequestCtx, "could not detect Content-Type from payload", fasthttp.StatusBadRequest) return } - readDetector.Wait() - contentType = readDetector.contentType + + // reset payload reader since part of the data has been read + var r io.Reader = bytes.NewReader(payloadHead) + + if err != io.EOF { // otherwise, we've already read full payload + r = io.MultiReader(r, rObj.Payload) + } + + // note: we could do with io.Reader, but SetBodyStream below closes body stream + // if it implements io.Closer and that's useful for us. + rObj.Payload = readCloser{r, rObj.Payload} } r.SetContentType(contentType) r.Response.Header.Set(fasthttp.HeaderContentDisposition, dis+"; filename="+path.Base(filename)) + + r.Response.SetBodyStream(rObj.Payload, int(payloadSize)) } // systemBackwardTranslator is used to convert headers looking like '__NEOFS__ATTR_NAME' to 'Neofs-Attr-Name'. diff --git a/downloader/head.go b/downloader/head.go index 7d9fdae..620d15d 100644 --- a/downloader/head.go +++ b/downloader/head.go @@ -16,6 +16,7 @@ import ( "go.uber.org/zap" ) +// max bytes needed to detect content type according to http.DetectContentType docs. const sizeToDetectType = 512 const ( @@ -67,27 +68,13 @@ func (r request) headObject(clnt pool.Object, objectAddress *address.Address) { idsToResponse(&r.Response, obj) if len(contentType) == 0 { - sz := obj.PayloadSize() - if sz > sizeToDetectType { - sz = sizeToDetectType - } - - res, err := clnt.ObjectRange(r.RequestCtx, *objectAddress, 0, sz, bearerOpt) - if err != nil { + contentType, _, err = readContentType(obj.PayloadSize(), func(sz uint64) (io.Reader, error) { + return clnt.ObjectRange(r.RequestCtx, *objectAddress, 0, sz, bearerOpt) + }) + if err != nil && err != io.EOF { r.handleNeoFSErr(err, start) return } - - defer res.Close() - - data := make([]byte, sz) // sync-pool it? - - _, err = io.ReadFull(res, data) - if err != nil { - r.handleNeoFSErr(err, start) - return - } - contentType = http.DetectContentType(data) } r.SetContentType(contentType) } diff --git a/downloader/reader_test.go b/downloader/reader_test.go index 96a42e1..8d58185 100644 --- a/downloader/reader_test.go +++ b/downloader/reader_test.go @@ -1,8 +1,6 @@ package downloader import ( - "bytes" - "fmt" "io" "strings" "testing" @@ -10,40 +8,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestReader(t *testing.T) { - data := []byte("test string") - err := fmt.Errorf("something wrong") - - for _, tc := range []struct { - err error - buff []byte - }{ - {err: nil, buff: make([]byte, len(data)+1)}, - {err: nil, buff: make([]byte, len(data))}, - {err: nil, buff: make([]byte, len(data)-1)}, - {err: err, buff: make([]byte, len(data)+1)}, - {err: err, buff: make([]byte, len(data))}, - {err: err, buff: make([]byte, len(data)-1)}, - } { - var res []byte - var err error - var n int - - r := newReader(data, tc.err) - for err == nil { - n, err = r.Read(tc.buff) - res = append(res, tc.buff[:n]...) - } - - if tc.err == nil { - require.Equal(t, io.EOF, err) - } else { - require.Equal(t, tc.err, err) - } - require.Equal(t, data, res) - } -} - func TestDetector(t *testing.T) { txtContentType := "text/plain; charset=utf-8" sb := strings.Builder{} @@ -68,19 +32,15 @@ func TestDetector(t *testing.T) { }, } { t.Run(tc.Name, func(t *testing.T) { - detector := newDetector() + contentType, data, err := readContentType(uint64(len(tc.Expected)), + func(sz uint64) (io.Reader, error) { + return strings.NewReader(tc.Expected), nil + }, + ) - go func() { - detector.SetReader(bytes.NewBufferString(tc.Expected)) - detector.Detect() - }() - - detector.Wait() - require.Equal(t, tc.ContentType, detector.contentType) - - data, err := io.ReadAll(detector.MultiReader()) require.NoError(t, err) - require.Equal(t, tc.Expected, string(data)) + require.Equal(t, tc.ContentType, contentType) + require.True(t, strings.HasPrefix(tc.Expected, string(data))) }) } }