[#128] downloader: Simplify detecting the Content-Type from payload
Signed-off-by: Leonard Lyubich <leonard@nspcc.ru>
This commit is contained in:
parent
f55edbb613
commit
271451dc32
3 changed files with 70 additions and 140 deletions
|
@ -2,6 +2,7 @@ package downloader
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"archive/zip"
|
"archive/zip"
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -25,74 +26,13 @@ import (
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type request struct {
|
||||||
detector struct {
|
*fasthttp.RequestCtx
|
||||||
io.Reader
|
log *zap.Logger
|
||||||
err error
|
}
|
||||||
contentType string
|
|
||||||
done chan struct{}
|
|
||||||
data []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
request struct {
|
|
||||||
*fasthttp.RequestCtx
|
|
||||||
log *zap.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
errReader struct {
|
|
||||||
data []byte
|
|
||||||
err error
|
|
||||||
offset int
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
var errObjectNotFound = errors.New("object not found")
|
var errObjectNotFound = errors.New("object not found")
|
||||||
|
|
||||||
func newReader(data []byte, err error) *errReader {
|
|
||||||
return &errReader{data: data, err: err}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *errReader) Read(b []byte) (int, error) {
|
|
||||||
if r.offset >= len(r.data) {
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
n := copy(b, r.data[r.offset:])
|
|
||||||
r.offset += n
|
|
||||||
if r.offset >= len(r.data) {
|
|
||||||
return n, r.err
|
|
||||||
}
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const contentTypeDetectSize = 512
|
|
||||||
|
|
||||||
func newDetector() *detector {
|
|
||||||
return &detector{done: make(chan struct{}), data: make([]byte, contentTypeDetectSize)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *detector) Wait() {
|
|
||||||
<-d.done
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *detector) SetReader(reader io.Reader) {
|
|
||||||
d.Reader = reader
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *detector) Detect() {
|
|
||||||
n, err := d.Reader.Read(d.data)
|
|
||||||
if err != nil && err != io.EOF {
|
|
||||||
d.err = err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
d.data = d.data[:n]
|
|
||||||
d.contentType = http.DetectContentType(d.data)
|
|
||||||
close(d.done)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *detector) MultiReader() io.Reader {
|
|
||||||
return io.MultiReader(newReader(d.data, d.err), d.Reader)
|
|
||||||
}
|
|
||||||
|
|
||||||
func isValidToken(s string) bool {
|
func isValidToken(s string) bool {
|
||||||
for _, c := range s {
|
for _, c := range s {
|
||||||
if c <= ' ' || c > 127 {
|
if c <= ' ' || c > 127 {
|
||||||
|
@ -115,6 +55,35 @@ func isValidValue(s string) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type readCloser struct {
|
||||||
|
io.Reader
|
||||||
|
io.Closer
|
||||||
|
}
|
||||||
|
|
||||||
|
// initializes io.Reader with limited size and detects Content-Type from it.
|
||||||
|
// Returns r's error directly. Also returns processed data.
|
||||||
|
func readContentType(maxSize uint64, rInit func(uint64) (io.Reader, error)) (string, []byte, error) {
|
||||||
|
if maxSize > sizeToDetectType {
|
||||||
|
maxSize = sizeToDetectType
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, maxSize) // maybe sync-pool the slice?
|
||||||
|
|
||||||
|
r, err := rInit(maxSize)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := r.Read(buf)
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = buf[:n]
|
||||||
|
|
||||||
|
return http.DetectContentType(buf), buf, err // to not lose io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
func (r request) receiveFile(clnt pool.Object, objectAddress *address.Address) {
|
func (r request) receiveFile(clnt pool.Object, objectAddress *address.Address) {
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
|
@ -140,12 +109,9 @@ func (r request) receiveFile(clnt pool.Object, objectAddress *address.Address) {
|
||||||
dis = "attachment"
|
dis = "attachment"
|
||||||
}
|
}
|
||||||
|
|
||||||
readDetector := newDetector()
|
payloadSize := rObj.Header.PayloadSize()
|
||||||
readDetector.SetReader(rObj.Payload)
|
|
||||||
readDetector.Detect()
|
|
||||||
|
|
||||||
r.Response.SetBodyStream(readDetector.MultiReader(), int(rObj.Header.PayloadSize()))
|
r.Response.Header.Set(fasthttp.HeaderContentLength, strconv.FormatUint(payloadSize, 10))
|
||||||
r.Response.Header.Set(fasthttp.HeaderContentLength, strconv.FormatUint(rObj.Header.PayloadSize(), 10))
|
|
||||||
var contentType string
|
var contentType string
|
||||||
for _, attr := range rObj.Header.Attributes() {
|
for _, attr := range rObj.Header.Attributes() {
|
||||||
key := attr.Key()
|
key := attr.Key()
|
||||||
|
@ -179,17 +145,34 @@ func (r request) receiveFile(clnt pool.Object, objectAddress *address.Address) {
|
||||||
idsToResponse(&r.Response, &rObj.Header)
|
idsToResponse(&r.Response, &rObj.Header)
|
||||||
|
|
||||||
if len(contentType) == 0 {
|
if len(contentType) == 0 {
|
||||||
if readDetector.err != nil {
|
// determine the Content-Type from the payload head
|
||||||
r.log.Error("could not read object", zap.Error(err))
|
var payloadHead []byte
|
||||||
response.Error(r.RequestCtx, "could not read object", fasthttp.StatusBadRequest)
|
|
||||||
|
contentType, payloadHead, err = readContentType(payloadSize, func(uint64) (io.Reader, error) {
|
||||||
|
return rObj.Payload, nil
|
||||||
|
})
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
r.log.Error("could not detect Content-Type from payload", zap.Error(err))
|
||||||
|
response.Error(r.RequestCtx, "could not detect Content-Type from payload", fasthttp.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
readDetector.Wait()
|
|
||||||
contentType = readDetector.contentType
|
// reset payload reader since part of the data has been read
|
||||||
|
var r io.Reader = bytes.NewReader(payloadHead)
|
||||||
|
|
||||||
|
if err != io.EOF { // otherwise, we've already read full payload
|
||||||
|
r = io.MultiReader(r, rObj.Payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
// note: we could do with io.Reader, but SetBodyStream below closes body stream
|
||||||
|
// if it implements io.Closer and that's useful for us.
|
||||||
|
rObj.Payload = readCloser{r, rObj.Payload}
|
||||||
}
|
}
|
||||||
r.SetContentType(contentType)
|
r.SetContentType(contentType)
|
||||||
|
|
||||||
r.Response.Header.Set(fasthttp.HeaderContentDisposition, dis+"; filename="+path.Base(filename))
|
r.Response.Header.Set(fasthttp.HeaderContentDisposition, dis+"; filename="+path.Base(filename))
|
||||||
|
|
||||||
|
r.Response.SetBodyStream(rObj.Payload, int(payloadSize))
|
||||||
}
|
}
|
||||||
|
|
||||||
// systemBackwardTranslator is used to convert headers looking like '__NEOFS__ATTR_NAME' to 'Neofs-Attr-Name'.
|
// systemBackwardTranslator is used to convert headers looking like '__NEOFS__ATTR_NAME' to 'Neofs-Attr-Name'.
|
||||||
|
|
|
@ -16,6 +16,7 @@ import (
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// max bytes needed to detect content type according to http.DetectContentType docs.
|
||||||
const sizeToDetectType = 512
|
const sizeToDetectType = 512
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -67,27 +68,13 @@ func (r request) headObject(clnt pool.Object, objectAddress *address.Address) {
|
||||||
idsToResponse(&r.Response, obj)
|
idsToResponse(&r.Response, obj)
|
||||||
|
|
||||||
if len(contentType) == 0 {
|
if len(contentType) == 0 {
|
||||||
sz := obj.PayloadSize()
|
contentType, _, err = readContentType(obj.PayloadSize(), func(sz uint64) (io.Reader, error) {
|
||||||
if sz > sizeToDetectType {
|
return clnt.ObjectRange(r.RequestCtx, *objectAddress, 0, sz, bearerOpt)
|
||||||
sz = sizeToDetectType
|
})
|
||||||
}
|
if err != nil && err != io.EOF {
|
||||||
|
|
||||||
res, err := clnt.ObjectRange(r.RequestCtx, *objectAddress, 0, sz, bearerOpt)
|
|
||||||
if err != nil {
|
|
||||||
r.handleNeoFSErr(err, start)
|
r.handleNeoFSErr(err, start)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer res.Close()
|
|
||||||
|
|
||||||
data := make([]byte, sz) // sync-pool it?
|
|
||||||
|
|
||||||
_, err = io.ReadFull(res, data)
|
|
||||||
if err != nil {
|
|
||||||
r.handleNeoFSErr(err, start)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
contentType = http.DetectContentType(data)
|
|
||||||
}
|
}
|
||||||
r.SetContentType(contentType)
|
r.SetContentType(contentType)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package downloader
|
package downloader
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -10,40 +8,6 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestReader(t *testing.T) {
|
|
||||||
data := []byte("test string")
|
|
||||||
err := fmt.Errorf("something wrong")
|
|
||||||
|
|
||||||
for _, tc := range []struct {
|
|
||||||
err error
|
|
||||||
buff []byte
|
|
||||||
}{
|
|
||||||
{err: nil, buff: make([]byte, len(data)+1)},
|
|
||||||
{err: nil, buff: make([]byte, len(data))},
|
|
||||||
{err: nil, buff: make([]byte, len(data)-1)},
|
|
||||||
{err: err, buff: make([]byte, len(data)+1)},
|
|
||||||
{err: err, buff: make([]byte, len(data))},
|
|
||||||
{err: err, buff: make([]byte, len(data)-1)},
|
|
||||||
} {
|
|
||||||
var res []byte
|
|
||||||
var err error
|
|
||||||
var n int
|
|
||||||
|
|
||||||
r := newReader(data, tc.err)
|
|
||||||
for err == nil {
|
|
||||||
n, err = r.Read(tc.buff)
|
|
||||||
res = append(res, tc.buff[:n]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tc.err == nil {
|
|
||||||
require.Equal(t, io.EOF, err)
|
|
||||||
} else {
|
|
||||||
require.Equal(t, tc.err, err)
|
|
||||||
}
|
|
||||||
require.Equal(t, data, res)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDetector(t *testing.T) {
|
func TestDetector(t *testing.T) {
|
||||||
txtContentType := "text/plain; charset=utf-8"
|
txtContentType := "text/plain; charset=utf-8"
|
||||||
sb := strings.Builder{}
|
sb := strings.Builder{}
|
||||||
|
@ -68,19 +32,15 @@ func TestDetector(t *testing.T) {
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
detector := newDetector()
|
contentType, data, err := readContentType(uint64(len(tc.Expected)),
|
||||||
|
func(sz uint64) (io.Reader, error) {
|
||||||
|
return strings.NewReader(tc.Expected), nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
go func() {
|
|
||||||
detector.SetReader(bytes.NewBufferString(tc.Expected))
|
|
||||||
detector.Detect()
|
|
||||||
}()
|
|
||||||
|
|
||||||
detector.Wait()
|
|
||||||
require.Equal(t, tc.ContentType, detector.contentType)
|
|
||||||
|
|
||||||
data, err := io.ReadAll(detector.MultiReader())
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, tc.Expected, string(data))
|
require.Equal(t, tc.ContentType, contentType)
|
||||||
|
require.True(t, strings.HasPrefix(tc.Expected, string(data)))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue