Merge pull request #182 from masterSplinter01/feature/180-add-ct-check

Add check of continuation token
This commit is contained in:
Alex Vanin 2021-08-04 14:22:17 +03:00 committed by GitHub
commit 6674e350cc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 91 additions and 37 deletions

View file

@ -2,9 +2,11 @@ package handler
import (
"net/http"
"net/url"
"strconv"
"time"
"github.com/nspcc-dev/neofs-api-go/pkg/object"
"github.com/nspcc-dev/neofs-s3-gw/api"
"github.com/nspcc-dev/neofs-s3-gw/api/layer"
)
@ -37,9 +39,8 @@ func encodeV1(p *layer.ListObjectsParamsV1, list *layer.ListObjectsInfoV1) *List
Prefix: p.Prefix,
MaxKeys: p.MaxKeys,
Delimiter: p.Delimiter,
IsTruncated: list.IsTruncated,
NextMarker: list.NextMarker,
IsTruncated: list.IsTruncated,
NextMarker: list.NextMarker,
}
res.CommonPrefixes = fillPrefixes(list.Prefixes, p.Encode)
@ -71,16 +72,14 @@ func (h *handler) ListObjectsV2Handler(w http.ResponseWriter, r *http.Request) {
func encodeV2(p *layer.ListObjectsParamsV2, list *layer.ListObjectsInfoV2) *ListObjectsV2Response {
res := &ListObjectsV2Response{
Name: p.Bucket,
EncodingType: p.Encode,
Prefix: s3PathEncode(p.Prefix, p.Encode),
KeyCount: len(list.Objects) + len(list.Prefixes),
MaxKeys: p.MaxKeys,
Delimiter: s3PathEncode(p.Delimiter, p.Encode),
StartAfter: s3PathEncode(p.StartAfter, p.Encode),
IsTruncated: list.IsTruncated,
Name: p.Bucket,
EncodingType: p.Encode,
Prefix: s3PathEncode(p.Prefix, p.Encode),
KeyCount: len(list.Objects) + len(list.Prefixes),
MaxKeys: p.MaxKeys,
Delimiter: s3PathEncode(p.Delimiter, p.Encode),
StartAfter: s3PathEncode(p.StartAfter, p.Encode),
IsTruncated: list.IsTruncated,
ContinuationToken: p.ContinuationToken,
NextContinuationToken: list.NextContinuationToken,
}
@ -94,8 +93,9 @@ func encodeV2(p *layer.ListObjectsParamsV2, list *layer.ListObjectsInfoV2) *List
func parseListObjectsArgsV1(r *http.Request) (*layer.ListObjectsParamsV1, error) {
var (
err error
res layer.ListObjectsParamsV1
err error
res layer.ListObjectsParamsV1
queryValues = r.URL.Query()
)
common, err := parseListObjectArgs(r)
@ -104,15 +104,16 @@ func parseListObjectsArgsV1(r *http.Request) (*layer.ListObjectsParamsV1, error)
}
res.ListObjectsParamsCommon = *common
res.Marker = r.URL.Query().Get("marker")
res.Marker = queryValues.Get("marker")
return &res, nil
}
func parseListObjectsArgsV2(r *http.Request) (*layer.ListObjectsParamsV2, error) {
var (
err error
res layer.ListObjectsParamsV2
err error
res layer.ListObjectsParamsV2
queryValues = r.URL.Query()
)
common, err := parseListObjectArgs(r)
@ -121,36 +122,51 @@ func parseListObjectsArgsV2(r *http.Request) (*layer.ListObjectsParamsV2, error)
}
res.ListObjectsParamsCommon = *common
res.ContinuationToken = r.URL.Query().Get("continuation-token")
res.StartAfter = r.URL.Query().Get("start-after")
res.FetchOwner, _ = strconv.ParseBool(r.URL.Query().Get("fetch-owner"))
res.ContinuationToken, err = parseContinuationToken(queryValues)
if err != nil {
return nil, err
}
res.StartAfter = queryValues.Get("start-after")
res.FetchOwner, _ = strconv.ParseBool(queryValues.Get("fetch-owner"))
return &res, nil
}
func parseListObjectArgs(r *http.Request) (*layer.ListObjectsParamsCommon, error) {
var (
err error
res layer.ListObjectsParamsCommon
err error
res layer.ListObjectsParamsCommon
queryValues = r.URL.Query()
)
if info := api.GetReqInfo(r.Context()); info != nil {
res.Bucket = info.BucketName
}
res.Delimiter = r.URL.Query().Get("delimiter")
res.Encode = r.URL.Query().Get("encoding-type")
res.Delimiter = queryValues.Get("delimiter")
res.Encode = queryValues.Get("encoding-type")
if r.URL.Query().Get("max-keys") == "" {
if queryValues.Get("max-keys") == "" {
res.MaxKeys = maxObjectList
} else if res.MaxKeys, err = strconv.Atoi(r.URL.Query().Get("max-keys")); err != nil || res.MaxKeys < 0 {
} else if res.MaxKeys, err = strconv.Atoi(queryValues.Get("max-keys")); err != nil || res.MaxKeys < 0 {
return nil, api.GetAPIError(api.ErrInvalidMaxKeys)
}
res.Prefix = r.URL.Query().Get("prefix")
res.Prefix = queryValues.Get("prefix")
return &res, nil
}
func parseContinuationToken(queryValues url.Values) (string, error) {
if val, ok := queryValues["continuation-token"]; ok {
if err := object.NewID().Parse(val[0]); err != nil {
return "", api.GetAPIError(api.ErrIncorrectContinuationToken)
}
return val[0], nil
}
return "", nil
}
func fillPrefixes(src []string, encode string) []CommonPrefix {
var dst []CommonPrefix
for _, obj := range src {
@ -208,21 +224,22 @@ func (h *handler) ListBucketObjectVersionsHandler(w http.ResponseWriter, r *http
func parseListObjectVersionsRequest(r *http.Request) (*layer.ListObjectVersionsParams, error) {
var (
err error
res layer.ListObjectVersionsParams
err error
res layer.ListObjectVersionsParams
queryValues = r.URL.Query()
)
if r.URL.Query().Get("max-keys") == "" {
if queryValues.Get("max-keys") == "" {
res.MaxKeys = maxObjectList
} else if res.MaxKeys, err = strconv.Atoi(r.URL.Query().Get("max-keys")); err != nil || res.MaxKeys <= 0 {
} else if res.MaxKeys, err = strconv.Atoi(queryValues.Get("max-keys")); err != nil || res.MaxKeys <= 0 {
return nil, api.GetAPIError(api.ErrInvalidMaxKeys)
}
res.Prefix = r.URL.Query().Get("prefix")
res.KeyMarker = r.URL.Query().Get("marker")
res.Delimiter = r.URL.Query().Get("delimiter")
res.Encode = r.URL.Query().Get("encoding-type")
res.VersionIDMarker = r.URL.Query().Get("version-id-marker")
res.Prefix = queryValues.Get("prefix")
res.KeyMarker = queryValues.Get("marker")
res.Delimiter = queryValues.Get("delimiter")
res.Encode = queryValues.Get("encoding-type")
res.VersionIDMarker = queryValues.Get("version-id-marker")
if info := api.GetReqInfo(r.Context()); info != nil {
res.Bucket = info.BucketName

View file

@ -0,0 +1,37 @@
package handler
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestParseContinuationToken(t *testing.T) {
var err error
t.Run("empty token", func(t *testing.T) {
var queryValues = map[string][]string{
"continuation-token": {""},
}
_, err = parseContinuationToken(queryValues)
require.Error(t, err)
})
t.Run("invalid not empty token", func(t *testing.T) {
var queryValues = map[string][]string{
"continuation-token": {"asd"},
}
_, err = parseContinuationToken(queryValues)
require.Error(t, err)
})
t.Run("valid token", func(t *testing.T) {
tokenStr := "75BTT5Z9o79XuKdUeGqvQbqDnxu6qWcR5EhxW8BXFf8t"
var queryValues = map[string][]string{
"continuation-token": {tokenStr},
}
token, err := parseContinuationToken(queryValues)
require.NoError(t, err)
require.Equal(t, tokenStr, token)
})
}