Merge pull request #182 from masterSplinter01/feature/180-add-ct-check

Add check of continuation token
This commit is contained in:
Alex Vanin 2021-08-04 14:22:17 +03:00 committed by GitHub
commit 6674e350cc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 91 additions and 37 deletions

View file

@ -2,9 +2,11 @@ package handler
import ( import (
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"time" "time"
"github.com/nspcc-dev/neofs-api-go/pkg/object"
"github.com/nspcc-dev/neofs-s3-gw/api" "github.com/nspcc-dev/neofs-s3-gw/api"
"github.com/nspcc-dev/neofs-s3-gw/api/layer" "github.com/nspcc-dev/neofs-s3-gw/api/layer"
) )
@ -37,7 +39,6 @@ func encodeV1(p *layer.ListObjectsParamsV1, list *layer.ListObjectsInfoV1) *List
Prefix: p.Prefix, Prefix: p.Prefix,
MaxKeys: p.MaxKeys, MaxKeys: p.MaxKeys,
Delimiter: p.Delimiter, Delimiter: p.Delimiter,
IsTruncated: list.IsTruncated, IsTruncated: list.IsTruncated,
NextMarker: list.NextMarker, NextMarker: list.NextMarker,
} }
@ -78,9 +79,7 @@ func encodeV2(p *layer.ListObjectsParamsV2, list *layer.ListObjectsInfoV2) *List
MaxKeys: p.MaxKeys, MaxKeys: p.MaxKeys,
Delimiter: s3PathEncode(p.Delimiter, p.Encode), Delimiter: s3PathEncode(p.Delimiter, p.Encode),
StartAfter: s3PathEncode(p.StartAfter, p.Encode), StartAfter: s3PathEncode(p.StartAfter, p.Encode),
IsTruncated: list.IsTruncated, IsTruncated: list.IsTruncated,
ContinuationToken: p.ContinuationToken, ContinuationToken: p.ContinuationToken,
NextContinuationToken: list.NextContinuationToken, NextContinuationToken: list.NextContinuationToken,
} }
@ -96,6 +95,7 @@ func parseListObjectsArgsV1(r *http.Request) (*layer.ListObjectsParamsV1, error)
var ( var (
err error err error
res layer.ListObjectsParamsV1 res layer.ListObjectsParamsV1
queryValues = r.URL.Query()
) )
common, err := parseListObjectArgs(r) common, err := parseListObjectArgs(r)
@ -104,7 +104,7 @@ func parseListObjectsArgsV1(r *http.Request) (*layer.ListObjectsParamsV1, error)
} }
res.ListObjectsParamsCommon = *common res.ListObjectsParamsCommon = *common
res.Marker = r.URL.Query().Get("marker") res.Marker = queryValues.Get("marker")
return &res, nil return &res, nil
} }
@ -113,6 +113,7 @@ func parseListObjectsArgsV2(r *http.Request) (*layer.ListObjectsParamsV2, error)
var ( var (
err error err error
res layer.ListObjectsParamsV2 res layer.ListObjectsParamsV2
queryValues = r.URL.Query()
) )
common, err := parseListObjectArgs(r) common, err := parseListObjectArgs(r)
@ -121,9 +122,13 @@ func parseListObjectsArgsV2(r *http.Request) (*layer.ListObjectsParamsV2, error)
} }
res.ListObjectsParamsCommon = *common res.ListObjectsParamsCommon = *common
res.ContinuationToken = r.URL.Query().Get("continuation-token") res.ContinuationToken, err = parseContinuationToken(queryValues)
res.StartAfter = r.URL.Query().Get("start-after") if err != nil {
res.FetchOwner, _ = strconv.ParseBool(r.URL.Query().Get("fetch-owner")) return nil, err
}
res.StartAfter = queryValues.Get("start-after")
res.FetchOwner, _ = strconv.ParseBool(queryValues.Get("fetch-owner"))
return &res, nil return &res, nil
} }
@ -131,26 +136,37 @@ func parseListObjectArgs(r *http.Request) (*layer.ListObjectsParamsCommon, error
var ( var (
err error err error
res layer.ListObjectsParamsCommon res layer.ListObjectsParamsCommon
queryValues = r.URL.Query()
) )
if info := api.GetReqInfo(r.Context()); info != nil { if info := api.GetReqInfo(r.Context()); info != nil {
res.Bucket = info.BucketName res.Bucket = info.BucketName
} }
res.Delimiter = r.URL.Query().Get("delimiter") res.Delimiter = queryValues.Get("delimiter")
res.Encode = r.URL.Query().Get("encoding-type") res.Encode = queryValues.Get("encoding-type")
if r.URL.Query().Get("max-keys") == "" { if queryValues.Get("max-keys") == "" {
res.MaxKeys = maxObjectList res.MaxKeys = maxObjectList
} else if res.MaxKeys, err = strconv.Atoi(r.URL.Query().Get("max-keys")); err != nil || res.MaxKeys < 0 { } else if res.MaxKeys, err = strconv.Atoi(queryValues.Get("max-keys")); err != nil || res.MaxKeys < 0 {
return nil, api.GetAPIError(api.ErrInvalidMaxKeys) return nil, api.GetAPIError(api.ErrInvalidMaxKeys)
} }
res.Prefix = r.URL.Query().Get("prefix") res.Prefix = queryValues.Get("prefix")
return &res, nil return &res, nil
} }
func parseContinuationToken(queryValues url.Values) (string, error) {
if val, ok := queryValues["continuation-token"]; ok {
if err := object.NewID().Parse(val[0]); err != nil {
return "", api.GetAPIError(api.ErrIncorrectContinuationToken)
}
return val[0], nil
}
return "", nil
}
func fillPrefixes(src []string, encode string) []CommonPrefix { func fillPrefixes(src []string, encode string) []CommonPrefix {
var dst []CommonPrefix var dst []CommonPrefix
for _, obj := range src { for _, obj := range src {
@ -210,19 +226,20 @@ func parseListObjectVersionsRequest(r *http.Request) (*layer.ListObjectVersionsP
var ( var (
err error err error
res layer.ListObjectVersionsParams res layer.ListObjectVersionsParams
queryValues = r.URL.Query()
) )
if r.URL.Query().Get("max-keys") == "" { if queryValues.Get("max-keys") == "" {
res.MaxKeys = maxObjectList res.MaxKeys = maxObjectList
} else if res.MaxKeys, err = strconv.Atoi(r.URL.Query().Get("max-keys")); err != nil || res.MaxKeys <= 0 { } else if res.MaxKeys, err = strconv.Atoi(queryValues.Get("max-keys")); err != nil || res.MaxKeys <= 0 {
return nil, api.GetAPIError(api.ErrInvalidMaxKeys) return nil, api.GetAPIError(api.ErrInvalidMaxKeys)
} }
res.Prefix = r.URL.Query().Get("prefix") res.Prefix = queryValues.Get("prefix")
res.KeyMarker = r.URL.Query().Get("marker") res.KeyMarker = queryValues.Get("marker")
res.Delimiter = r.URL.Query().Get("delimiter") res.Delimiter = queryValues.Get("delimiter")
res.Encode = r.URL.Query().Get("encoding-type") res.Encode = queryValues.Get("encoding-type")
res.VersionIDMarker = r.URL.Query().Get("version-id-marker") res.VersionIDMarker = queryValues.Get("version-id-marker")
if info := api.GetReqInfo(r.Context()); info != nil { if info := api.GetReqInfo(r.Context()); info != nil {
res.Bucket = info.BucketName res.Bucket = info.BucketName

View file

@ -0,0 +1,37 @@
package handler
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestParseContinuationToken(t *testing.T) {
var err error
t.Run("empty token", func(t *testing.T) {
var queryValues = map[string][]string{
"continuation-token": {""},
}
_, err = parseContinuationToken(queryValues)
require.Error(t, err)
})
t.Run("invalid not empty token", func(t *testing.T) {
var queryValues = map[string][]string{
"continuation-token": {"asd"},
}
_, err = parseContinuationToken(queryValues)
require.Error(t, err)
})
t.Run("valid token", func(t *testing.T) {
tokenStr := "75BTT5Z9o79XuKdUeGqvQbqDnxu6qWcR5EhxW8BXFf8t"
var queryValues = map[string][]string{
"continuation-token": {tokenStr},
}
token, err := parseContinuationToken(queryValues)
require.NoError(t, err)
require.Equal(t, tokenStr, token)
})
}