diff --git a/api/handler/object_list.go b/api/handler/object_list.go index 359e45c..a4b58ce 100644 --- a/api/handler/object_list.go +++ b/api/handler/object_list.go @@ -2,9 +2,11 @@ package handler import ( "net/http" + "net/url" "strconv" "time" + "github.com/nspcc-dev/neofs-api-go/pkg/object" "github.com/nspcc-dev/neofs-s3-gw/api" "github.com/nspcc-dev/neofs-s3-gw/api/layer" ) @@ -37,9 +39,8 @@ func encodeV1(p *layer.ListObjectsParamsV1, list *layer.ListObjectsInfoV1) *List Prefix: p.Prefix, MaxKeys: p.MaxKeys, Delimiter: p.Delimiter, - - IsTruncated: list.IsTruncated, - NextMarker: list.NextMarker, + IsTruncated: list.IsTruncated, + NextMarker: list.NextMarker, } res.CommonPrefixes = fillPrefixes(list.Prefixes, p.Encode) @@ -71,16 +72,14 @@ func (h *handler) ListObjectsV2Handler(w http.ResponseWriter, r *http.Request) { func encodeV2(p *layer.ListObjectsParamsV2, list *layer.ListObjectsInfoV2) *ListObjectsV2Response { res := &ListObjectsV2Response{ - Name: p.Bucket, - EncodingType: p.Encode, - Prefix: s3PathEncode(p.Prefix, p.Encode), - KeyCount: len(list.Objects) + len(list.Prefixes), - MaxKeys: p.MaxKeys, - Delimiter: s3PathEncode(p.Delimiter, p.Encode), - StartAfter: s3PathEncode(p.StartAfter, p.Encode), - - IsTruncated: list.IsTruncated, - + Name: p.Bucket, + EncodingType: p.Encode, + Prefix: s3PathEncode(p.Prefix, p.Encode), + KeyCount: len(list.Objects) + len(list.Prefixes), + MaxKeys: p.MaxKeys, + Delimiter: s3PathEncode(p.Delimiter, p.Encode), + StartAfter: s3PathEncode(p.StartAfter, p.Encode), + IsTruncated: list.IsTruncated, ContinuationToken: p.ContinuationToken, NextContinuationToken: list.NextContinuationToken, } @@ -94,8 +93,9 @@ func encodeV2(p *layer.ListObjectsParamsV2, list *layer.ListObjectsInfoV2) *List func parseListObjectsArgsV1(r *http.Request) (*layer.ListObjectsParamsV1, error) { var ( - err error - res layer.ListObjectsParamsV1 + err error + res layer.ListObjectsParamsV1 + queryValues = r.URL.Query() ) common, err := parseListObjectArgs(r) @@ -104,15 +104,16 @@ func parseListObjectsArgsV1(r *http.Request) (*layer.ListObjectsParamsV1, error) } res.ListObjectsParamsCommon = *common - res.Marker = r.URL.Query().Get("marker") + res.Marker = queryValues.Get("marker") return &res, nil } func parseListObjectsArgsV2(r *http.Request) (*layer.ListObjectsParamsV2, error) { var ( - err error - res layer.ListObjectsParamsV2 + err error + res layer.ListObjectsParamsV2 + queryValues = r.URL.Query() ) common, err := parseListObjectArgs(r) @@ -121,36 +122,51 @@ func parseListObjectsArgsV2(r *http.Request) (*layer.ListObjectsParamsV2, error) } res.ListObjectsParamsCommon = *common - res.ContinuationToken = r.URL.Query().Get("continuation-token") - res.StartAfter = r.URL.Query().Get("start-after") - res.FetchOwner, _ = strconv.ParseBool(r.URL.Query().Get("fetch-owner")) + res.ContinuationToken, err = parseContinuationToken(queryValues) + if err != nil { + return nil, err + } + + res.StartAfter = queryValues.Get("start-after") + res.FetchOwner, _ = strconv.ParseBool(queryValues.Get("fetch-owner")) return &res, nil } func parseListObjectArgs(r *http.Request) (*layer.ListObjectsParamsCommon, error) { var ( - err error - res layer.ListObjectsParamsCommon + err error + res layer.ListObjectsParamsCommon + queryValues = r.URL.Query() ) if info := api.GetReqInfo(r.Context()); info != nil { res.Bucket = info.BucketName } - res.Delimiter = r.URL.Query().Get("delimiter") - res.Encode = r.URL.Query().Get("encoding-type") + res.Delimiter = queryValues.Get("delimiter") + res.Encode = queryValues.Get("encoding-type") - if r.URL.Query().Get("max-keys") == "" { + if queryValues.Get("max-keys") == "" { res.MaxKeys = maxObjectList - } else if res.MaxKeys, err = strconv.Atoi(r.URL.Query().Get("max-keys")); err != nil || res.MaxKeys < 0 { + } else if res.MaxKeys, err = strconv.Atoi(queryValues.Get("max-keys")); err != nil || res.MaxKeys < 0 { return nil, api.GetAPIError(api.ErrInvalidMaxKeys) } - res.Prefix = r.URL.Query().Get("prefix") + res.Prefix = queryValues.Get("prefix") return &res, nil } +func parseContinuationToken(queryValues url.Values) (string, error) { + if val, ok := queryValues["continuation-token"]; ok { + if err := object.NewID().Parse(val[0]); err != nil { + return "", api.GetAPIError(api.ErrIncorrectContinuationToken) + } + return val[0], nil + } + return "", nil +} + func fillPrefixes(src []string, encode string) []CommonPrefix { var dst []CommonPrefix for _, obj := range src { @@ -208,21 +224,22 @@ func (h *handler) ListBucketObjectVersionsHandler(w http.ResponseWriter, r *http func parseListObjectVersionsRequest(r *http.Request) (*layer.ListObjectVersionsParams, error) { var ( - err error - res layer.ListObjectVersionsParams + err error + res layer.ListObjectVersionsParams + queryValues = r.URL.Query() ) - if r.URL.Query().Get("max-keys") == "" { + if queryValues.Get("max-keys") == "" { res.MaxKeys = maxObjectList - } else if res.MaxKeys, err = strconv.Atoi(r.URL.Query().Get("max-keys")); err != nil || res.MaxKeys <= 0 { + } else if res.MaxKeys, err = strconv.Atoi(queryValues.Get("max-keys")); err != nil || res.MaxKeys <= 0 { return nil, api.GetAPIError(api.ErrInvalidMaxKeys) } - res.Prefix = r.URL.Query().Get("prefix") - res.KeyMarker = r.URL.Query().Get("marker") - res.Delimiter = r.URL.Query().Get("delimiter") - res.Encode = r.URL.Query().Get("encoding-type") - res.VersionIDMarker = r.URL.Query().Get("version-id-marker") + res.Prefix = queryValues.Get("prefix") + res.KeyMarker = queryValues.Get("marker") + res.Delimiter = queryValues.Get("delimiter") + res.Encode = queryValues.Get("encoding-type") + res.VersionIDMarker = queryValues.Get("version-id-marker") if info := api.GetReqInfo(r.Context()); info != nil { res.Bucket = info.BucketName diff --git a/api/handler/object_list_test.go b/api/handler/object_list_test.go new file mode 100644 index 0000000..7b6a2f8 --- /dev/null +++ b/api/handler/object_list_test.go @@ -0,0 +1,37 @@ +package handler + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseContinuationToken(t *testing.T) { + var err error + + t.Run("empty token", func(t *testing.T) { + var queryValues = map[string][]string{ + "continuation-token": {""}, + } + _, err = parseContinuationToken(queryValues) + require.Error(t, err) + }) + + t.Run("invalid not empty token", func(t *testing.T) { + var queryValues = map[string][]string{ + "continuation-token": {"asd"}, + } + _, err = parseContinuationToken(queryValues) + require.Error(t, err) + }) + + t.Run("valid token", func(t *testing.T) { + tokenStr := "75BTT5Z9o79XuKdUeGqvQbqDnxu6qWcR5EhxW8BXFf8t" + var queryValues = map[string][]string{ + "continuation-token": {tokenStr}, + } + token, err := parseContinuationToken(queryValues) + require.NoError(t, err) + require.Equal(t, tokenStr, token) + }) +}