Compare commits

..

6 commits

Author SHA1 Message Date
85a81d2f79 [#339] lint: Ignore aws sdk dirs
Some checks failed
/ DCO (pull_request) Successful in 1m8s
/ Vulncheck (pull_request) Successful in 1m11s
/ Builds (pull_request) Failing after 1m9s
/ Lint (pull_request) Successful in 2m24s
/ Tests (pull_request) Successful in 1m25s
Signed-off-by: Denis Kirillov <d.kirillov@yadro.com>
2024-09-05 15:48:33 +03:00
be2109bfc9 [#339] Presign fix aws sdk
Signed-off-by: Denis Kirillov <d.kirillov@yadro.com>
2024-09-05 15:48:33 +03:00
0397bf83e7 [#339] Drop aws-sdk-go v1
Signed-off-by: Denis Kirillov <d.kirillov@yadro.com>
2024-09-05 15:48:33 +03:00
81a1f9a959 [#339] Don't include additional content-length header for signing
Signed-off-by: Denis Kirillov <d.kirillov@yadro.com>
2024-09-05 15:48:33 +03:00
f09229c76e [#339] sigv4a: Support presign
Signed-off-by: Denis Kirillov <d.kirillov@yadro.com>
2024-09-05 15:48:33 +03:00
a692cb87fb [#339] Add aws-sdk-go-v2
Signed-off-by: Denis Kirillov <d.kirillov@yadro.com>
2024-09-05 15:48:33 +03:00
185 changed files with 7430 additions and 6141 deletions

View file

@ -3,12 +3,11 @@ FROM golang:1.22 AS builder
ARG BUILD=now
ARG REPO=git.frostfs.info/TrueCloudLab/frostfs-s3-gw
ARG VERSION=dev
ARG GOFLAGS=""
WORKDIR /src
COPY . /src
RUN make GOFLAGS=${GOFLAGS}
RUN make
# Executable image
FROM alpine AS frostfs-s3-gw

View file

@ -66,3 +66,6 @@ issues:
- EXC0003 # test/Test ... consider calling this
- EXC0004 # govet
- EXC0005 # C-style breaks
exclude-dirs:
- api/auth/signer/v4asdk2
- api/auth/signer/v4sdk2

View file

@ -8,13 +8,9 @@ This document outlines major changes between releases.
- Add support for virtual hosted style addressing (#446, #449)
- Support new param `frostfs.graceful_close_on_switch_timeout` (#475)
- Support patch object method (#479)
- Add `sign` command to `frostfs-s3-authmate` (#467)
- Support custom aws credentials (#509)
- Multinet dial support (#521)
### Changed
- Update go version to go1.19 (#470)
- Avoid maintenance mode storage node during object operations (#524)
## [0.30.0] - Kangshung -2024-07-19

View file

@ -14,8 +14,6 @@ METRICS_DUMP_OUT ?= ./metrics-dump.json
CMDS = $(addprefix frostfs-, $(notdir $(wildcard cmd/*)))
BINS = $(addprefix $(BINDIR)/, $(CMDS))
GOFLAGS ?=
# Variables for docker
REPO_BASENAME = $(shell basename `go list -m`)
HUB_IMAGE ?= "truecloudlab/$(REPO_BASENAME)"
@ -46,7 +44,6 @@ all: $(BINS)
$(BINS): $(BINDIR) dep
@echo "⇒ Build $@"
CGO_ENABLED=0 \
GOFLAGS=$(GOFLAGS) \
go build -v -trimpath \
-ldflags "-X $(REPO)/internal/version.Version=$(VERSION)" \
-o $@ ./cmd/$(subst frostfs-,,$(notdir $@))
@ -73,7 +70,7 @@ docker/%:
-w /src \
-u `stat -c "%u:%g" .` \
--env HOME=/src \
golang:$(GO_VERSION) make GOFLAGS=$(GOFLAGS) $*,\
golang:$(GO_VERSION) make $*,\
@echo "supported docker targets: all $(BINS) lint")
# Run tests
@ -124,7 +121,6 @@ image:
@docker build \
--build-arg REPO=$(REPO) \
--build-arg VERSION=$(VERSION) \
--build-arg GOFLAGS=$(GOFLAGS) \
--rm \
-f .docker/Dockerfile \
-t $(HUB_IMAGE):$(HUB_TAG) .

View file

@ -9,37 +9,40 @@ import (
"io"
"mime/multipart"
"net/http"
"net/url"
"regexp"
"strings"
"time"
v4 "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
v4a "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4asdk2"
v4 "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4sdk2/signer/v4"
apiErrors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/tokens"
cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
)
// AuthorizationFieldRegexp -- is regexp for credentials with Base58 encoded cid and oid and '0' (zero) as delimiter.
var AuthorizationFieldRegexp = regexp.MustCompile(`AWS4-HMAC-SHA256 Credential=(?P<access_key_id>[^/]+)/(?P<date>[^/]+)/(?P<region>[^/]*)/(?P<service>[^/]+)/aws4_request,\s*SignedHeaders=(?P<signed_header_fields>.+),\s*Signature=(?P<v4_signature>.+)`)
var (
// authorizationFieldRegexp -- is regexp for credentials with Base58 encoded cid and oid and '0' (zero) as delimiter.
authorizationFieldRegexp = regexp.MustCompile(`AWS4-HMAC-SHA256 Credential=(?P<access_key_id>[^/]+)/(?P<date>[^/]+)/(?P<region>[^/]*)/(?P<service>[^/]+)/aws4_request,\s*SignedHeaders=(?P<signed_header_fields>.+),\s*Signature=(?P<v4_signature>.+)`)
// authorizationFieldV4aRegexp -- is regexp for credentials with Base58 encoded cid and oid and '0' (zero) as delimiter.
authorizationFieldV4aRegexp = regexp.MustCompile(`AWS4-ECDSA-P256-SHA256 Credential=(?P<access_key_id>[^/]+)/(?P<date>[^/]+)/(?P<service>[^/]+)/aws4_request,\s*SignedHeaders=(?P<signed_header_fields>.+),\s*Signature=(?P<v4_signature>.+)`)
// postPolicyCredentialRegexp -- is regexp for credentials when uploading file using POST with policy.
var postPolicyCredentialRegexp = regexp.MustCompile(`(?P<access_key_id>[^/]+)/(?P<date>[^/]+)/(?P<region>[^/]*)/(?P<service>[^/]+)/aws4_request`)
postPolicyCredentialRegexp = regexp.MustCompile(`(?P<access_key_id>[^/]+)/(?P<date>[^/]+)/(?P<region>[^/]*)/(?P<service>[^/]+)/aws4_request`)
)
type (
Center struct {
reg *RegexpSubmatcher
regV4a *RegexpSubmatcher
postReg *RegexpSubmatcher
cli tokens.Credentials
allowedAccessKeyIDPrefixes []string // empty slice means all access key ids are allowed
settings CenterSettings
}
CenterSettings interface {
AccessBoxContainer() (cid.ID, bool)
}
//nolint:revive
@ -47,22 +50,27 @@ type (
AccessKeyID string
Service string
Region string
SignatureV4 string
Signature string
SignedFields []string
Date string
IsPresigned bool
Expiration time.Duration
Preamble string
PayloadHash string
}
)
const (
accessKeyPartsNum = 2
authHeaderPartsNum = 6
authHeaderV4aPartsNum = 5
maxFormSizeMemory = 50 * 1048576 // 50 MB
AmzAlgorithm = "X-Amz-Algorithm"
AmzCredential = "X-Amz-Credential"
AmzSignature = "X-Amz-Signature"
AmzSignedHeaders = "X-Amz-SignedHeaders"
AmzRegionSet = "X-Amz-Region-Set"
AmzExpires = "X-Amz-Expires"
AmzDate = "X-Amz-Date"
AmzContentSHA256 = "X-Amz-Content-Sha256"
@ -87,34 +95,76 @@ var ContentSHA256HeaderStandardValue = map[string]struct{}{
}
// New creates an instance of AuthCenter.
func New(creds tokens.Credentials, prefixes []string, settings CenterSettings) *Center {
func New(creds tokens.Credentials, prefixes []string) *Center {
return &Center{
cli: creds,
reg: NewRegexpMatcher(AuthorizationFieldRegexp),
reg: NewRegexpMatcher(authorizationFieldRegexp),
regV4a: NewRegexpMatcher(authorizationFieldV4aRegexp),
postReg: NewRegexpMatcher(postPolicyCredentialRegexp),
allowedAccessKeyIDPrefixes: prefixes,
settings: settings,
}
}
func (c *Center) parseAuthHeader(header string) (*AuthHeader, error) {
submatches := c.reg.GetSubmatches(header)
const (
signaturePreambleSigV4 = "AWS4-HMAC-SHA256"
signaturePreambleSigV4A = "AWS4-ECDSA-P256-SHA256"
)
func (c *Center) parseAuthHeader(authHeader string, headers http.Header) (*AuthHeader, error) {
preamble, _, _ := strings.Cut(authHeader, " ")
var (
submatches map[string]string
region string
)
switch preamble {
case signaturePreambleSigV4:
submatches = c.reg.GetSubmatches(authHeader)
if len(submatches) != authHeaderPartsNum {
return nil, fmt.Errorf("%w: %s", apierr.GetAPIError(apierr.ErrAuthorizationHeaderMalformed), header)
return nil, fmt.Errorf("%w: %s", apiErrors.GetAPIError(apiErrors.ErrAuthorizationHeaderMalformed), authHeader)
}
region = submatches["region"]
case signaturePreambleSigV4A:
submatches = c.regV4a.GetSubmatches(authHeader)
if len(submatches) != authHeaderV4aPartsNum {
return nil, fmt.Errorf("%w: %s", apiErrors.GetAPIError(apiErrors.ErrAuthorizationHeaderMalformed), authHeader)
}
region = headers.Get(AmzRegionSet)
default:
return nil, fmt.Errorf("%w: %s", apiErrors.GetAPIError(apiErrors.ErrAuthorizationHeaderMalformed), authHeader)
}
// AWS4-ECDSA-P256-SHA256
// Credential=2XGRML5EW3LMHdf64W2DkBy1Nkuu4y4wGhUj44QjbXBi05ZNvs8WVwy1XTmSEkcVkydPKzCgtmR7U3zyLYTj3Snxf/20240326/s3/aws4_request,
// SignedHeaders=amz-sdk-invocation-id;amz-sdk-request;host;x-amz-content-sha256;x-amz-date;x-amz-region-set,
// Signature=3044022006a2bc760140834101d0a79667d6aa75768c1a28e9cafc8963484d0752a6c6050220629dc06d7d6505e1b1e2a5d1f974b25ba32fdffc6f3f70dc4dda31b8a6f7ea2b
signedFields := strings.Split(submatches["signed_header_fields"], ";")
accessKey := strings.Split(submatches["access_key_id"], "0")
if len(accessKey) != accessKeyPartsNum {
return nil, fmt.Errorf("%w: %s", apiErrors.GetAPIError(apiErrors.ErrInvalidAccessKeyID), accessKey)
}
return &AuthHeader{
AccessKeyID: submatches["access_key_id"],
Service: submatches["service"],
Region: submatches["region"],
SignatureV4: submatches["v4_signature"],
SignedFields: signedFields,
Region: region,
Signature: submatches["v4_signature"],
SignedFields: strings.Split(submatches["signed_header_fields"], ";"),
Date: submatches["date"],
Preamble: preamble,
PayloadHash: headers.Get(AmzContentSHA256),
}, nil
}
func getAddress(accessKeyID string) (oid.Address, error) {
var addr oid.Address
if err := addr.DecodeString(strings.ReplaceAll(accessKeyID, "0", "/")); err != nil {
return addr, fmt.Errorf("%w: %s", apiErrors.GetAPIError(apiErrors.ErrInvalidAccessKeyID), accessKeyID)
}
return addr, nil
}
func IsStandardContentSHA256(key string) bool {
_, ok := ContentSHA256HeaderStandardValue[key]
return ok
@ -129,7 +179,7 @@ func (c *Center) Authenticate(r *http.Request) (*middleware.Box, error) {
)
queryValues := r.URL.Query()
if queryValues.Get(AmzAlgorithm) == "AWS4-HMAC-SHA256" {
if queryValues.Get(AmzAlgorithm) == signaturePreambleSigV4 {
creds := strings.Split(queryValues.Get(AmzCredential), "/")
if len(creds) != 5 || creds[4] != "aws4_request" {
return nil, fmt.Errorf("bad X-Amz-Credential")
@ -138,14 +188,37 @@ func (c *Center) Authenticate(r *http.Request) (*middleware.Box, error) {
AccessKeyID: creds[0],
Service: creds[3],
Region: creds[2],
SignatureV4: queryValues.Get(AmzSignature),
Signature: queryValues.Get(AmzSignature),
SignedFields: strings.Split(queryValues.Get(AmzSignedHeaders), ";"),
Date: creds[1],
IsPresigned: true,
Preamble: signaturePreambleSigV4,
PayloadHash: r.Header.Get(AmzContentSHA256),
}
authHdr.Expiration, err = time.ParseDuration(queryValues.Get(AmzExpires) + "s")
if err != nil {
return nil, fmt.Errorf("couldn't parse X-Amz-Expires: %w", err)
return nil, fmt.Errorf("%w: couldn't parse X-Amz-Expires %v", apiErrors.GetAPIError(apiErrors.ErrMalformedExpires), err)
}
signatureDateTimeStr = queryValues.Get(AmzDate)
} else if queryValues.Get(AmzAlgorithm) == signaturePreambleSigV4A {
creds := strings.Split(queryValues.Get(AmzCredential), "/")
if len(creds) != 4 || creds[3] != "aws4_request" {
return nil, fmt.Errorf("bad X-Amz-Credential")
}
authHdr = &AuthHeader{
AccessKeyID: creds[0],
Service: creds[2],
Region: queryValues.Get(AmzRegionSet),
Signature: queryValues.Get(AmzSignature),
SignedFields: strings.Split(queryValues.Get(AmzSignedHeaders), ";"),
Date: creds[1],
IsPresigned: true,
Preamble: signaturePreambleSigV4A,
PayloadHash: r.Header.Get(AmzContentSHA256),
}
authHdr.Expiration, err = time.ParseDuration(queryValues.Get(AmzExpires) + "s")
if err != nil {
return nil, fmt.Errorf("%w: couldn't parse X-Amz-Expires %v", apiErrors.GetAPIError(apiErrors.ErrMalformedExpires), err)
}
signatureDateTimeStr = queryValues.Get(AmzDate)
} else {
@ -156,7 +229,7 @@ func (c *Center) Authenticate(r *http.Request) (*middleware.Box, error) {
}
return nil, fmt.Errorf("%w: %v", middleware.ErrNoAuthorizationHeader, authHeaderField)
}
authHdr, err = c.parseAuthHeader(authHeaderField[0])
authHdr, err = c.parseAuthHeader(authHeaderField[0], r.Header)
if err != nil {
return nil, err
}
@ -173,14 +246,14 @@ func (c *Center) Authenticate(r *http.Request) (*middleware.Box, error) {
return nil, err
}
cnrID, err := c.getAccessBoxContainer(authHdr.AccessKeyID)
addr, err := getAddress(authHdr.AccessKeyID)
if err != nil {
return nil, err
}
box, attrs, err := c.cli.GetBox(r.Context(), cnrID, authHdr.AccessKeyID)
box, attrs, err := c.cli.GetBox(r.Context(), addr)
if err != nil {
return nil, fmt.Errorf("get box by access key '%s': %w", authHdr.AccessKeyID, err)
return nil, fmt.Errorf("get box '%s': %w", addr, err)
}
if err = checkFormatHashContentSHA256(r.Header.Get(AmzContentSHA256)); err != nil {
@ -188,7 +261,7 @@ func (c *Center) Authenticate(r *http.Request) (*middleware.Box, error) {
}
clonedRequest := cloneRequest(r, authHdr)
if err = c.checkSign(authHdr, box, clonedRequest, signatureDateTime); err != nil {
if err = c.checkSign(r.Context(), authHdr, box, clonedRequest, signatureDateTime); err != nil {
return nil, err
}
@ -197,7 +270,7 @@ func (c *Center) Authenticate(r *http.Request) (*middleware.Box, error) {
AuthHeaders: &middleware.AuthHeader{
AccessKeyID: authHdr.AccessKeyID,
Region: authHdr.Region,
SignatureV4: authHdr.SignatureV4,
SignatureV4: authHdr.Signature,
},
Attributes: attrs,
}
@ -208,29 +281,15 @@ func (c *Center) Authenticate(r *http.Request) (*middleware.Box, error) {
return result, nil
}
func (c *Center) getAccessBoxContainer(accessKeyID string) (cid.ID, error) {
var addr oid.Address
if err := addr.DecodeString(strings.ReplaceAll(accessKeyID, "0", "/")); err == nil {
return addr.Container(), nil
}
cnrID, ok := c.settings.AccessBoxContainer()
if ok {
return cnrID, nil
}
return cid.ID{}, fmt.Errorf("%w: unknown container for creds '%s'", apierr.GetAPIError(apierr.ErrInvalidAccessKeyID), accessKeyID)
}
func checkFormatHashContentSHA256(hash string) error {
if !IsStandardContentSHA256(hash) {
hashBinary, err := hex.DecodeString(hash)
if err != nil {
return fmt.Errorf("%w: decode hash: %s: %s", apierr.GetAPIError(apierr.ErrContentSHA256Mismatch),
return fmt.Errorf("%w: decode hash: %s: %s", apiErrors.GetAPIError(apiErrors.ErrContentSHA256Mismatch),
hash, err.Error())
}
if len(hashBinary) != sha256.Size && len(hash) != 0 {
return fmt.Errorf("%w: invalid hash size %d", apierr.GetAPIError(apierr.ErrContentSHA256Mismatch), len(hashBinary))
return fmt.Errorf("%w: invalid hash size %d", apiErrors.GetAPIError(apiErrors.ErrContentSHA256Mismatch), len(hashBinary))
}
}
@ -248,12 +307,12 @@ func (c Center) checkAccessKeyID(accessKeyID string) error {
}
}
return fmt.Errorf("%w: accesskeyID prefix isn't allowed", apierr.GetAPIError(apierr.ErrAccessDenied))
return fmt.Errorf("%w: accesskeyID prefix isn't allowed", apiErrors.GetAPIError(apiErrors.ErrAccessDenied))
}
func (c *Center) checkFormData(r *http.Request) (*middleware.Box, error) {
if err := r.ParseMultipartForm(maxFormSizeMemory); err != nil {
return nil, fmt.Errorf("%w: parse multipart form with max size %d", apierr.GetAPIError(apierr.ErrInvalidArgument), maxFormSizeMemory)
return nil, fmt.Errorf("%w: parse multipart form with max size %d", apiErrors.GetAPIError(apiErrors.ErrInvalidArgument), maxFormSizeMemory)
}
if err := prepareForm(r.MultipartForm); err != nil {
@ -268,7 +327,7 @@ func (c *Center) checkFormData(r *http.Request) (*middleware.Box, error) {
creds := MultipartFormValue(r, "x-amz-credential")
submatches := c.postReg.GetSubmatches(creds)
if len(submatches) != 4 {
return nil, fmt.Errorf("%w: %s", apierr.GetAPIError(apierr.ErrAuthorizationHeaderMalformed), creds)
return nil, fmt.Errorf("%w: %s", apiErrors.GetAPIError(apiErrors.ErrAuthorizationHeaderMalformed), creds)
}
signatureDateTime, err := time.Parse("20060102T150405Z", MultipartFormValue(r, "x-amz-date"))
@ -278,14 +337,14 @@ func (c *Center) checkFormData(r *http.Request) (*middleware.Box, error) {
accessKeyID := submatches["access_key_id"]
cnrID, err := c.getAccessBoxContainer(accessKeyID)
addr, err := getAddress(accessKeyID)
if err != nil {
return nil, err
}
box, attrs, err := c.cli.GetBox(r.Context(), cnrID, accessKeyID)
box, attrs, err := c.cli.GetBox(r.Context(), addr)
if err != nil {
return nil, fmt.Errorf("get box by accessKeyID '%s': %w", accessKeyID, err)
return nil, fmt.Errorf("get box '%s': %w", addr, err)
}
secret := box.Gate.SecretKey
@ -294,7 +353,7 @@ func (c *Center) checkFormData(r *http.Request) (*middleware.Box, error) {
signature := SignStr(secret, service, region, signatureDateTime, policy)
reqSignature := MultipartFormValue(r, "x-amz-signature")
if signature != reqSignature {
return nil, fmt.Errorf("%w: %s != %s", apierr.GetAPIError(apierr.ErrSignatureDoesNotMatch),
return nil, fmt.Errorf("%w: %s != %s", apiErrors.GetAPIError(apiErrors.ErrSignatureDoesNotMatch),
reqSignature, signature)
}
@ -330,38 +389,87 @@ func cloneRequest(r *http.Request, authHeader *AuthHeader) *http.Request {
return otherRequest
}
func (c *Center) checkSign(authHeader *AuthHeader, box *accessbox.Box, request *http.Request, signatureDateTime time.Time) error {
awsCreds := credentials.NewStaticCredentials(authHeader.AccessKeyID, box.Gate.SecretKey, "")
signer := v4.NewSigner(awsCreds)
signer.DisableURIPathEscaping = true
func (c *Center) checkSign(ctx context.Context, authHeader *AuthHeader, box *accessbox.Box, request *http.Request, signatureDateTime time.Time) error {
var signature string
switch authHeader.Preamble {
case signaturePreambleSigV4:
creds := aws.Credentials{
AccessKeyID: authHeader.AccessKeyID,
SecretAccessKey: box.Gate.SecretKey,
}
signer := v4.NewSigner(func(options *v4.SignerOptions) {
options.DisableURIPathEscaping = true
})
if authHeader.IsPresigned {
now := time.Now()
if signatureDateTime.Add(authHeader.Expiration).Before(now) {
return fmt.Errorf("%w: expired: now %s, signature %s", apierr.GetAPIError(apierr.ErrExpiredPresignRequest),
now.Format(time.RFC3339), signatureDateTime.Format(time.RFC3339))
if err := checkPresignedDate(authHeader, signatureDateTime); err != nil {
return err
}
if now.Before(signatureDateTime) {
return fmt.Errorf("%w: signature time from the future: now %s, signature %s", apierr.GetAPIError(apierr.ErrBadRequest),
now.Format(time.RFC3339), signatureDateTime.Format(time.RFC3339))
}
if _, err := signer.Presign(request, nil, authHeader.Service, authHeader.Region, authHeader.Expiration, signatureDateTime); err != nil {
signedURI, _, err := signer.PresignHTTP(ctx, creds, request, authHeader.PayloadHash, authHeader.Service, authHeader.Region, signatureDateTime)
if err != nil {
return fmt.Errorf("failed to pre-sign temporary HTTP request: %w", err)
}
signature = request.URL.Query().Get(AmzSignature)
u, err := url.ParseRequestURI(signedURI)
if err != nil {
return err
}
signature = u.Query().Get(AmzSignature)
} else {
if _, err := signer.Sign(request, nil, authHeader.Service, authHeader.Region, signatureDateTime); err != nil {
if err := signer.SignHTTP(ctx, creds, request, authHeader.PayloadHash, authHeader.Service, authHeader.Region, signatureDateTime); err != nil {
return fmt.Errorf("failed to sign temporary HTTP request: %w", err)
}
signature = c.reg.GetSubmatches(request.Header.Get(AuthorizationHdr))["v4_signature"]
}
if authHeader.SignatureV4 != signature {
return fmt.Errorf("%w: %s != %s: headers %v", apierr.GetAPIError(apierr.ErrSignatureDoesNotMatch),
authHeader.SignatureV4, signature, authHeader.SignedFields)
if authHeader.Signature != signature {
return fmt.Errorf("%w: %s != %s: headers %v", apiErrors.GetAPIError(apiErrors.ErrSignatureDoesNotMatch),
authHeader.Signature, signature, authHeader.SignedFields)
}
case signaturePreambleSigV4A:
signer := v4a.NewSigner(func(options *v4a.SignerOptions) {
options.DisableURIPathEscaping = true
})
credAdapter := v4a.SymmetricCredentialAdaptor{
SymmetricProvider: credentials.NewStaticCredentialsProvider(authHeader.AccessKeyID, box.Gate.SecretKey, ""),
}
creds, err := credAdapter.RetrievePrivateKey(request.Context())
if err != nil {
return fmt.Errorf("failed to derive assymetric key from credentials: %w", err)
}
if !authHeader.IsPresigned {
return signer.VerifySignature(creds, request, authHeader.PayloadHash, authHeader.Service,
strings.Split(authHeader.Region, ","), signatureDateTime, authHeader.Signature)
}
if err = checkPresignedDate(authHeader, signatureDateTime); err != nil {
return err
}
return signer.VerifyPresigned(creds, request, authHeader.PayloadHash, authHeader.Service,
strings.Split(authHeader.Region, ","), signatureDateTime, authHeader.Signature)
default:
return fmt.Errorf("invalid preamble: %s", authHeader.Preamble)
}
return nil
}
func checkPresignedDate(authHeader *AuthHeader, signatureDateTime time.Time) error {
now := time.Now()
if signatureDateTime.Add(authHeader.Expiration).Before(now) {
return fmt.Errorf("%w: expired: now %s, signature %s", apiErrors.GetAPIError(apiErrors.ErrExpiredPresignRequest),
now.Format(time.RFC3339), signatureDateTime.Format(time.RFC3339))
}
if now.Before(signatureDateTime) {
return fmt.Errorf("%w: signature time from the future: now %s, signature %s", apiErrors.GetAPIError(apiErrors.ErrBadRequest),
now.Format(time.RFC3339), signatureDateTime.Format(time.RFC3339))
}
return nil
}

View file

@ -72,7 +72,7 @@ func DoFuzzAuthenticate(input []byte) int {
c := &Center{
cli: mock,
reg: NewRegexpMatcher(AuthorizationFieldRegexp),
reg: NewRegexpMatcher(authorizationFieldRegexp),
postReg: NewRegexpMatcher(postPolicyCredentialRegexp),
}

View file

@ -8,44 +8,35 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"time"
v4 "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4"
v4a "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4asdk2"
v4 "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4sdk2/signer/v4"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/cache"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/tokens"
frosterr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/frostfs/errors"
frostfsErrors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/frostfs/errors"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/bearer"
cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
oidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id/test"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/smithy-go/logging"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
)
type centerSettingsMock struct {
accessBoxContainer *cid.ID
}
func (c *centerSettingsMock) AccessBoxContainer() (cid.ID, bool) {
if c.accessBoxContainer == nil {
return cid.ID{}, false
}
return *c.accessBoxContainer, true
}
func TestAuthHeaderParse(t *testing.T) {
defaultHeader := "AWS4-HMAC-SHA256 Credential=oid0cid/20210809/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=2811ccb9e242f41426738fb1f"
center := &Center{
reg: NewRegexpMatcher(AuthorizationFieldRegexp),
settings: &centerSettingsMock{},
reg: NewRegexpMatcher(authorizationFieldRegexp),
}
for _, tc := range []struct {
@ -60,9 +51,10 @@ func TestAuthHeaderParse(t *testing.T) {
AccessKeyID: "oid0cid",
Service: "s3",
Region: "us-east-1",
SignatureV4: "2811ccb9e242f41426738fb1f",
Signature: "2811ccb9e242f41426738fb1f",
SignedFields: []string{"host", "x-amz-content-sha256", "x-amz-date"},
Date: "20210809",
Preamble: signaturePreambleSigV4,
},
},
{
@ -70,13 +62,55 @@ func TestAuthHeaderParse(t *testing.T) {
err: errors.GetAPIError(errors.ErrAuthorizationHeaderMalformed),
expected: nil,
},
{
header: strings.ReplaceAll(defaultHeader, "oid0cid", "oidcid"),
err: errors.GetAPIError(errors.ErrInvalidAccessKeyID),
expected: nil,
},
} {
authHeader, err := center.parseAuthHeader(tc.header)
authHeader, err := center.parseAuthHeader(tc.header, nil)
require.ErrorIs(t, err, tc.err, tc.header)
require.Equal(t, tc.expected, authHeader, tc.header)
}
}
func TestAuthHeaderGetAddress(t *testing.T) {
defaulErr := errors.GetAPIError(errors.ErrInvalidAccessKeyID)
for _, tc := range []struct {
authHeader *AuthHeader
err error
}{
{
authHeader: &AuthHeader{
AccessKeyID: "vWqF8cMDRbJcvnPLALoQGnABPPhw8NyYMcGsfDPfZJM0HrgjonN8CgFvCZ3kh9BUXw4W2tJ5E7EAGhueSF122HB",
},
err: nil,
},
{
authHeader: &AuthHeader{
AccessKeyID: "vWqF8cMDRbJcvnPLALoQGnABPPhw8NyYMcGsfDPfZJMHrgjonN8CgFvCZ3kh9BUXw4W2tJ5E7EAGhueSF122HB",
},
err: defaulErr,
},
{
authHeader: &AuthHeader{
AccessKeyID: "oid0cid",
},
err: defaulErr,
},
{
authHeader: &AuthHeader{
AccessKeyID: "oidcid",
},
err: defaulErr,
},
} {
_, err := getAddress(tc.authHeader.AccessKeyID)
require.ErrorIs(t, err, tc.err, tc.authHeader.AccessKeyID)
}
}
func TestSignature(t *testing.T) {
secret := "66be461c3cd429941c55daf42fad2b8153e5a2016ba89c9494d97677cc9d3872"
strToSign := "eyAiZXhwaXJhdGlvbiI6ICIyMDE1LTEyLTMwVDEyOjAwOjAwLjAwMFoiLAogICJjb25kaXRpb25zIjogWwogICAgeyJidWNrZXQiOiAiYWNsIn0sCiAgICBbInN0YXJ0cy13aXRoIiwgIiRrZXkiLCAidXNlci91c2VyMS8iXSwKICAgIHsic3VjY2Vzc19hY3Rpb25fcmVkaXJlY3QiOiAiaHR0cDovL2xvY2FsaG9zdDo4MDg0L2FjbCJ9LAogICAgWyJzdGFydHMtd2l0aCIsICIkQ29udGVudC1UeXBlIiwgImltYWdlLyJdLAogICAgeyJ4LWFtei1tZXRhLXV1aWQiOiAiMTQzNjUxMjM2NTEyNzQifSwKICAgIFsic3RhcnRzLXdpdGgiLCAiJHgtYW16LW1ldGEtdGFnIiwgIiJdLAoKICAgIHsiWC1BbXotQ3JlZGVudGlhbCI6ICI4Vmk0MVBIbjVGMXNzY2J4OUhqMXdmMUU2aERUYURpNndxOGhxTU05NllKdTA1QzVDeUVkVlFoV1E2aVZGekFpTkxXaTlFc3BiUTE5ZDRuR3pTYnZVZm10TS8yMDE1MTIyOS91cy1lYXN0LTEvczMvYXdzNF9yZXF1ZXN0In0sCiAgICB7IngtYW16LWFsZ29yaXRobSI6ICJBV1M0LUhNQUMtU0hBMjU2In0sCiAgICB7IlgtQW16LURhdGUiOiAiMjAxNTEyMjlUMDAwMDAwWiIgfSwKICAgIHsieC1pZ25vcmUtdG1wIjogInNvbWV0aGluZyIgfQogIF0KfQ=="
@ -90,6 +124,52 @@ func TestSignature(t *testing.T) {
require.Equal(t, "dfbe886241d9e369cf4b329ca0f15eb27306c97aa1022cc0bb5a914c4ef87634", signature)
}
func TestSignatureV4A(t *testing.T) {
accessKeyID := "2XEbqH4M3ym7a3E3esxfZ2gRLnMwDXrCN4y1SkQg5fHa09sThVmVL3EE6xeKsyMzaqu5jPi41YCaVbnwbwCTF3bx1"
secretKey := "00637f53f842573aaa06c2164c598973cd986880987111416cf71f1619def537"
signer := v4a.NewSigner(func(options *v4a.SignerOptions) {
options.DisableURIPathEscaping = true
options.Logger = logging.NewStandardLogger(os.Stdout)
options.LogSigning = true
})
credAdapter := v4a.SymmetricCredentialAdaptor{
SymmetricProvider: credentials.NewStaticCredentialsProvider(accessKeyID, secretKey, ""),
}
bodyStr := `
1b;chunk-signature=3045022100b63692a1b20759bdabd342011823427a8952df75c93174d98ad043abca8052e002201695228a91ba986171b8d0ad20856d3d94ca3614d0a90a50a531ba8e52447b9b**
Testing with the {sdk-java}
0;chunk-signature=30440220455885a2d4e9f705256ca6b0a5a22f7f784780ccbd1c0a371e5db3059c91745b022073259dd44746cbd63261d628a04d25be5a32a974c077c5c2d83c8157fb323b9f****
`
body := bytes.NewBufferString(bodyStr)
req, err := http.NewRequest("PUT", "http://localhost:8084/test/tmp", body)
require.NoError(t, err)
req.Header.Set("Amz-Sdk-Invocation-Id", "ca3a3cde-7d26-fce6-ed9c-82f7a0573824")
req.Header.Set("Amz-Sdk-Request", "attempt=2; max=2")
req.Header.Set("Authorization", "AWS4-ECDSA-P256-SHA256 Credential=2XEbqH4M3ym7a3E3esxfZ2gRLnMwDXrCN4y1SkQg5fHa09sThVmVL3EE6xeKsyMzaqu5jPi41YCaVbnwbwCTF3bx1/20240904/s3/aws4_request, SignedHeaders=amz-sdk-invocation-id;amz-sdk-request;content-length;content-type;host;x-amz-content-sha256;x-amz-date;x-amz-decoded-content-length;x-amz-region-set, Signature=30440220574244c5ff5deba388c4e3b0541a42113179b6839b3e6b4212d255a118fa9089022056f7b9b72c93f67dbcd25fe9ca67950b5913fc00bb7a62bc276c21e828c0b6c7")
req.Header.Set("Content-Length", "360")
req.Header.Set("Content-Type", "text/plain; charset=UTF-8")
req.Header.Set("X-Amz-Content-Sha256", "STREAMING-AWS4-ECDSA-P256-SHA256-PAYLOAD")
req.Header.Set("X-Amz-Date", "20240904T133253Z")
req.Header.Set("X-Amz-Decoded-Content-Length", "27")
req.Header.Set("X-Amz-Region-Set", "us-east-1")
service := "s3"
regionSet := []string{"us-east-1"}
signature := "30440220574244c5ff5deba388c4e3b0541a42113179b6839b3e6b4212d255a118fa9089022056f7b9b72c93f67dbcd25fe9ca67950b5913fc00bb7a62bc276c21e828c0b6c7"
signingTime, err := time.Parse("20060102T150405Z", "20240904T133253Z")
require.NoError(t, err)
creds, err := credAdapter.RetrievePrivateKey(req.Context())
require.NoError(t, err)
err = signer.VerifySignature(creds, req, "STREAMING-AWS4-ECDSA-P256-SHA256-PAYLOAD", service, regionSet, signingTime, signature)
require.NoError(t, err)
}
func TestCheckFormatContentSHA256(t *testing.T) {
defaultErr := errors.GetAPIError(errors.ErrContentSHA256Mismatch)
@ -142,17 +222,17 @@ func TestCheckFormatContentSHA256(t *testing.T) {
}
type frostFSMock struct {
objects map[string]*object.Object
objects map[oid.Address]*object.Object
}
func newFrostFSMock() *frostFSMock {
return &frostFSMock{
objects: map[string]*object.Object{},
objects: map[oid.Address]*object.Object{},
}
}
func (f *frostFSMock) GetCredsObject(_ context.Context, prm tokens.PrmGetCredsObject) (*object.Object, error) {
obj, ok := f.objects[prm.AccessKeyID]
func (f *frostFSMock) GetCredsObject(_ context.Context, address oid.Address) (*object.Object, error) {
obj, ok := f.objects[address]
if !ok {
return nil, fmt.Errorf("not found")
}
@ -165,6 +245,7 @@ func (f *frostFSMock) CreateObject(context.Context, tokens.PrmObjectCreate) (oid
}
func TestAuthenticate(t *testing.T) {
ctx := context.Background()
key, err := keys.NewPrivateKey()
require.NoError(t, err)
@ -179,7 +260,7 @@ func TestAuthenticate(t *testing.T) {
GateKey: key.PublicKey(),
}}
accessBox, secret, err := accessbox.PackTokens(gateData, []byte("secret"), false)
accessBox, secret, err := accessbox.PackTokens(gateData, []byte("secret"))
require.NoError(t, err)
data, err := accessBox.Marshal()
require.NoError(t, err)
@ -190,13 +271,13 @@ func TestAuthenticate(t *testing.T) {
obj.SetContainerID(addr.Container())
obj.SetID(addr.Object())
accessKeyID := getAccessKeyID(addr)
frostfs := newFrostFSMock()
frostfs.objects[accessKeyID] = &obj
frostfs.objects[addr] = &obj
awsCreds := credentials.NewStaticCredentials(accessKeyID, secret.SecretKey, "")
defaultSigner := v4.NewSigner(awsCreds)
accessKeyID := addr.Container().String() + "0" + addr.Object().String()
awsCreds := aws.Credentials{AccessKeyID: accessKeyID, SecretAccessKey: secret.SecretKey}
defaultSigner := v4.NewSigner()
service, region := "s3", "default"
invalidValue := "invalid-value"
@ -219,7 +300,7 @@ func TestAuthenticate(t *testing.T) {
prefixes: []string{addr.Container().String()},
request: func() *http.Request {
r := httptest.NewRequest(http.MethodPost, "/", nil)
_, err = defaultSigner.Sign(r, nil, service, region, time.Now())
err = defaultSigner.SignHTTP(ctx, awsCreds, r, "", service, region, time.Now())
require.NoError(t, err)
return r
}(),
@ -245,8 +326,8 @@ func TestAuthenticate(t *testing.T) {
name: "invalid access key id format",
request: func() *http.Request {
r := httptest.NewRequest(http.MethodPost, "/", nil)
signer := v4.NewSigner(credentials.NewStaticCredentials(addr.Object().String(), secret.SecretKey, ""))
_, err = signer.Sign(r, nil, service, region, time.Now())
cred := aws.Credentials{AccessKeyID: addr.Object().String(), SecretAccessKey: secret.SecretKey}
err = v4.NewSigner().SignHTTP(ctx, cred, r, "", service, region, time.Now())
require.NoError(t, err)
return r
}(),
@ -258,7 +339,7 @@ func TestAuthenticate(t *testing.T) {
prefixes: []string{addr.Object().String()},
request: func() *http.Request {
r := httptest.NewRequest(http.MethodPost, "/", nil)
_, err = defaultSigner.Sign(r, nil, service, region, time.Now())
err = defaultSigner.SignHTTP(ctx, awsCreds, r, "", service, region, time.Now())
require.NoError(t, err)
return r
}(),
@ -269,8 +350,8 @@ func TestAuthenticate(t *testing.T) {
name: "invalid access key id value",
request: func() *http.Request {
r := httptest.NewRequest(http.MethodPost, "/", nil)
signer := v4.NewSigner(credentials.NewStaticCredentials(accessKeyID[:len(accessKeyID)-4], secret.SecretKey, ""))
_, err = signer.Sign(r, nil, service, region, time.Now())
cred := aws.Credentials{AccessKeyID: accessKeyID[:len(accessKeyID)-4], SecretAccessKey: secret.SecretKey}
err = v4.NewSigner().SignHTTP(ctx, cred, r, "", service, region, time.Now())
require.NoError(t, err)
return r
}(),
@ -281,8 +362,8 @@ func TestAuthenticate(t *testing.T) {
name: "unknown access key id",
request: func() *http.Request {
r := httptest.NewRequest(http.MethodPost, "/", nil)
signer := v4.NewSigner(credentials.NewStaticCredentials(addr.Object().String()+"0"+addr.Container().String(), secret.SecretKey, ""))
_, err = signer.Sign(r, nil, service, region, time.Now())
cred := aws.Credentials{AccessKeyID: addr.Object().String() + "0" + addr.Container().String(), SecretAccessKey: secret.SecretKey}
err = v4.NewSigner().SignHTTP(ctx, cred, r, "", service, region, time.Now())
require.NoError(t, err)
return r
}(),
@ -292,8 +373,8 @@ func TestAuthenticate(t *testing.T) {
name: "invalid signature",
request: func() *http.Request {
r := httptest.NewRequest(http.MethodPost, "/", nil)
signer := v4.NewSigner(credentials.NewStaticCredentials(accessKeyID, "secret", ""))
_, err = signer.Sign(r, nil, service, region, time.Now())
cred := aws.Credentials{AccessKeyID: accessKeyID, SecretAccessKey: "secret"}
err = v4.NewSigner().SignHTTP(ctx, cred, r, "", service, region, time.Now())
require.NoError(t, err)
return r
}(),
@ -305,7 +386,7 @@ func TestAuthenticate(t *testing.T) {
prefixes: []string{addr.Container().String()},
request: func() *http.Request {
r := httptest.NewRequest(http.MethodPost, "/", nil)
_, err = defaultSigner.Sign(r, nil, service, region, time.Now())
err = defaultSigner.SignHTTP(ctx, awsCreds, r, "", service, region, time.Now())
r.Header.Set(AmzDate, invalidValue)
require.NoError(t, err)
return r
@ -317,7 +398,7 @@ func TestAuthenticate(t *testing.T) {
prefixes: []string{addr.Container().String()},
request: func() *http.Request {
r := httptest.NewRequest(http.MethodPost, "/", nil)
_, err = defaultSigner.Sign(r, nil, service, region, time.Now())
err = defaultSigner.SignHTTP(ctx, awsCreds, r, "", service, region, time.Now())
r.Header.Set(AmzContentSHA256, invalidValue)
require.NoError(t, err)
return r
@ -328,7 +409,10 @@ func TestAuthenticate(t *testing.T) {
name: "valid presign",
request: func() *http.Request {
r := httptest.NewRequest(http.MethodPost, "/", nil)
_, err = defaultSigner.Presign(r, nil, service, region, time.Minute, time.Now())
r.Header.Set(AmzExpires, "60")
signedURI, _, err := defaultSigner.PresignHTTP(ctx, awsCreds, r, "", service, region, time.Now())
require.NoError(t, err)
r.URL, err = url.ParseRequestURI(signedURI)
require.NoError(t, err)
return r
}(),
@ -350,20 +434,24 @@ func TestAuthenticate(t *testing.T) {
name: "presign, bad X-Amz-Expires",
request: func() *http.Request {
r := httptest.NewRequest(http.MethodPost, "/", nil)
_, err = defaultSigner.Presign(r, nil, service, region, time.Minute, time.Now())
queryParams := r.URL.Query()
queryParams.Set("X-Amz-Expires", invalidValue)
r.URL.RawQuery = queryParams.Encode()
r.Header.Set(AmzExpires, invalidValue)
signedURI, _, err := defaultSigner.PresignHTTP(ctx, awsCreds, r, UnsignedPayload, service, region, time.Now())
require.NoError(t, err)
r.URL, err = url.ParseRequestURI(signedURI)
require.NoError(t, err)
return r
}(),
err: true,
errCode: errors.ErrMalformedExpires,
},
{
name: "presign, expired",
request: func() *http.Request {
r := httptest.NewRequest(http.MethodPost, "/", nil)
_, err = defaultSigner.Presign(r, nil, service, region, time.Minute, time.Now().Add(-time.Minute))
r.Header.Set(AmzExpires, "60")
signedURI, _, err := defaultSigner.PresignHTTP(ctx, awsCreds, r, UnsignedPayload, service, region, time.Now().Add(-time.Minute))
require.NoError(t, err)
r.URL, err = url.ParseRequestURI(signedURI)
require.NoError(t, err)
return r
}(),
@ -374,7 +462,10 @@ func TestAuthenticate(t *testing.T) {
name: "presign, signature from future",
request: func() *http.Request {
r := httptest.NewRequest(http.MethodPost, "/", nil)
_, err = defaultSigner.Presign(r, nil, service, region, time.Minute, time.Now().Add(time.Minute))
r.Header.Set(AmzExpires, "60")
signedURI, _, err := defaultSigner.PresignHTTP(ctx, awsCreds, r, UnsignedPayload, service, region, time.Now().Add(time.Minute))
require.NoError(t, err)
r.URL, err = url.ParseRequestURI(signedURI)
require.NoError(t, err)
return r
}(),
@ -384,13 +475,13 @@ func TestAuthenticate(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
creds := tokens.New(bigConfig)
cntr := New(creds, tc.prefixes, &centerSettingsMock{})
cntr := New(creds, tc.prefixes)
box, err := cntr.Authenticate(tc.request)
if tc.err {
require.Error(t, err)
if tc.errCode > 0 {
err = frosterr.UnwrapErr(err)
err = frostfsErrors.UnwrapErr(err)
require.Equal(t, errors.GetAPIError(tc.errCode), err)
}
} else {
@ -426,7 +517,7 @@ func TestHTTPPostAuthenticate(t *testing.T) {
GateKey: key.PublicKey(),
}}
accessBox, secret, err := accessbox.PackTokens(gateData, []byte("secret"), false)
accessBox, secret, err := accessbox.PackTokens(gateData, []byte("secret"))
require.NoError(t, err)
data, err := accessBox.Marshal()
require.NoError(t, err)
@ -437,11 +528,10 @@ func TestHTTPPostAuthenticate(t *testing.T) {
obj.SetContainerID(addr.Container())
obj.SetID(addr.Object())
accessKeyID := getAccessKeyID(addr)
frostfs := newFrostFSMock()
frostfs.objects[accessKeyID] = &obj
frostfs.objects[addr] = &obj
accessKeyID := addr.Container().String() + "0" + addr.Object().String()
invalidAccessKeyID := oidtest.Address().String() + "0" + oidtest.Address().Object().String()
timeToSign := time.Now()
@ -562,13 +652,13 @@ func TestHTTPPostAuthenticate(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
creds := tokens.New(bigConfig)
cntr := New(creds, tc.prefixes, &centerSettingsMock{})
cntr := New(creds, tc.prefixes)
box, err := cntr.Authenticate(tc.request)
if tc.err {
require.Error(t, err)
if tc.errCode > 0 {
err = frosterr.UnwrapErr(err)
err = frostfsErrors.UnwrapErr(err)
require.Equal(t, errors.GetAPIError(tc.errCode), err)
}
} else {
@ -605,7 +695,3 @@ func getRequestWithMultipartForm(t *testing.T, policy, creds, date, sign, fieldN
return req
}
func getAccessKeyID(addr oid.Address) string {
return strings.ReplaceAll(addr.EncodeToString(), "/", "0")
}

View file

@ -1,14 +1,21 @@
package auth
import (
"context"
"fmt"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
v4 "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/private/protocol/rest"
v4a "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4asdk2"
v4 "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4sdk2/signer/v4"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/smithy-go/encoding/httpbinding"
"github.com/aws/smithy-go/logging"
)
type RequestData struct {
@ -27,25 +34,70 @@ type PresignData struct {
}
// PresignRequest forms pre-signed request to access objects without aws credentials.
func PresignRequest(creds *credentials.Credentials, reqData RequestData, presignData PresignData) (*http.Request, error) {
urlStr := fmt.Sprintf("%s/%s/%s", reqData.Endpoint, rest.EscapePath(reqData.Bucket, false), rest.EscapePath(reqData.Object, false))
func PresignRequest(ctx context.Context, creds aws.Credentials, reqData RequestData, presignData PresignData) (*http.Request, error) {
urlStr := fmt.Sprintf("%s/%s/%s", reqData.Endpoint, httpbinding.EscapePath(reqData.Bucket, false), httpbinding.EscapePath(reqData.Object, false))
req, err := http.NewRequest(strings.ToUpper(reqData.Method), urlStr, nil)
if err != nil {
return nil, fmt.Errorf("failed to create new request: %w", err)
}
req.Header.Set(AmzDate, presignData.SignTime.Format("20060102T150405Z"))
for k, v := range presignData.Headers {
req.Header.Set(k, v)
req.Header.Set(k, v) // maybe we should filter system header (or keep responsibility on caller)
}
req.Header.Set(AmzDate, presignData.SignTime.Format("20060102T150405Z"))
req.Header.Set(AmzExpires, strconv.FormatFloat(presignData.Lifetime.Round(time.Second).Seconds(), 'f', 0, 64))
signer := v4.NewSigner(func(options *v4.SignerOptions) {
options.DisableURIPathEscaping = true
})
signedURI, _, err := signer.PresignHTTP(ctx, creds, req, presignData.Headers[AmzContentSHA256], presignData.Service, presignData.Region, presignData.SignTime)
if err != nil {
return nil, fmt.Errorf("presign: %w", err)
}
signer := v4.NewSigner(creds)
signer.DisableURIPathEscaping = true
if _, err = signer.Presign(req, nil, presignData.Service, presignData.Region, presignData.Lifetime, presignData.SignTime); err != nil {
return nil, fmt.Errorf("presign: %w", err)
if req.URL, err = url.ParseRequestURI(signedURI); err != nil {
return nil, fmt.Errorf("parse signed URI: %w", err)
}
return req, nil
}
// PresignRequestV4a forms pre-signed request to access objects without aws credentials.
func PresignRequestV4a(cred aws.Credentials, reqData RequestData, presignData PresignData) (*http.Request, error) {
urlStr := fmt.Sprintf("%s/%s/%s", reqData.Endpoint, httpbinding.EscapePath(reqData.Bucket, false), httpbinding.EscapePath(reqData.Object, false))
req, err := http.NewRequest(strings.ToUpper(reqData.Method), urlStr, nil)
if err != nil {
return nil, fmt.Errorf("failed to create new request: %w", err)
}
for k, v := range presignData.Headers {
req.Header.Set(k, v) // maybe we should filter system header (or keep responsibility on caller)
}
req.Header.Set(AmzDate, presignData.SignTime.Format("20060102T150405Z"))
req.Header.Set(AmzExpires, strconv.Itoa(int(presignData.Lifetime.Seconds())))
signer := v4a.NewSigner(func(options *v4a.SignerOptions) {
options.DisableURIPathEscaping = true
options.LogSigning = true
options.Logger = logging.NewStandardLogger(os.Stdout)
})
credAdapter := v4a.SymmetricCredentialAdaptor{
SymmetricProvider: credentials.NewStaticCredentialsProvider(cred.AccessKeyID, cred.SecretAccessKey, ""),
}
creds, err := credAdapter.RetrievePrivateKey(req.Context())
if err != nil {
return nil, fmt.Errorf("failed to derive assymetric key from credentials: %w", err)
}
presignedURL, _, err := signer.PresignHTTP(req.Context(), creds, req, presignData.Headers[AmzContentSHA256], presignData.Service, []string{presignData.Region}, presignData.SignTime)
if err != nil {
return nil, fmt.Errorf("presign: %w", err)
}
fmt.Println(presignedURL)
return http.NewRequest(reqData.Method, presignedURL, nil)
}

View file

@ -2,17 +2,23 @@ package auth
import (
"context"
"fmt"
"net/http"
"os"
"strings"
"testing"
"time"
v4a "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4asdk2"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/tokens"
apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status"
cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go-v2/aws"
credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/smithy-go/logging"
"github.com/stretchr/testify/require"
)
@ -29,11 +35,11 @@ func newTokensFrostfsMock() *credentialsMock {
}
func (m credentialsMock) addBox(addr oid.Address, box *accessbox.Box) {
m.boxes[getAccessKeyID(addr)] = box
m.boxes[addr.String()] = box
}
func (m credentialsMock) GetBox(_ context.Context, _ cid.ID, accessKeyID string) (*accessbox.Box, []object.Attribute, error) {
box, ok := m.boxes[accessKeyID]
func (m credentialsMock) GetBox(_ context.Context, addr oid.Address) (*accessbox.Box, []object.Attribute, error) {
box, ok := m.boxes[addr.String()]
if !ok {
return nil, nil, &apistatus.ObjectNotFound{}
}
@ -41,22 +47,24 @@ func (m credentialsMock) GetBox(_ context.Context, _ cid.ID, accessKeyID string)
return box, nil, nil
}
func (m credentialsMock) Put(context.Context, tokens.CredentialsParam) (oid.Address, error) {
func (m credentialsMock) Put(context.Context, cid.ID, tokens.CredentialsParam) (oid.Address, error) {
return oid.Address{}, nil
}
func (m credentialsMock) Update(context.Context, tokens.CredentialsParam) (oid.Address, error) {
func (m credentialsMock) Update(context.Context, oid.Address, tokens.CredentialsParam) (oid.Address, error) {
return oid.Address{}, nil
}
func TestCheckSign(t *testing.T) {
ctx := context.Background()
var accessKeyAddr oid.Address
err := accessKeyAddr.DecodeString("8N7CYBY74kxZXoyvA5UNdmovaXqFpwNfvEPsqaN81es2/3tDwq5tR8fByrJcyJwyiuYX7Dae8tyDT7pd8oaL1MBto")
require.NoError(t, err)
accessKeyID := strings.ReplaceAll(accessKeyAddr.String(), "/", "0")
secretKey := "713d0a0b9efc7d22923e17b0402a6a89b4273bc711c8bacb2da1b643d0006aeb"
awsCreds := credentials.NewStaticCredentials(accessKeyID, secretKey, "")
awsCreds := aws.Credentials{AccessKeyID: accessKeyID, SecretAccessKey: secretKey}
reqData := RequestData{
Method: "GET",
@ -69,9 +77,13 @@ func TestCheckSign(t *testing.T) {
Region: "spb",
Lifetime: 10 * time.Minute,
SignTime: time.Now().UTC(),
Headers: map[string]string{
ContentTypeHdr: "text/plain",
AmzContentSHA256: UnsignedPayload,
},
}
req, err := PresignRequest(awsCreds, reqData, presignData)
req, err := PresignRequest(ctx, awsCreds, reqData, presignData)
require.NoError(t, err)
expBox := &accessbox.Box{
@ -85,11 +97,106 @@ func TestCheckSign(t *testing.T) {
c := &Center{
cli: mock,
reg: NewRegexpMatcher(AuthorizationFieldRegexp),
reg: NewRegexpMatcher(authorizationFieldRegexp),
postReg: NewRegexpMatcher(postPolicyCredentialRegexp),
settings: &centerSettingsMock{},
}
box, err := c.Authenticate(req)
require.NoError(t, err)
require.EqualValues(t, expBox, box.AccessBox)
}
func TestCheckSignV4a(t *testing.T) {
var accessKeyAddr oid.Address
err := accessKeyAddr.DecodeString("8N7CYBY74kxZXoyvA5UNdmovaXqFpwNfvEPsqaN81es2/3tDwq5tR8fByrJcyJwyiuYX7Dae8tyDT7pd8oaL1MBto")
require.NoError(t, err)
accessKeyID := strings.ReplaceAll(accessKeyAddr.String(), "/", "0")
secretKey := "713d0a0b9efc7d22923e17b0402a6a89b4273bc711c8bacb2da1b643d0006aeb"
awsCreds := aws.Credentials{AccessKeyID: accessKeyID, SecretAccessKey: secretKey}
reqData := RequestData{
Method: "GET",
Endpoint: "http://localhost:8084",
Bucket: "my-bucket",
Object: "@obj/name",
}
presignData := PresignData{
Service: "s3",
Region: "spb",
Lifetime: 10 * time.Minute,
SignTime: time.Now().UTC(),
Headers: map[string]string{
ContentTypeHdr: "text/plain",
},
}
req, err := PresignRequestV4a(awsCreds, reqData, presignData)
require.NoError(t, err)
req.Header.Set(ContentTypeHdr, "text/plain")
expBox := &accessbox.Box{
Gate: &accessbox.GateData{
SecretKey: secretKey,
},
}
mock := newTokensFrostfsMock()
mock.addBox(accessKeyAddr, expBox)
c := &Center{
cli: mock,
regV4a: NewRegexpMatcher(authorizationFieldV4aRegexp),
postReg: NewRegexpMatcher(postPolicyCredentialRegexp),
}
box, err := c.Authenticate(req)
require.NoError(t, err)
require.EqualValues(t, expBox, box.AccessBox)
}
func TestPresignRequestV4a(t *testing.T) {
var accessKeyAddr oid.Address
err := accessKeyAddr.DecodeString("8N7CYBY74kxZXoyvA5UNdmovaXqFpwNfvEPsqaN81es2/3tDwq5tR8fByrJcyJwyiuYX7Dae8tyDT7pd8oaL1MBto")
require.NoError(t, err)
accessKeyID := strings.ReplaceAll(accessKeyAddr.String(), "/", "0")
secretKey := "713d0a0b9efc7d22923e17b0402a6a89b4273bc711c8bacb2da1b643d0006aeb"
signer := v4a.NewSigner(func(options *v4a.SignerOptions) {
options.DisableURIPathEscaping = true
options.LogSigning = true
options.Logger = logging.NewStandardLogger(os.Stdout)
})
credAdapter := v4a.SymmetricCredentialAdaptor{
SymmetricProvider: credentialsv2.NewStaticCredentialsProvider(accessKeyID, secretKey, ""),
}
creds, err := credAdapter.RetrievePrivateKey(context.TODO())
require.NoError(t, err)
signingTime := time.Now()
service := "s3"
regionSet := []string{"spb"}
req, err := http.NewRequest("GET", "http://localhost:8084/bucket/object", nil)
require.NoError(t, err)
req.Header.Set(AmzExpires, "600")
presignedURL, hdr, err := signer.PresignHTTP(req.Context(), creds, req, "", service, regionSet, signingTime)
require.NoError(t, err)
fmt.Println(presignedURL)
fmt.Println(hdr)
signature := req.URL.Query().Get(AmzSignature)
r, err := http.NewRequest("GET", presignedURL, nil)
require.NoError(t, err)
query := r.URL.Query()
query.Del(AmzSignature)
r.URL.RawQuery = query.Encode()
err = signer.VerifyPresigned(creds, r, "", service, regionSet, signingTime, signature)
require.NoError(t, err)
}

View file

@ -1,87 +0,0 @@
package v4
import (
"strings"
)
// validator houses a set of rule needed for validation of a
// string value.
type rules []rule
// rule interface allows for more flexible rules and just simply
// checks whether or not a value adheres to that rule.
type rule interface {
IsValid(value string) bool
}
// IsValid will iterate through all rules and see if any rules
// apply to the value and supports nested rules.
func (r rules) IsValid(value string) bool {
for _, rule := range r {
if rule.IsValid(value) {
return true
}
}
return false
}
// mapRule generic rule for maps.
type mapRule map[string]struct{}
// IsValid for the map rule satisfies whether it exists in the map.
func (m mapRule) IsValid(value string) bool {
_, ok := m[value]
return ok
}
// whitelist is a generic rule for whitelisting.
type whitelist struct {
rule
}
// IsValid for whitelist checks if the value is within the whitelist.
func (w whitelist) IsValid(value string) bool {
return w.rule.IsValid(value)
}
// blacklist is a generic rule for blacklisting.
type blacklist struct {
rule
}
// IsValid for whitelist checks if the value is within the whitelist.
func (b blacklist) IsValid(value string) bool {
return !b.rule.IsValid(value)
}
type patterns []string
// IsValid for patterns checks each pattern and returns if a match has
// been found.
func (p patterns) IsValid(value string) bool {
for _, pattern := range p {
if HasPrefixFold(value, pattern) {
return true
}
}
return false
}
// HasPrefixFold tests whether the string s begins with prefix, interpreted as UTF-8 strings,
// under Unicode case-folding.
func HasPrefixFold(s, prefix string) bool {
return len(s) >= len(prefix) && strings.EqualFold(s[0:len(prefix)], prefix)
}
// inclusiveRules rules allow for rules to depend on one another.
type inclusiveRules []rule
// IsValid will return true if all rules are true.
func (r inclusiveRules) IsValid(value string) bool {
for _, rule := range r {
if !rule.IsValid(value) {
return false
}
}
return true
}

View file

@ -1,7 +0,0 @@
package v4
// WithUnsignedPayload will enable and set the UnsignedPayload field to
// true of the signer.
func WithUnsignedPayload(v4 *Signer) {
v4.UnsignedPayload = true
}

View file

@ -1,14 +0,0 @@
//go:build go1.7
// +build go1.7
package v4
import (
"net/http"
"github.com/aws/aws-sdk-go/aws"
)
func requestContext(r *http.Request) aws.Context {
return r.Context()
}

View file

@ -1,63 +0,0 @@
package v4
import (
"encoding/hex"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
type credentialValueProvider interface {
Get() (credentials.Value, error)
}
// StreamSigner implements signing of event stream encoded payloads.
type StreamSigner struct {
region string
service string
credentials credentialValueProvider
prevSig []byte
}
// NewStreamSigner creates a SigV4 signer used to sign Event Stream encoded messages.
func NewStreamSigner(region, service string, seedSignature []byte, credentials *credentials.Credentials) *StreamSigner {
return &StreamSigner{
region: region,
service: service,
credentials: credentials,
prevSig: seedSignature,
}
}
// GetSignature takes an event stream encoded headers and payload and returns a signature.
func (s *StreamSigner) GetSignature(headers, payload []byte, date time.Time) ([]byte, error) {
credValue, err := s.credentials.Get()
if err != nil {
return nil, err
}
sigKey := deriveSigningKey(s.region, s.service, credValue.SecretAccessKey, date)
keyPath := buildSigningScope(s.region, s.service, date)
stringToSign := buildEventStreamStringToSign(headers, payload, s.prevSig, keyPath, date)
signature := hmacSHA256(sigKey, []byte(stringToSign))
s.prevSig = signature
return signature, nil
}
func buildEventStreamStringToSign(headers, payload, prevSig []byte, scope string, date time.Time) string {
return strings.Join([]string{
"AWS4-HMAC-SHA256-PAYLOAD",
formatTime(date),
scope,
hex.EncodeToString(prevSig),
hex.EncodeToString(hashSHA256(headers)),
hex.EncodeToString(hashSHA256(payload)),
}, "\n")
}

View file

@ -1,25 +0,0 @@
//go:build go1.5
// +build go1.5
package v4
import (
"net/url"
"strings"
)
func getURIPath(u *url.URL) string {
var uri string
if len(u.Opaque) > 0 {
uri = "/" + strings.Join(strings.Split(u.Opaque, "/")[3:], "/")
} else {
uri = u.EscapedPath()
}
if len(uri) == 0 {
uri = "/"
}
return uri
}

View file

@ -1,858 +0,0 @@
// Package v4 implements signing for AWS V4 signer
//
// Provides request signing for request that need to be signed with
// AWS V4 Signatures.
//
// # Standalone Signer
//
// Generally using the signer outside of the SDK should not require any additional
// logic when using Go v1.5 or higher. The signer does this by taking advantage
// of the URL.EscapedPath method. If your request URI requires additional escaping
// you many need to use the URL.Opaque to define what the raw URI should be sent
// to the service as.
//
// The signer will first check the URL.Opaque field, and use its value if set.
// The signer does require the URL.Opaque field to be set in the form of:
//
// "//<hostname>/<path>"
//
// // e.g.
// "//example.com/some/path"
//
// The leading "//" and hostname are required or the URL.Opaque escaping will
// not work correctly.
//
// If URL.Opaque is not set the signer will fallback to the URL.EscapedPath()
// method and using the returned value. If you're using Go v1.4 you must set
// URL.Opaque if the URI path needs escaping. If URL.Opaque is not set with
// Go v1.5 the signer will fallback to URL.Path.
//
// AWS v4 signature validation requires that the canonical string's URI path
// element must be the URI escaped form of the HTTP request's path.
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
//
// The Go HTTP client will perform escaping automatically on the request. Some
// of these escaping may cause signature validation errors because the HTTP
// request differs from the URI path or query that the signature was generated.
// https://golang.org/pkg/net/url/#URL.EscapedPath
//
// Because of this, it is recommended that when using the signer outside of the
// SDK that explicitly escaping the request prior to being signed is preferable,
// and will help prevent signature validation errors. This can be done by setting
// the URL.Opaque or URL.RawPath. The SDK will use URL.Opaque first and then
// call URL.EscapedPath() if Opaque is not set.
//
// If signing a request intended for HTTP2 server, and you're using Go 1.6.2
// through 1.7.4 you should use the URL.RawPath as the pre-escaped form of the
// request URL. https://github.com/golang/go/issues/16847 points to a bug in
// Go pre 1.8 that fails to make HTTP2 requests using absolute URL in the HTTP
// message. URL.Opaque generally will force Go to make requests with absolute URL.
// URL.RawPath does not do this, but RawPath must be a valid escaping of Path
// or url.EscapedPath will ignore the RawPath escaping.
//
// Test `TestStandaloneSign` provides a complete example of using the signer
// outside of the SDK and pre-escaping the URI path.
package v4
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/rest"
)
const (
authorizationHeader = "Authorization"
authHeaderSignatureElem = "Signature="
signatureQueryKey = "X-Amz-Signature"
authHeaderPrefix = "AWS4-HMAC-SHA256"
timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102"
awsV4Request = "aws4_request"
// emptyStringSHA256 is a SHA256 of an empty string.
emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
)
var ignoredPresignHeaders = rules{
blacklist{
mapRule{
authorizationHeader: struct{}{},
"User-Agent": struct{}{},
"X-Amzn-Trace-Id": struct{}{},
},
},
}
// drop User-Agent header to be compatible with aws sdk java v1.
var ignoredHeaders = rules{
blacklist{
mapRule{
authorizationHeader: struct{}{},
"X-Amzn-Trace-Id": struct{}{},
},
},
}
// requiredSignedHeaders is a whitelist for build canonical headers.
var requiredSignedHeaders = rules{
whitelist{
mapRule{
"Cache-Control": struct{}{},
"Content-Disposition": struct{}{},
"Content-Encoding": struct{}{},
"Content-Language": struct{}{},
"Content-Md5": struct{}{},
"Content-Type": struct{}{},
"Expires": struct{}{},
"If-Match": struct{}{},
"If-Modified-Since": struct{}{},
"If-None-Match": struct{}{},
"If-Unmodified-Since": struct{}{},
"Range": struct{}{},
"X-Amz-Acl": struct{}{},
"X-Amz-Copy-Source": struct{}{},
"X-Amz-Copy-Source-If-Match": struct{}{},
"X-Amz-Copy-Source-If-Modified-Since": struct{}{},
"X-Amz-Copy-Source-If-None-Match": struct{}{},
"X-Amz-Copy-Source-If-Unmodified-Since": struct{}{},
"X-Amz-Copy-Source-Range": struct{}{},
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Algorithm": struct{}{},
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key": struct{}{},
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key-Md5": struct{}{},
"X-Amz-Grant-Full-control": struct{}{},
"X-Amz-Grant-Read": struct{}{},
"X-Amz-Grant-Read-Acp": struct{}{},
"X-Amz-Grant-Write": struct{}{},
"X-Amz-Grant-Write-Acp": struct{}{},
"X-Amz-Metadata-Directive": struct{}{},
"X-Amz-Mfa": struct{}{},
"X-Amz-Request-Payer": struct{}{},
"X-Amz-Server-Side-Encryption": struct{}{},
"X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id": struct{}{},
"X-Amz-Server-Side-Encryption-Customer-Algorithm": struct{}{},
"X-Amz-Server-Side-Encryption-Customer-Key": struct{}{},
"X-Amz-Server-Side-Encryption-Customer-Key-Md5": struct{}{},
"X-Amz-Storage-Class": struct{}{},
"X-Amz-Tagging": struct{}{},
"X-Amz-Website-Redirect-Location": struct{}{},
"X-Amz-Content-Sha256": struct{}{},
},
},
patterns{"X-Amz-Meta-"},
}
// allowedHoisting is a whitelist for build query headers. The boolean value
// represents whether or not it is a pattern.
var allowedQueryHoisting = inclusiveRules{
blacklist{requiredSignedHeaders},
patterns{"X-Amz-"},
}
// Signer applies AWS v4 signing to given request. Use this to sign requests
// that need to be signed with AWS V4 Signatures.
type Signer struct {
// The authentication credentials the request will be signed against.
// This value must be set to sign requests.
Credentials *credentials.Credentials
// Sets the log level the signer should use when reporting information to
// the logger. If the logger is nil nothing will be logged. See
// aws.LogLevelType for more information on available logging levels
//
// By default nothing will be logged.
Debug aws.LogLevelType
// The logger loging information will be written to. If there the logger
// is nil, nothing will be logged.
Logger aws.Logger
// Disables the Signer's moving HTTP header key/value pairs from the HTTP
// request header to the request's query string. This is most commonly used
// with pre-signed requests preventing headers from being added to the
// request's query string.
DisableHeaderHoisting bool
// Disables the automatic escaping of the URI path of the request for the
// siganture's canonical string's path. For services that do not need additional
// escaping then use this to disable the signer escaping the path.
//
// S3 is an example of a service that does not need additional escaping.
//
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
DisableURIPathEscaping bool
// Disables the automatical setting of the HTTP request's Body field with the
// io.ReadSeeker passed in to the signer. This is useful if you're using a
// custom wrapper around the body for the io.ReadSeeker and want to preserve
// the Body value on the Request.Body.
//
// This does run the risk of signing a request with a body that will not be
// sent in the request. Need to ensure that the underlying data of the Body
// values are the same.
DisableRequestBodyOverwrite bool
// currentTimeFn returns the time value which represents the current time.
// This value should only be used for testing. If it is nil the default
// time.Now will be used.
currentTimeFn func() time.Time
// UnsignedPayload will prevent signing of the payload. This will only
// work for services that have support for this.
UnsignedPayload bool
}
// NewSigner returns a Signer pointer configured with the credentials and optional
// option values provided. If not options are provided the Signer will use its
// default configuration.
func NewSigner(credentials *credentials.Credentials, options ...func(*Signer)) *Signer {
v4 := &Signer{
Credentials: credentials,
}
for _, option := range options {
option(v4)
}
return v4
}
type signingCtx struct {
ServiceName string
Region string
Request *http.Request
Body io.ReadSeeker
Query url.Values
Time time.Time
ExpireTime time.Duration
SignedHeaderVals http.Header
DisableURIPathEscaping bool
credValues credentials.Value
isPresign bool
unsignedPayload bool
bodyDigest string
signedHeaders string
canonicalHeaders string
canonicalString string
credentialString string
stringToSign string
signature string
}
// Sign signs AWS v4 requests with the provided body, service name, region the
// request is made to, and time the request is signed at. The signTime allows
// you to specify that a request is signed for the future, and cannot be
// used until then.
//
// Returns a list of HTTP headers that were included in the signature or an
// error if signing the request failed. Generally for signed requests this value
// is not needed as the full request context will be captured by the http.Request
// value. It is included for reference though.
//
// Sign will set the request's Body to be the `body` parameter passed in. If
// the body is not already an io.ReadCloser, it will be wrapped within one. If
// a `nil` body parameter passed to Sign, the request's Body field will be
// also set to nil. Its important to note that this functionality will not
// change the request's ContentLength of the request.
//
// Sign differs from Presign in that it will sign the request using HTTP
// header values. This type of signing is intended for http.Request values that
// will not be shared, or are shared in a way the header values on the request
// will not be lost.
//
// The requests body is an io.ReadSeeker so the SHA256 of the body can be
// generated. To bypass the signer computing the hash you can set the
// "X-Amz-Content-Sha256" header with a precomputed value. The signer will
// only compute the hash if the request header value is empty.
func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
return v4.signWithBody(r, body, service, region, 0, false, signTime)
}
// Presign signs AWS v4 requests with the provided body, service name, region
// the request is made to, and time the request is signed at. The signTime
// allows you to specify that a request is signed for the future, and cannot
// be used until then.
//
// Returns a list of HTTP headers that were included in the signature or an
// error if signing the request failed. For presigned requests these headers
// and their values must be included on the HTTP request when it is made. This
// is helpful to know what header values need to be shared with the party the
// presigned request will be distributed to.
//
// Presign differs from Sign in that it will sign the request using query string
// instead of header values. This allows you to share the Presigned Request's
// URL with third parties, or distribute it throughout your system with minimal
// dependencies.
//
// Presign also takes an exp value which is the duration the
// signed request will be valid after the signing time. This is allows you to
// set when the request will expire.
//
// The requests body is an io.ReadSeeker so the SHA256 of the body can be
// generated. To bypass the signer computing the hash you can set the
// "X-Amz-Content-Sha256" header with a precomputed value. The signer will
// only compute the hash if the request header value is empty.
//
// Presigning a S3 request will not compute the body's SHA256 hash by default.
// This is done due to the general use case for S3 presigned URLs is to share
// PUT/GET capabilities. If you would like to include the body's SHA256 in the
// presigned request's signature you can set the "X-Amz-Content-Sha256"
// HTTP header and that will be included in the request's signature.
func (v4 Signer) Presign(r *http.Request, body io.ReadSeeker, service, region string, exp time.Duration, signTime time.Time) (http.Header, error) {
return v4.signWithBody(r, body, service, region, exp, true, signTime)
}
func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, region string, exp time.Duration, isPresign bool, signTime time.Time) (http.Header, error) {
currentTimeFn := v4.currentTimeFn
if currentTimeFn == nil {
currentTimeFn = time.Now
}
ctx := &signingCtx{
Request: r,
Body: body,
Query: r.URL.Query(),
Time: signTime,
ExpireTime: exp,
isPresign: isPresign,
ServiceName: service,
Region: region,
DisableURIPathEscaping: v4.DisableURIPathEscaping,
unsignedPayload: v4.UnsignedPayload,
}
for key := range ctx.Query {
sort.Strings(ctx.Query[key])
}
if ctx.isRequestSigned() {
ctx.Time = currentTimeFn()
ctx.handlePresignRemoval()
}
var err error
ctx.credValues, err = v4.Credentials.GetWithContext(requestContext(r))
if err != nil {
return http.Header{}, err
}
ctx.sanitizeHostForHeader()
ctx.assignAmzQueryValues()
if err := ctx.build(v4.DisableHeaderHoisting); err != nil {
return nil, err
}
// If the request is not presigned the body should be attached to it. This
// prevents the confusion of wanting to send a signed request without
// the body the request was signed for attached.
if !(v4.DisableRequestBodyOverwrite || ctx.isPresign) {
var reader io.ReadCloser
if body != nil {
var ok bool
if reader, ok = body.(io.ReadCloser); !ok {
reader = io.NopCloser(body)
}
}
r.Body = reader
}
if v4.Debug.Matches(aws.LogDebugWithSigning) {
v4.logSigningInfo(ctx)
}
return ctx.SignedHeaderVals, nil
}
func (ctx *signingCtx) sanitizeHostForHeader() {
request.SanitizeHostForHeader(ctx.Request)
}
func (ctx *signingCtx) handlePresignRemoval() {
if !ctx.isPresign {
return
}
// The credentials have expired for this request. The current signing
// is invalid, and needs to be request because the request will fail.
ctx.removePresign()
// Update the request's query string to ensure the values stays in
// sync in the case retrieving the new credentials fails.
ctx.Request.URL.RawQuery = ctx.Query.Encode()
}
func (ctx *signingCtx) assignAmzQueryValues() {
if ctx.isPresign {
ctx.Query.Set("X-Amz-Algorithm", authHeaderPrefix)
if ctx.credValues.SessionToken != "" {
ctx.Query.Set("X-Amz-Security-Token", ctx.credValues.SessionToken)
} else {
ctx.Query.Del("X-Amz-Security-Token")
}
return
}
if ctx.credValues.SessionToken != "" {
ctx.Request.Header.Set("X-Amz-Security-Token", ctx.credValues.SessionToken)
}
}
// SignRequestHandler is a named request handler the SDK will use to sign
// service client request with using the V4 signature.
var SignRequestHandler = request.NamedHandler{
Name: "v4.SignRequestHandler", Fn: SignSDKRequest,
}
// SignSDKRequest signs an AWS request with the V4 signature. This
// request handler should only be used with the SDK's built in service client's
// API operation requests.
//
// This function should not be used on its on its own, but in conjunction with
// an AWS service client's API operation call. To sign a standalone request
// not created by a service client's API operation method use the "Sign" or
// "Presign" functions of the "Signer" type.
//
// If the credentials of the request's config are set to
// credentials.AnonymousCredentials the request will not be signed.
func SignSDKRequest(req *request.Request) {
SignSDKRequestWithCurrentTime(req, time.Now)
}
// BuildNamedHandler will build a generic handler for signing.
func BuildNamedHandler(name string, opts ...func(*Signer)) request.NamedHandler {
return request.NamedHandler{
Name: name,
Fn: func(req *request.Request) {
SignSDKRequestWithCurrentTime(req, time.Now, opts...)
},
}
}
// SignSDKRequestWithCurrentTime will sign the SDK's request using the time
// function passed in. Behaves the same as SignSDKRequest with the exception
// the request is signed with the value returned by the current time function.
func SignSDKRequestWithCurrentTime(req *request.Request, curTimeFn func() time.Time, opts ...func(*Signer)) {
// If the request does not need to be signed ignore the signing of the
// request if the AnonymousCredentials object is used.
if req.Config.Credentials == credentials.AnonymousCredentials {
return
}
region := req.ClientInfo.SigningRegion
if region == "" {
region = aws.StringValue(req.Config.Region)
}
name := req.ClientInfo.SigningName
if name == "" {
name = req.ClientInfo.ServiceName
}
v4 := NewSigner(req.Config.Credentials, func(v4 *Signer) {
v4.Debug = req.Config.LogLevel.Value()
v4.Logger = req.Config.Logger
v4.DisableHeaderHoisting = req.NotHoist
v4.currentTimeFn = curTimeFn
if name == "s3" {
// S3 service should not have any escaping applied
v4.DisableURIPathEscaping = true
}
// Prevents setting the HTTPRequest's Body. Since the Body could be
// wrapped in a custom io.Closer that we do not want to be stompped
// on top of by the signer.
v4.DisableRequestBodyOverwrite = true
})
for _, opt := range opts {
opt(v4)
}
curTime := curTimeFn()
signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.GetBody(),
name, region, req.ExpireTime, req.ExpireTime > 0, curTime,
)
if err != nil {
req.Error = err
req.SignedHeaderVals = nil
return
}
req.SignedHeaderVals = signedHeaders
req.LastSignedAt = curTime
}
const logSignInfoMsg = `DEBUG: Request Signature:
---[ CANONICAL STRING ]-----------------------------
%s
---[ STRING TO SIGN ]--------------------------------
%s%s
-----------------------------------------------------`
const logSignedURLMsg = `
---[ SIGNED URL ]------------------------------------
%s`
func (v4 *Signer) logSigningInfo(ctx *signingCtx) {
signedURLMsg := ""
if ctx.isPresign {
signedURLMsg = fmt.Sprintf(logSignedURLMsg, ctx.Request.URL.String())
}
msg := fmt.Sprintf(logSignInfoMsg, ctx.canonicalString, ctx.stringToSign, signedURLMsg)
v4.Logger.Log(msg)
}
func (ctx *signingCtx) build(disableHeaderHoisting bool) error {
ctx.buildTime() // no depends
ctx.buildCredentialString() // no depends
if err := ctx.buildBodyDigest(); err != nil {
return err
}
unsignedHeaders := ctx.Request.Header
if ctx.isPresign {
if !disableHeaderHoisting {
var urlValues url.Values
urlValues, unsignedHeaders = buildQuery(allowedQueryHoisting, unsignedHeaders) // no depends
for k := range urlValues {
ctx.Query[k] = urlValues[k]
}
}
}
if ctx.isPresign {
ctx.buildCanonicalHeaders(ignoredPresignHeaders, unsignedHeaders)
} else {
ctx.buildCanonicalHeaders(ignoredHeaders, unsignedHeaders)
}
ctx.buildCanonicalString() // depends on canon headers / signed headers
ctx.buildStringToSign() // depends on canon string
ctx.buildSignature() // depends on string to sign
if ctx.isPresign {
ctx.Request.URL.RawQuery += "&" + signatureQueryKey + "=" + ctx.signature
} else {
parts := []string{
authHeaderPrefix + " Credential=" + ctx.credValues.AccessKeyID + "/" + ctx.credentialString,
"SignedHeaders=" + ctx.signedHeaders,
authHeaderSignatureElem + ctx.signature,
}
ctx.Request.Header.Set(authorizationHeader, strings.Join(parts, ", "))
}
return nil
}
// GetSignedRequestSignature attempts to extract the signature of the request.
// Returning an error if the request is unsigned, or unable to extract the
// signature.
func GetSignedRequestSignature(r *http.Request) ([]byte, error) {
if auth := r.Header.Get(authorizationHeader); len(auth) != 0 {
ps := strings.Split(auth, ", ")
for _, p := range ps {
if idx := strings.Index(p, authHeaderSignatureElem); idx >= 0 {
sig := p[len(authHeaderSignatureElem):]
if len(sig) == 0 {
return nil, fmt.Errorf("invalid request signature authorization header")
}
return hex.DecodeString(sig)
}
}
}
if sig := r.URL.Query().Get("X-Amz-Signature"); len(sig) != 0 {
return hex.DecodeString(sig)
}
return nil, fmt.Errorf("request not signed")
}
func (ctx *signingCtx) buildTime() {
if ctx.isPresign {
duration := int64(ctx.ExpireTime / time.Second)
ctx.Query.Set("X-Amz-Date", formatTime(ctx.Time))
ctx.Query.Set("X-Amz-Expires", strconv.FormatInt(duration, 10))
} else {
ctx.Request.Header.Set("X-Amz-Date", formatTime(ctx.Time))
}
}
func (ctx *signingCtx) buildCredentialString() {
ctx.credentialString = buildSigningScope(ctx.Region, ctx.ServiceName, ctx.Time)
if ctx.isPresign {
ctx.Query.Set("X-Amz-Credential", ctx.credValues.AccessKeyID+"/"+ctx.credentialString)
}
}
func buildQuery(r rule, header http.Header) (url.Values, http.Header) {
query := url.Values{}
unsignedHeaders := http.Header{}
for k, h := range header {
if r.IsValid(k) {
query[k] = h
} else {
unsignedHeaders[k] = h
}
}
return query, unsignedHeaders
}
func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
var headers []string
headers = append(headers, "host")
for k, v := range header {
if !r.IsValid(k) {
continue // ignored header
}
if ctx.SignedHeaderVals == nil {
ctx.SignedHeaderVals = make(http.Header)
}
lowerCaseKey := strings.ToLower(k)
if _, ok := ctx.SignedHeaderVals[lowerCaseKey]; ok {
// include additional values
ctx.SignedHeaderVals[lowerCaseKey] = append(ctx.SignedHeaderVals[lowerCaseKey], v...)
continue
}
headers = append(headers, lowerCaseKey)
ctx.SignedHeaderVals[lowerCaseKey] = v
}
sort.Strings(headers)
ctx.signedHeaders = strings.Join(headers, ";")
if ctx.isPresign {
ctx.Query.Set("X-Amz-SignedHeaders", ctx.signedHeaders)
}
headerValues := make([]string, len(headers))
for i, k := range headers {
if k == "host" {
if ctx.Request.Host != "" {
headerValues[i] = "host:" + ctx.Request.Host
} else {
headerValues[i] = "host:" + ctx.Request.URL.Host
}
} else {
headerValues[i] = k + ":" +
strings.Join(ctx.SignedHeaderVals[k], ",")
}
}
stripExcessSpaces(headerValues)
ctx.canonicalHeaders = strings.Join(headerValues, "\n")
}
func (ctx *signingCtx) buildCanonicalString() {
ctx.Request.URL.RawQuery = strings.Replace(ctx.Query.Encode(), "+", "%20", -1)
uri := getURIPath(ctx.Request.URL)
if !ctx.DisableURIPathEscaping {
uri = rest.EscapePath(uri, false)
}
ctx.canonicalString = strings.Join([]string{
ctx.Request.Method,
uri,
ctx.Request.URL.RawQuery,
ctx.canonicalHeaders + "\n",
ctx.signedHeaders,
ctx.bodyDigest,
}, "\n")
}
func (ctx *signingCtx) buildStringToSign() {
ctx.stringToSign = strings.Join([]string{
authHeaderPrefix,
formatTime(ctx.Time),
ctx.credentialString,
hex.EncodeToString(hashSHA256([]byte(ctx.canonicalString))),
}, "\n")
}
func (ctx *signingCtx) buildSignature() {
creds := deriveSigningKey(ctx.Region, ctx.ServiceName, ctx.credValues.SecretAccessKey, ctx.Time)
signature := hmacSHA256(creds, []byte(ctx.stringToSign))
ctx.signature = hex.EncodeToString(signature)
}
func (ctx *signingCtx) buildBodyDigest() error {
hash := ctx.Request.Header.Get("X-Amz-Content-Sha256")
if hash == "" {
includeSHA256Header := ctx.unsignedPayload ||
ctx.ServiceName == "s3" ||
ctx.ServiceName == "glacier"
s3Presign := ctx.isPresign && ctx.ServiceName == "s3"
if ctx.unsignedPayload || s3Presign {
hash = "UNSIGNED-PAYLOAD"
includeSHA256Header = !s3Presign
} else if ctx.Body == nil {
hash = emptyStringSHA256
} else {
if !aws.IsReaderSeekable(ctx.Body) {
return fmt.Errorf("cannot use unseekable request body %T, for signed request with body", ctx.Body)
}
hashBytes, err := makeSha256Reader(ctx.Body)
if err != nil {
return err
}
hash = hex.EncodeToString(hashBytes)
}
if includeSHA256Header {
ctx.Request.Header.Set("X-Amz-Content-Sha256", hash)
}
}
ctx.bodyDigest = hash
return nil
}
// isRequestSigned returns if the request is currently signed or presigned.
func (ctx *signingCtx) isRequestSigned() bool {
if ctx.isPresign && ctx.Query.Get("X-Amz-Signature") != "" {
return true
}
if ctx.Request.Header.Get("Authorization") != "" {
return true
}
return false
}
// unsign removes signing flags for both signed and presigned requests.
func (ctx *signingCtx) removePresign() {
ctx.Query.Del("X-Amz-Algorithm")
ctx.Query.Del("X-Amz-Signature")
ctx.Query.Del("X-Amz-Security-Token")
ctx.Query.Del("X-Amz-Date")
ctx.Query.Del("X-Amz-Expires")
ctx.Query.Del("X-Amz-Credential")
ctx.Query.Del("X-Amz-SignedHeaders")
}
func hmacSHA256(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(data)
return hash.Sum(nil)
}
func hashSHA256(data []byte) []byte {
hash := sha256.New()
hash.Write(data)
return hash.Sum(nil)
}
func makeSha256Reader(reader io.ReadSeeker) (hashBytes []byte, err error) {
hash := sha256.New()
start, err := reader.Seek(0, io.SeekCurrent)
if err != nil {
return nil, err
}
defer func() {
// ensure error is return if unable to seek back to start of payload.
_, err = reader.Seek(start, io.SeekStart)
}()
// Use CopyN to avoid allocating the 32KB buffer in io.Copy for bodies
// smaller than 32KB. Fall back to io.Copy if we fail to determine the size.
size, err := aws.SeekerLen(reader)
if err != nil {
_, _ = io.Copy(hash, reader)
} else {
_, _ = io.CopyN(hash, reader, size)
}
return hash.Sum(nil), nil
}
const doubleSpace = " "
// stripExcessSpaces will rewrite the passed in slice's string values to not
// contain multiple side-by-side spaces.
//
//nolint:revive
func stripExcessSpaces(vals []string) {
var j, k, l, m, spaces int
for i, str := range vals {
// Trim trailing spaces
for j = len(str) - 1; j >= 0 && str[j] == ' '; j-- {
}
// Trim leading spaces
for k = 0; k < j && str[k] == ' '; k++ {
}
str = str[k : j+1]
// Strip multiple spaces.
j = strings.Index(str, doubleSpace)
if j < 0 {
vals[i] = str
continue
}
buf := []byte(str)
for k, m, l = j, j, len(buf); k < l; k++ {
if buf[k] == ' ' {
if spaces == 0 {
// First space.
buf[m] = buf[k]
m++
}
spaces++
} else {
// End of multiple spaces.
spaces = 0
buf[m] = buf[k]
m++
}
}
vals[i] = string(buf[:m])
}
}
func buildSigningScope(region, service string, dt time.Time) string {
return strings.Join([]string{
formatShortTime(dt),
region,
service,
awsV4Request,
}, "/")
}
func deriveSigningKey(region, service, secretKey string, dt time.Time) []byte {
hmacDate := hmacSHA256([]byte("AWS4"+secretKey), []byte(formatShortTime(dt)))
hmacRegion := hmacSHA256(hmacDate, []byte(region))
hmacService := hmacSHA256(hmacRegion, []byte(service))
signingKey := hmacSHA256(hmacService, []byte(awsV4Request))
return signingKey
}
func formatShortTime(dt time.Time) string {
return dt.UTC().Format(shortTimeFormat)
}
func formatTime(dt time.Time) string {
return dt.UTC().Format(timeFormat)
}

View file

@ -0,0 +1,140 @@
package v4a
import (
"context"
"crypto/ecdsa"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
)
// Credentials is Context, ECDSA, and Optional Session Token that can be used
// to sign requests using SigV4a
type Credentials struct {
Context string
PrivateKey *ecdsa.PrivateKey
SessionToken string
// Time the credentials will expire.
CanExpire bool
Expires time.Time
}
// Expired returns if the credentials have expired.
func (v Credentials) Expired() bool {
if v.CanExpire {
return !v.Expires.After(time.Now())
}
return false
}
// HasKeys returns if the credentials keys are set.
func (v Credentials) HasKeys() bool {
return len(v.Context) > 0 && v.PrivateKey != nil
}
// SymmetricCredentialAdaptor wraps a SigV4 AccessKey/SecretKey provider and adapts the credentials
// to a ECDSA PrivateKey for signing with SiV4a
type SymmetricCredentialAdaptor struct {
SymmetricProvider aws.CredentialsProvider
asymmetric atomic.Value
m sync.Mutex
}
// Retrieve retrieves symmetric credentials from the underlying provider.
func (s *SymmetricCredentialAdaptor) Retrieve(ctx context.Context) (aws.Credentials, error) {
symCreds, err := s.retrieveFromSymmetricProvider(ctx)
if err != nil {
return aws.Credentials{}, err
}
if asymCreds := s.getCreds(); asymCreds == nil {
return symCreds, nil
}
s.m.Lock()
defer s.m.Unlock()
asymCreds := s.getCreds()
if asymCreds == nil {
return symCreds, nil
}
// if the context does not match the access key id clear it
if asymCreds.Context != symCreds.AccessKeyID {
s.asymmetric.Store((*Credentials)(nil))
}
return symCreds, nil
}
// RetrievePrivateKey returns credentials suitable for SigV4a signing
func (s *SymmetricCredentialAdaptor) RetrievePrivateKey(ctx context.Context) (Credentials, error) {
if asymCreds := s.getCreds(); asymCreds != nil {
return *asymCreds, nil
}
s.m.Lock()
defer s.m.Unlock()
if asymCreds := s.getCreds(); asymCreds != nil {
return *asymCreds, nil
}
symmetricCreds, err := s.retrieveFromSymmetricProvider(ctx)
if err != nil {
return Credentials{}, fmt.Errorf("failed to retrieve symmetric credentials: %v", err)
}
privateKey, err := deriveKeyFromAccessKeyPair(symmetricCreds.AccessKeyID, symmetricCreds.SecretAccessKey)
if err != nil {
return Credentials{}, fmt.Errorf("failed to derive assymetric key from credentials")
}
creds := Credentials{
Context: symmetricCreds.AccessKeyID,
PrivateKey: privateKey,
SessionToken: symmetricCreds.SessionToken,
CanExpire: symmetricCreds.CanExpire,
Expires: symmetricCreds.Expires,
}
s.asymmetric.Store(&creds)
return creds, nil
}
func (s *SymmetricCredentialAdaptor) getCreds() *Credentials {
v := s.asymmetric.Load()
if v == nil {
return nil
}
c := v.(*Credentials)
if c != nil && c.HasKeys() && !c.Expired() {
return c
}
return nil
}
func (s *SymmetricCredentialAdaptor) retrieveFromSymmetricProvider(ctx context.Context) (aws.Credentials, error) {
credentials, err := s.SymmetricProvider.Retrieve(ctx)
if err != nil {
return aws.Credentials{}, err
}
return credentials, nil
}
// CredentialsProvider is the interface for a provider to retrieve credentials
// to sign requests with.
type CredentialsProvider interface {
RetrievePrivateKey(context.Context) (Credentials, error)
}

View file

@ -0,0 +1,77 @@
package v4a
import (
"context"
"fmt"
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
)
type rotatingCredsProvider struct {
count int
fail chan struct{}
}
func (r *rotatingCredsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
select {
case <-r.fail:
return aws.Credentials{}, fmt.Errorf("rotatingCredsProvider error")
default:
}
credentials := aws.Credentials{
AccessKeyID: fmt.Sprintf("ACCESS_KEY_ID_%d", r.count),
SecretAccessKey: fmt.Sprintf("SECRET_ACCESS_KEY_%d", r.count),
SessionToken: fmt.Sprintf("SESSION_TOKEN_%d", r.count),
}
return credentials, nil
}
func TestSymmetricCredentialAdaptor(t *testing.T) {
provider := &rotatingCredsProvider{
count: 0,
fail: make(chan struct{}),
}
adaptor := &SymmetricCredentialAdaptor{SymmetricProvider: provider}
if symCreds, err := adaptor.Retrieve(context.Background()); err != nil {
t.Fatalf("expect no error, got %v", err)
} else if !symCreds.HasKeys() {
t.Fatalf("expect symmetric credentials to have keys")
}
if load := adaptor.asymmetric.Load(); load != nil {
t.Errorf("expect asymmetric credentials to be nil")
}
if asymCreds, err := adaptor.RetrievePrivateKey(context.Background()); err != nil {
t.Fatalf("expect no error, got %v", err)
} else if !asymCreds.HasKeys() {
t.Fatalf("expect asymmetric credentials to have keys")
}
if _, err := adaptor.Retrieve(context.Background()); err != nil {
t.Fatalf("expect no error, got %v", err)
}
if load := adaptor.asymmetric.Load(); load.(*Credentials) == nil {
t.Errorf("expect asymmetric credentials to be not nil")
}
provider.count++
if _, err := adaptor.Retrieve(context.Background()); err != nil {
t.Fatalf("expect no error, got %v", err)
}
if load := adaptor.asymmetric.Load(); load.(*Credentials) != nil {
t.Errorf("expect asymmetric credentials to be nil")
}
close(provider.fail) // All requests to the original provider will now fail from this point-on.
_, err := adaptor.Retrieve(context.Background())
if err == nil {
t.Error("expect error, got nil")
}
}

View file

@ -0,0 +1,17 @@
package v4a
import "fmt"
// SigningError indicates an error condition occurred while performing SigV4a signing
type SigningError struct {
Err error
}
func (e *SigningError) Error() string {
return fmt.Sprintf("failed to sign request: %v", e.Err)
}
// Unwrap returns the underlying error cause
func (e *SigningError) Unwrap() error {
return e.Err
}

View file

@ -0,0 +1,30 @@
package crypto
import "fmt"
// ConstantTimeByteCompare is a constant-time byte comparison of x and y. This function performs an absolute comparison
// if the two byte slices assuming they represent a big-endian number.
//
// error if len(x) != len(y)
// -1 if x < y
// 0 if x == y
// +1 if x > y
func ConstantTimeByteCompare(x, y []byte) (int, error) {
if len(x) != len(y) {
return 0, fmt.Errorf("slice lengths do not match")
}
xLarger, yLarger := 0, 0
for i := 0; i < len(x); i++ {
xByte, yByte := int(x[i]), int(y[i])
x := ((yByte - xByte) >> 8) & 1
y := ((xByte - yByte) >> 8) & 1
xLarger |= x &^ yLarger
yLarger |= y &^ xLarger
}
return xLarger - yLarger, nil
}

View file

@ -0,0 +1,60 @@
package crypto
import (
"bytes"
"math/big"
"testing"
)
func TestConstantTimeByteCompare(t *testing.T) {
cases := []struct {
x, y []byte
r int
expectErr bool
}{
{x: []byte{}, y: []byte{}, r: 0},
{x: []byte{40}, y: []byte{30}, r: 1},
{x: []byte{30}, y: []byte{40}, r: -1},
{x: []byte{60, 40, 30, 10, 20}, y: []byte{50, 30, 20, 0, 10}, r: 1},
{x: []byte{50, 30, 20, 0, 10}, y: []byte{60, 40, 30, 10, 20}, r: -1},
{x: nil, y: []byte{}, r: 0},
{x: []byte{}, y: nil, r: 0},
{x: []byte{}, y: []byte{10}, expectErr: true},
{x: []byte{10}, y: []byte{}, expectErr: true},
{x: []byte{10, 20}, y: []byte{10}, expectErr: true},
}
for _, tt := range cases {
compare, err := ConstantTimeByteCompare(tt.x, tt.y)
if (err != nil) != tt.expectErr {
t.Fatalf("expectErr=%v, got %v", tt.expectErr, err)
}
if e, a := tt.r, compare; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
}
func BenchmarkConstantTimeCompare(b *testing.B) {
x, y := big.NewInt(1023), big.NewInt(1024)
b.ResetTimer()
for i := 0; i < b.N; i++ {
ConstantTimeByteCompare(x.Bytes(), y.Bytes())
}
}
func BenchmarkCompare(b *testing.B) {
x, y := big.NewInt(1023).Bytes(), big.NewInt(1024).Bytes()
b.ResetTimer()
for i := 0; i < b.N; i++ {
bytes.Compare(x, y)
}
}
func mustBigInt(s string) *big.Int {
b, ok := (&big.Int{}).SetString(s, 16)
if !ok {
panic("can't parse as big.Int")
}
return b
}

View file

@ -0,0 +1,113 @@
package crypto
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/hmac"
"encoding/asn1"
"encoding/binary"
"fmt"
"hash"
"math"
"math/big"
)
type ecdsaSignature struct {
R, S *big.Int
}
// ECDSAKey takes the given elliptic curve, and private key (d) byte slice
// and returns the private ECDSA key.
func ECDSAKey(curve elliptic.Curve, d []byte) *ecdsa.PrivateKey {
return ECDSAKeyFromPoint(curve, (&big.Int{}).SetBytes(d))
}
// ECDSAKeyFromPoint takes the given elliptic curve and point and returns the
// private and public keypair
func ECDSAKeyFromPoint(curve elliptic.Curve, d *big.Int) *ecdsa.PrivateKey {
pX, pY := curve.ScalarBaseMult(d.Bytes())
privKey := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: curve,
X: pX,
Y: pY,
},
D: d,
}
return privKey
}
// ECDSAPublicKey takes the provide curve and (x, y) coordinates and returns
// *ecdsa.PublicKey. Returns an error if the given points are not on the curve.
func ECDSAPublicKey(curve elliptic.Curve, x, y []byte) (*ecdsa.PublicKey, error) {
xPoint := (&big.Int{}).SetBytes(x)
yPoint := (&big.Int{}).SetBytes(y)
if !curve.IsOnCurve(xPoint, yPoint) {
return nil, fmt.Errorf("point(%v, %v) is not on the given curve", xPoint.String(), yPoint.String())
}
return &ecdsa.PublicKey{
Curve: curve,
X: xPoint,
Y: yPoint,
}, nil
}
// VerifySignature takes the provided public key, hash, and asn1 encoded signature and returns
// whether the given signature is valid.
func VerifySignature(key *ecdsa.PublicKey, hash []byte, signature []byte) (bool, error) {
var ecdsaSignature ecdsaSignature
_, err := asn1.Unmarshal(signature, &ecdsaSignature)
if err != nil {
return false, err
}
return ecdsa.Verify(key, hash, ecdsaSignature.R, ecdsaSignature.S), nil
}
// HMACKeyDerivation provides an implementation of a NIST-800-108 of a KDF (Key Derivation Function) in Counter Mode.
// For the purposes of this implantation HMAC is used as the PRF (Pseudorandom function), where the value of
// `r` is defined as a 4 byte counter.
func HMACKeyDerivation(hash func() hash.Hash, bitLen int, key []byte, label, context []byte) ([]byte, error) {
// verify that we won't overflow the counter
n := int64(math.Ceil((float64(bitLen) / 8) / float64(hash().Size())))
if n > 0x7FFFFFFF {
return nil, fmt.Errorf("unable to derive key of size %d using 32-bit counter", bitLen)
}
// verify the requested bit length is not larger then the length encoding size
if int64(bitLen) > 0x7FFFFFFF {
return nil, fmt.Errorf("bitLen is greater than 32-bits")
}
fixedInput := bytes.NewBuffer(nil)
fixedInput.Write(label)
fixedInput.WriteByte(0x00)
fixedInput.Write(context)
if err := binary.Write(fixedInput, binary.BigEndian, int32(bitLen)); err != nil {
return nil, fmt.Errorf("failed to write bit length to fixed input string: %v", err)
}
var output []byte
h := hmac.New(hash, key)
for i := int64(1); i <= n; i++ {
h.Reset()
if err := binary.Write(h, binary.BigEndian, int32(i)); err != nil {
return nil, err
}
_, err := h.Write(fixedInput.Bytes())
if err != nil {
return nil, err
}
output = append(output, h.Sum(nil)...)
}
return output[:bitLen/8], nil
}

View file

@ -0,0 +1,277 @@
package crypto
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"io"
"testing"
)
func TestECDSAPublicKeyDerivation_P256(t *testing.T) {
d := []byte{
0xc9, 0x80, 0x68, 0x98, 0xa0, 0x33, 0x49, 0x16, 0xc8, 0x60, 0x74, 0x88, 0x80, 0xa5, 0x41, 0xf0,
0x93, 0xb5, 0x79, 0xa9, 0xb1, 0xf3, 0x29, 0x34, 0xd8, 0x6c, 0x36, 0x3c, 0x39, 0x80, 0x03, 0x57,
}
x := []byte{
0xd0, 0x72, 0x0d, 0xc6, 0x91, 0xaa, 0x80, 0x09, 0x6b, 0xa3, 0x2f, 0xed, 0x1c, 0xb9, 0x7c, 0x2b,
0x62, 0x06, 0x90, 0xd0, 0x6d, 0xe0, 0x31, 0x7b, 0x86, 0x18, 0xd5, 0xce, 0x65, 0xeb, 0x72, 0x8f,
}
y := []byte{
0x96, 0x81, 0xb5, 0x17, 0xb1, 0xcd, 0xa1, 0x7d, 0x0d, 0x83, 0xd3, 0x35, 0xd9, 0xc4, 0xa8, 0xa9,
0xa9, 0xb0, 0xb1, 0xb3, 0xc7, 0x10, 0x6d, 0x8f, 0x3c, 0x72, 0xbc, 0x50, 0x93, 0xdc, 0x27, 0x5f,
}
testKeyDerivation(t, elliptic.P256(), d, x, y)
}
func TestECDSAPublicKeyDerivation_P384(t *testing.T) {
d := []byte{
0x53, 0x94, 0xf7, 0x97, 0x3e, 0xa8, 0x68, 0xc5, 0x2b, 0xf3, 0xff, 0x8d, 0x8c, 0xee, 0xb4, 0xdb,
0x90, 0xa6, 0x83, 0x65, 0x3b, 0x12, 0x48, 0x5d, 0x5f, 0x62, 0x7c, 0x3c, 0xe5, 0xab, 0xd8, 0x97,
0x8f, 0xc9, 0x67, 0x3d, 0x14, 0xa7, 0x1d, 0x92, 0x57, 0x47, 0x93, 0x16, 0x62, 0x49, 0x3c, 0x37,
}
x := []byte{
0xfd, 0x3c, 0x84, 0xe5, 0x68, 0x9b, 0xed, 0x27, 0x0e, 0x60, 0x1b, 0x3d, 0x80, 0xf9, 0x0d, 0x67,
0xa9, 0xae, 0x45, 0x1c, 0xce, 0x89, 0x0f, 0x53, 0xe5, 0x83, 0x22, 0x9a, 0xd0, 0xe2, 0xee, 0x64,
0x56, 0x11, 0xfa, 0x99, 0x36, 0xdf, 0xa4, 0x53, 0x06, 0xec, 0x18, 0x06, 0x67, 0x74, 0xaa, 0x24,
}
y := []byte{
0xb8, 0x3c, 0xa4, 0x12, 0x6c, 0xfc, 0x4c, 0x4d, 0x1d, 0x18, 0xa4, 0xb6, 0xc2, 0x1c, 0x7f, 0x69,
0x9d, 0x51, 0x23, 0xdd, 0x9c, 0x24, 0xf6, 0x6f, 0x83, 0x38, 0x46, 0xee, 0xb5, 0x82, 0x96, 0x19,
0x6b, 0x42, 0xec, 0x06, 0x42, 0x5d, 0xb5, 0xb7, 0x0a, 0x4b, 0x81, 0xb7, 0xfc, 0xf7, 0x05, 0xa0,
}
testKeyDerivation(t, elliptic.P384(), d, x, y)
}
func TestECDSAKnownSigningValue_P256(t *testing.T) {
d := []byte{
0x51, 0x9b, 0x42, 0x3d, 0x71, 0x5f, 0x8b, 0x58, 0x1f, 0x4f, 0xa8, 0xee, 0x59, 0xf4, 0x77, 0x1a,
0x5b, 0x44, 0xc8, 0x13, 0x0b, 0x4e, 0x3e, 0xac, 0xca, 0x54, 0xa5, 0x6d, 0xda, 0x72, 0xb4, 0x64,
}
testKnownSigningValue(t, elliptic.P256(), d)
}
func TestECDSAKnownSigningValue_P384(t *testing.T) {
d := []byte{
0x53, 0x94, 0xf7, 0x97, 0x3e, 0xa8, 0x68, 0xc5, 0x2b, 0xf3, 0xff, 0x8d, 0x8c, 0xee, 0xb4, 0xdb,
0x90, 0xa6, 0x83, 0x65, 0x3b, 0x12, 0x48, 0x5d, 0x5f, 0x62, 0x7c, 0x3c, 0xe5, 0xab, 0xd8, 0x97,
0x8f, 0xc9, 0x67, 0x3d, 0x14, 0xa7, 0x1d, 0x92, 0x57, 0x47, 0x93, 0x16, 0x62, 0x49, 0x3c, 0x37,
}
testKnownSigningValue(t, elliptic.P384(), d)
}
func testKeyDerivation(t *testing.T, curve elliptic.Curve, d, expectedX, expectedY []byte) {
privKey := ECDSAKey(curve, d)
if e, a := d, privKey.D.Bytes(); bytes.Compare(e, a) != 0 {
t.Errorf("expected % x, got % x", e, a)
}
if e, a := expectedX, privKey.X.Bytes(); bytes.Compare(e, a) != 0 {
t.Errorf("expected % x, got % x", e, a)
}
if e, a := expectedY, privKey.Y.Bytes(); bytes.Compare(e, a) != 0 {
t.Errorf("expected % x, got % x", e, a)
}
}
func testKnownSigningValue(t *testing.T, curve elliptic.Curve, d []byte) {
signingKey := ECDSAKey(curve, d)
message := []byte{
0x59, 0x05, 0x23, 0x88, 0x77, 0xc7, 0x74, 0x21, 0xf7, 0x3e, 0x43, 0xee, 0x3d, 0xa6, 0xf2, 0xd9,
0xe2, 0xcc, 0xad, 0x5f, 0xc9, 0x42, 0xdc, 0xec, 0x0c, 0xbd, 0x25, 0x48, 0x29, 0x35, 0xfa, 0xaf,
0x41, 0x69, 0x83, 0xfe, 0x16, 0x5b, 0x1a, 0x04, 0x5e, 0xe2, 0xbc, 0xd2, 0xe6, 0xdc, 0xa3, 0xbd,
0xf4, 0x6c, 0x43, 0x10, 0xa7, 0x46, 0x1f, 0x9a, 0x37, 0x96, 0x0c, 0xa6, 0x72, 0xd3, 0xfe, 0xb5,
0x47, 0x3e, 0x25, 0x36, 0x05, 0xfb, 0x1d, 0xdf, 0xd2, 0x80, 0x65, 0xb5, 0x3c, 0xb5, 0x85, 0x8a,
0x8a, 0xd2, 0x81, 0x75, 0xbf, 0x9b, 0xd3, 0x86, 0xa5, 0xe4, 0x71, 0xea, 0x7a, 0x65, 0xc1, 0x7c,
0xc9, 0x34, 0xa9, 0xd7, 0x91, 0xe9, 0x14, 0x91, 0xeb, 0x37, 0x54, 0xd0, 0x37, 0x99, 0x79, 0x0f,
0xe2, 0xd3, 0x08, 0xd1, 0x61, 0x46, 0xd5, 0xc9, 0xb0, 0xd0, 0xde, 0xbd, 0x97, 0xd7, 0x9c, 0xe8,
}
sha256Hash := sha256.New()
_, err := io.Copy(sha256Hash, bytes.NewReader(message))
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
msgHash := sha256Hash.Sum(nil)
msgSignature, err := signingKey.Sign(rand.Reader, msgHash, crypto.SHA256)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
verified, err := VerifySignature(&signingKey.PublicKey, msgHash, msgSignature)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !verified {
t.Fatalf("failed to verify message msgSignature")
}
}
func TestECDSAInvalidSignature_P256(t *testing.T) {
testInvalidSignature(t, elliptic.P256())
}
func TestECDSAInvalidSignature_P384(t *testing.T) {
testInvalidSignature(t, elliptic.P384())
}
func TestECDSAGenKeySignature_P256(t *testing.T) {
testGenKeySignature(t, elliptic.P256())
}
func TestECDSAGenKeySignature_P384(t *testing.T) {
testGenKeySignature(t, elliptic.P384())
}
func testInvalidSignature(t *testing.T, curve elliptic.Curve) {
privateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
if err != nil {
t.Fatalf("failed to generate key: %v", err)
}
message := []byte{
0x59, 0x05, 0x23, 0x88, 0x77, 0xc7, 0x74, 0x21, 0xf7, 0x3e, 0x43, 0xee, 0x3d, 0xa6, 0xf2, 0xd9,
0xe2, 0xcc, 0xad, 0x5f, 0xc9, 0x42, 0xdc, 0xec, 0x0c, 0xbd, 0x25, 0x48, 0x29, 0x35, 0xfa, 0xaf,
0x41, 0x69, 0x83, 0xfe, 0x16, 0x5b, 0x1a, 0x04, 0x5e, 0xe2, 0xbc, 0xd2, 0xe6, 0xdc, 0xa3, 0xbd,
0xf4, 0x6c, 0x43, 0x10, 0xa7, 0x46, 0x1f, 0x9a, 0x37, 0x96, 0x0c, 0xa6, 0x72, 0xd3, 0xfe, 0xb5,
0x47, 0x3e, 0x25, 0x36, 0x05, 0xfb, 0x1d, 0xdf, 0xd2, 0x80, 0x65, 0xb5, 0x3c, 0xb5, 0x85, 0x8a,
0x8a, 0xd2, 0x81, 0x75, 0xbf, 0x9b, 0xd3, 0x86, 0xa5, 0xe4, 0x71, 0xea, 0x7a, 0x65, 0xc1, 0x7c,
0xc9, 0x34, 0xa9, 0xd7, 0x91, 0xe9, 0x14, 0x91, 0xeb, 0x37, 0x54, 0xd0, 0x37, 0x99, 0x79, 0x0f,
0xe2, 0xd3, 0x08, 0xd1, 0x61, 0x46, 0xd5, 0xc9, 0xb0, 0xd0, 0xde, 0xbd, 0x97, 0xd7, 0x9c, 0xe8,
}
sha256Hash := sha256.New()
_, err = io.Copy(sha256Hash, bytes.NewReader(message))
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
msgHash := sha256Hash.Sum(nil)
msgSignature, err := privateKey.Sign(rand.Reader, msgHash, crypto.SHA256)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
byteToFlip := 15
switch msgSignature[byteToFlip] {
case 0:
msgSignature[byteToFlip] = 0x0a
default:
msgSignature[byteToFlip] &^= msgSignature[byteToFlip]
}
verified, err := VerifySignature(&privateKey.PublicKey, msgHash, msgSignature)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if verified {
t.Fatalf("expected message verification to fail")
}
}
func testGenKeySignature(t *testing.T, curve elliptic.Curve) {
privateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
if err != nil {
t.Fatalf("failed to generate key: %v", err)
}
message := []byte{
0x59, 0x05, 0x23, 0x88, 0x77, 0xc7, 0x74, 0x21, 0xf7, 0x3e, 0x43, 0xee, 0x3d, 0xa6, 0xf2, 0xd9,
0xe2, 0xcc, 0xad, 0x5f, 0xc9, 0x42, 0xdc, 0xec, 0x0c, 0xbd, 0x25, 0x48, 0x29, 0x35, 0xfa, 0xaf,
0x41, 0x69, 0x83, 0xfe, 0x16, 0x5b, 0x1a, 0x04, 0x5e, 0xe2, 0xbc, 0xd2, 0xe6, 0xdc, 0xa3, 0xbd,
0xf4, 0x6c, 0x43, 0x10, 0xa7, 0x46, 0x1f, 0x9a, 0x37, 0x96, 0x0c, 0xa6, 0x72, 0xd3, 0xfe, 0xb5,
0x47, 0x3e, 0x25, 0x36, 0x05, 0xfb, 0x1d, 0xdf, 0xd2, 0x80, 0x65, 0xb5, 0x3c, 0xb5, 0x85, 0x8a,
0x8a, 0xd2, 0x81, 0x75, 0xbf, 0x9b, 0xd3, 0x86, 0xa5, 0xe4, 0x71, 0xea, 0x7a, 0x65, 0xc1, 0x7c,
0xc9, 0x34, 0xa9, 0xd7, 0x91, 0xe9, 0x14, 0x91, 0xeb, 0x37, 0x54, 0xd0, 0x37, 0x99, 0x79, 0x0f,
0xe2, 0xd3, 0x08, 0xd1, 0x61, 0x46, 0xd5, 0xc9, 0xb0, 0xd0, 0xde, 0xbd, 0x97, 0xd7, 0x9c, 0xe8,
}
sha256Hash := sha256.New()
_, err = io.Copy(sha256Hash, bytes.NewReader(message))
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
msgHash := sha256Hash.Sum(nil)
msgSignature, err := privateKey.Sign(rand.Reader, msgHash, crypto.SHA256)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
verified, err := VerifySignature(&privateKey.PublicKey, msgHash, msgSignature)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !verified {
t.Fatalf("expected message verification to fail")
}
}
func TestECDSASignatureFormat(t *testing.T) {
asn1Signature := []byte{
0x30, 0x45, 0x02, 0x21, 0x00, 0xd7, 0xc5, 0xb9, 0x9e, 0x0b, 0xb1, 0x1a, 0x1f, 0x32, 0xda, 0x66, 0xe0, 0xff,
0x59, 0xb7, 0x8a, 0x5e, 0xb3, 0x94, 0x9c, 0x23, 0xb3, 0xfc, 0x1f, 0x18, 0xcc, 0xf6, 0x61, 0x67, 0x8b, 0xf1,
0xc1, 0x02, 0x20, 0x26, 0x4d, 0x8b, 0x7c, 0xaa, 0x52, 0x4c, 0xc0, 0x2e, 0x5f, 0xf6, 0x7e, 0x24, 0x82, 0xe5,
0xfb, 0xcb, 0xc7, 0x9b, 0x83, 0x0d, 0x19, 0x7e, 0x7a, 0x40, 0x37, 0x87, 0xdd, 0x1c, 0x93, 0x13, 0xc4,
}
x := []byte{
0x1c, 0xcb, 0xe9, 0x1c, 0x07, 0x5f, 0xc7, 0xf4, 0xf0, 0x33, 0xbf, 0xa2, 0x48, 0xdb, 0x8f, 0xcc,
0xd3, 0x56, 0x5d, 0xe9, 0x4b, 0xbf, 0xb1, 0x2f, 0x3c, 0x59, 0xff, 0x46, 0xc2, 0x71, 0xbf, 0x83,
}
y := []byte{
0xce, 0x40, 0x14, 0xc6, 0x88, 0x11, 0xf9, 0xa2, 0x1a, 0x1f, 0xdb, 0x2c, 0x0e, 0x61, 0x13, 0xe0,
0x6d, 0xb7, 0xca, 0x93, 0xb7, 0x40, 0x4e, 0x78, 0xdc, 0x7c, 0xcd, 0x5c, 0xa8, 0x9a, 0x4c, 0xa9,
}
publicKey, err := ECDSAPublicKey(elliptic.P256(), x, y)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
message := []byte{
0x59, 0x05, 0x23, 0x88, 0x77, 0xc7, 0x74, 0x21, 0xf7, 0x3e, 0x43, 0xee, 0x3d, 0xa6, 0xf2, 0xd9,
0xe2, 0xcc, 0xad, 0x5f, 0xc9, 0x42, 0xdc, 0xec, 0x0c, 0xbd, 0x25, 0x48, 0x29, 0x35, 0xfa, 0xaf,
0x41, 0x69, 0x83, 0xfe, 0x16, 0x5b, 0x1a, 0x04, 0x5e, 0xe2, 0xbc, 0xd2, 0xe6, 0xdc, 0xa3, 0xbd,
0xf4, 0x6c, 0x43, 0x10, 0xa7, 0x46, 0x1f, 0x9a, 0x37, 0x96, 0x0c, 0xa6, 0x72, 0xd3, 0xfe, 0xb5,
0x47, 0x3e, 0x25, 0x36, 0x05, 0xfb, 0x1d, 0xdf, 0xd2, 0x80, 0x65, 0xb5, 0x3c, 0xb5, 0x85, 0x8a,
0x8a, 0xd2, 0x81, 0x75, 0xbf, 0x9b, 0xd3, 0x86, 0xa5, 0xe4, 0x71, 0xea, 0x7a, 0x65, 0xc1, 0x7c,
0xc9, 0x34, 0xa9, 0xd7, 0x91, 0xe9, 0x14, 0x91, 0xeb, 0x37, 0x54, 0xd0, 0x37, 0x99, 0x79, 0x0f,
0xe2, 0xd3, 0x08, 0xd1, 0x61, 0x46, 0xd5, 0xc9, 0xb0, 0xd0, 0xde, 0xbd, 0x97, 0xd7, 0x9c, 0xe8,
}
hash := sha256.New()
_, err = io.Copy(hash, bytes.NewReader(message))
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
msgHash := hash.Sum(nil)
verifySignature, err := VerifySignature(publicKey, msgHash, asn1Signature)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !verifySignature {
t.Fatalf("failed to verify signature")
}
}

View file

@ -0,0 +1,36 @@
package v4
const (
// EmptyStringSHA256 is the hex encoded sha256 value of an empty string
EmptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
// UnsignedPayload indicates that the request payload body is unsigned
UnsignedPayload = "UNSIGNED-PAYLOAD"
// AmzAlgorithmKey indicates the signing algorithm
AmzAlgorithmKey = "X-Amz-Algorithm"
// AmzSecurityTokenKey indicates the security token to be used with temporary credentials
AmzSecurityTokenKey = "X-Amz-Security-Token"
// AmzDateKey is the UTC timestamp for the request in the format YYYYMMDD'T'HHMMSS'Z'
AmzDateKey = "X-Amz-Date"
// AmzCredentialKey is the access key ID and credential scope
AmzCredentialKey = "X-Amz-Credential"
// AmzSignedHeadersKey is the set of headers signed for the request
AmzSignedHeadersKey = "X-Amz-SignedHeaders"
// AmzSignatureKey is the query parameter to store the SigV4 signature
AmzSignatureKey = "X-Amz-Signature"
// TimeFormat is the time format to be used in the X-Amz-Date header or query parameter
TimeFormat = "20060102T150405Z"
// ShortTimeFormat is the shorten time format used in the credential scope
ShortTimeFormat = "20060102"
// ContentSHAKey is the SHA256 of request body
ContentSHAKey = "X-Amz-Content-Sha256"
)

View file

@ -0,0 +1,88 @@
package v4
import (
"strings"
)
// Rules houses a set of Rule needed for validation of a
// string value
type Rules []Rule
// Rule interface allows for more flexible rules and just simply
// checks whether or not a value adheres to that Rule
type Rule interface {
IsValid(value string) bool
}
// IsValid will iterate through all rules and see if any rules
// apply to the value and supports nested rules
func (r Rules) IsValid(value string) bool {
for _, rule := range r {
if rule.IsValid(value) {
return true
}
}
return false
}
// MapRule generic Rule for maps
type MapRule map[string]struct{}
// IsValid for the map Rule satisfies whether it exists in the map
func (m MapRule) IsValid(value string) bool {
_, ok := m[value]
return ok
}
// AllowList is a generic Rule for whitelisting
type AllowList struct {
Rule
}
// IsValid for AllowList checks if the value is within the AllowList
func (w AllowList) IsValid(value string) bool {
return w.Rule.IsValid(value)
}
// DenyList is a generic Rule for blacklisting
type DenyList struct {
Rule
}
// IsValid for AllowList checks if the value is within the AllowList
func (b DenyList) IsValid(value string) bool {
return !b.Rule.IsValid(value)
}
// Patterns is a list of strings to match against
type Patterns []string
// IsValid for Patterns checks each pattern and returns if a match has
// been found
func (p Patterns) IsValid(value string) bool {
for _, pattern := range p {
if HasPrefixFold(value, pattern) {
return true
}
}
return false
}
// InclusiveRules rules allow for rules to depend on one another
type InclusiveRules []Rule
// IsValid will return true if all rules are true
func (r InclusiveRules) IsValid(value string) bool {
for _, rule := range r {
if !rule.IsValid(value) {
return false
}
}
return true
}
// HasPrefixFold tests whether the string s begins with prefix, interpreted as UTF-8 strings,
// under Unicode case-folding.
func HasPrefixFold(s, prefix string) bool {
return len(s) >= len(prefix) && strings.EqualFold(s[0:len(prefix)], prefix)
}

View file

@ -0,0 +1,79 @@
package v4
// IgnoredPresignedHeaders is a list of headers that are ignored during signing
var IgnoredPresignedHeaders = Rules{
DenyList{
MapRule{
"Authorization": struct{}{},
"User-Agent": struct{}{},
"X-Amzn-Trace-Id": struct{}{},
},
},
}
// IgnoredHeaders is a list of headers that are ignored during signing
// drop User-Agent header to be compatible with aws sdk java v1.
var IgnoredHeaders = Rules{
DenyList{
MapRule{
"Authorization": struct{}{},
//"User-Agent": struct{}{},
"X-Amzn-Trace-Id": struct{}{},
},
},
}
// RequiredSignedHeaders is a whitelist for Build canonical headers.
var RequiredSignedHeaders = Rules{
AllowList{
MapRule{
"Cache-Control": struct{}{},
"Content-Disposition": struct{}{},
"Content-Encoding": struct{}{},
"Content-Language": struct{}{},
"Content-Md5": struct{}{},
"Content-Type": struct{}{},
"Expires": struct{}{},
"If-Match": struct{}{},
"If-Modified-Since": struct{}{},
"If-None-Match": struct{}{},
"If-Unmodified-Since": struct{}{},
"Range": struct{}{},
"X-Amz-Acl": struct{}{},
"X-Amz-Copy-Source": struct{}{},
"X-Amz-Copy-Source-If-Match": struct{}{},
"X-Amz-Copy-Source-If-Modified-Since": struct{}{},
"X-Amz-Copy-Source-If-None-Match": struct{}{},
"X-Amz-Copy-Source-If-Unmodified-Since": struct{}{},
"X-Amz-Copy-Source-Range": struct{}{},
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Algorithm": struct{}{},
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key": struct{}{},
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key-Md5": struct{}{},
"X-Amz-Grant-Full-control": struct{}{},
"X-Amz-Grant-Read": struct{}{},
"X-Amz-Grant-Read-Acp": struct{}{},
"X-Amz-Grant-Write": struct{}{},
"X-Amz-Grant-Write-Acp": struct{}{},
"X-Amz-Metadata-Directive": struct{}{},
"X-Amz-Mfa": struct{}{},
"X-Amz-Request-Payer": struct{}{},
"X-Amz-Server-Side-Encryption": struct{}{},
"X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id": struct{}{},
"X-Amz-Server-Side-Encryption-Customer-Algorithm": struct{}{},
"X-Amz-Server-Side-Encryption-Customer-Key": struct{}{},
"X-Amz-Server-Side-Encryption-Customer-Key-Md5": struct{}{},
"X-Amz-Storage-Class": struct{}{},
"X-Amz-Website-Redirect-Location": struct{}{},
"X-Amz-Content-Sha256": struct{}{},
"X-Amz-Tagging": struct{}{},
},
},
Patterns{"X-Amz-Meta-"},
}
// AllowedQueryHoisting is a whitelist for Build query headers. The boolean value
// represents whether or not it is a pattern.
var AllowedQueryHoisting = InclusiveRules{
DenyList{RequiredSignedHeaders},
Patterns{"X-Amz-"},
}

View file

@ -0,0 +1,13 @@
package v4
import (
"crypto/hmac"
"crypto/sha256"
)
// HMACSHA256 computes a HMAC-SHA256 of data given the provided key.
func HMACSHA256(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(data)
return hash.Sum(nil)
}

View file

@ -0,0 +1,75 @@
package v4
import (
"net/http"
"strings"
)
// SanitizeHostForHeader removes default port from host and updates request.Host
func SanitizeHostForHeader(r *http.Request) {
host := getHost(r)
port := portOnly(host)
if port != "" && isDefaultPort(r.URL.Scheme, port) {
r.Host = stripPort(host)
}
}
// Returns host from request
func getHost(r *http.Request) string {
if r.Host != "" {
return r.Host
}
return r.URL.Host
}
// Hostname returns u.Host, without any port number.
//
// If Host is an IPv6 literal with a port number, Hostname returns the
// IPv6 literal without the square brackets. IPv6 literals may include
// a zone identifier.
//
// Copied from the Go 1.8 standard library (net/url)
func stripPort(hostport string) string {
colon := strings.IndexByte(hostport, ':')
if colon == -1 {
return hostport
}
if i := strings.IndexByte(hostport, ']'); i != -1 {
return strings.TrimPrefix(hostport[:i], "[")
}
return hostport[:colon]
}
// Port returns the port part of u.Host, without the leading colon.
// If u.Host doesn't contain a port, Port returns an empty string.
//
// Copied from the Go 1.8 standard library (net/url)
func portOnly(hostport string) string {
colon := strings.IndexByte(hostport, ':')
if colon == -1 {
return ""
}
if i := strings.Index(hostport, "]:"); i != -1 {
return hostport[i+len("]:"):]
}
if strings.Contains(hostport, "]") {
return ""
}
return hostport[colon+len(":"):]
}
// Returns true if the specified URI is using the standard port
// (i.e. port 80 for HTTP URIs or 443 for HTTPS URIs)
func isDefaultPort(scheme, port string) bool {
if port == "" {
return true
}
lowerCaseScheme := strings.ToLower(scheme)
if (lowerCaseScheme == "http" && port == "80") || (lowerCaseScheme == "https" && port == "443") {
return true
}
return false
}

View file

@ -0,0 +1,36 @@
package v4
import "time"
// SigningTime provides a wrapper around a time.Time which provides cached values for SigV4 signing.
type SigningTime struct {
time.Time
timeFormat string
shortTimeFormat string
}
// NewSigningTime creates a new SigningTime given a time.Time
func NewSigningTime(t time.Time) SigningTime {
return SigningTime{
Time: t,
}
}
// TimeFormat provides a time formatted in the X-Amz-Date format.
func (m *SigningTime) TimeFormat() string {
return m.format(&m.timeFormat, TimeFormat)
}
// ShortTimeFormat provides a time formatted of 20060102.
func (m *SigningTime) ShortTimeFormat() string {
return m.format(&m.shortTimeFormat, ShortTimeFormat)
}
func (m *SigningTime) format(target *string, format string) string {
if len(*target) > 0 {
return *target
}
v := m.Time.Format(format)
*target = v
return v
}

View file

@ -0,0 +1,64 @@
package v4
import (
"net/url"
"strings"
)
const doubleSpace = " "
// StripExcessSpaces will rewrite the passed in slice's string values to not
// contain muliple side-by-side spaces.
func StripExcessSpaces(str string) string {
var j, k, l, m, spaces int
// Trim trailing spaces
for j = len(str) - 1; j >= 0 && str[j] == ' '; j-- {
}
// Trim leading spaces
for k = 0; k < j && str[k] == ' '; k++ {
}
str = str[k : j+1]
// Strip multiple spaces.
j = strings.Index(str, doubleSpace)
if j < 0 {
return str
}
buf := []byte(str)
for k, m, l = j, j, len(buf); k < l; k++ {
if buf[k] == ' ' {
if spaces == 0 {
// First space.
buf[m] = buf[k]
m++
}
spaces++
} else {
// End of multiple spaces.
spaces = 0
buf[m] = buf[k]
m++
}
}
return string(buf[:m])
}
// GetURIPath returns the escaped URI component from the provided URL
func GetURIPath(u *url.URL) string {
var uri string
if len(u.Opaque) > 0 {
uri = "/" + strings.Join(strings.Split(u.Opaque, "/")[3:], "/")
} else {
uri = u.EscapedPath()
}
if len(uri) == 0 {
uri = "/"
}
return uri
}

View file

@ -0,0 +1,75 @@
package v4
import (
"testing"
)
func TestStripExcessHeaders(t *testing.T) {
vals := []string{
"",
"123",
"1 2 3",
"1 2 3 ",
" 1 2 3",
"1 2 3",
"1 23",
"1 2 3",
"1 2 ",
" 1 2 ",
"12 3",
"12 3 1",
"12 3 1",
"12 3 1abc123",
}
expected := []string{
"",
"123",
"1 2 3",
"1 2 3",
"1 2 3",
"1 2 3",
"1 23",
"1 2 3",
"1 2",
"1 2",
"12 3",
"12 3 1",
"12 3 1",
"12 3 1abc123",
}
for i := 0; i < len(vals); i++ {
r := StripExcessSpaces(vals[i])
if e, a := expected[i], r; e != a {
t.Errorf("%d, expect %v, got %v", i, e, a)
}
}
}
var stripExcessSpaceCases = []string{
`AWS4-HMAC-SHA256 Credential=AKIDFAKEIDFAKEID/20160628/us-west-2/s3/aws4_request, SignedHeaders=host;x-amz-date, Signature=1234567890abcdef1234567890abcdef1234567890abcdef`,
`123 321 123 321`,
` 123 321 123 321 `,
` 123 321 123 321 `,
"123",
"1 2 3",
" 1 2 3",
"1 2 3",
"1 23",
"1 2 3",
"1 2 ",
" 1 2 ",
"12 3",
"12 3 1",
"12 3 1",
"12 3 1abc123",
}
func BenchmarkStripExcessSpaces(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, v := range stripExcessSpaceCases {
StripExcessSpaces(v)
}
}
}

View file

@ -0,0 +1,118 @@
package v4a
import (
"context"
"fmt"
"net/http"
"time"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)
// HTTPSigner is SigV4a HTTP signer implementation
type HTTPSigner interface {
SignHTTP(ctx context.Context, credentials Credentials, r *http.Request, payloadHash string, service string, regionSet []string, signingTime time.Time, optfns ...func(*SignerOptions)) error
}
// SignHTTPRequestMiddlewareOptions is the middleware options for constructing a SignHTTPRequestMiddleware.
type SignHTTPRequestMiddlewareOptions struct {
Credentials CredentialsProvider
Signer HTTPSigner
LogSigning bool
}
// SignHTTPRequestMiddleware is a middleware for signing an HTTP request using SigV4a.
type SignHTTPRequestMiddleware struct {
credentials CredentialsProvider
signer HTTPSigner
logSigning bool
}
// NewSignHTTPRequestMiddleware constructs a SignHTTPRequestMiddleware using the given SignHTTPRequestMiddlewareOptions.
func NewSignHTTPRequestMiddleware(options SignHTTPRequestMiddlewareOptions) *SignHTTPRequestMiddleware {
return &SignHTTPRequestMiddleware{
credentials: options.Credentials,
signer: options.Signer,
logSigning: options.LogSigning,
}
}
// ID the middleware identifier.
func (s *SignHTTPRequestMiddleware) ID() string {
return "Signing"
}
// HandleFinalize signs an HTTP request using SigV4a.
func (s *SignHTTPRequestMiddleware) HandleFinalize(
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
if !hasCredentialProvider(s.credentials) {
return next.HandleFinalize(ctx, in)
}
req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return out, metadata, fmt.Errorf("unexpected request middleware type %T", in.Request)
}
signingName, signingRegion := awsmiddleware.GetSigningName(ctx), awsmiddleware.GetSigningRegion(ctx)
payloadHash := v4.GetPayloadHash(ctx)
if len(payloadHash) == 0 {
return out, metadata, &SigningError{Err: fmt.Errorf("computed payload hash missing from context")}
}
credentials, err := s.credentials.RetrievePrivateKey(ctx)
if err != nil {
return out, metadata, &SigningError{Err: fmt.Errorf("failed to retrieve credentials: %w", err)}
}
signerOptions := []func(o *SignerOptions){
func(o *SignerOptions) {
o.Logger = middleware.GetLogger(ctx)
o.LogSigning = s.logSigning
},
}
// existing DisableURIPathEscaping is equivalent in purpose
// to authentication scheme property DisableDoubleEncoding
//disableDoubleEncoding, overridden := internalauth.GetDisableDoubleEncoding(ctx) // internalauth "github.com/aws/aws-sdk-go-v2/internal/auth"
//if overridden {
// signerOptions = append(signerOptions, func(o *SignerOptions) {
// o.DisableURIPathEscaping = disableDoubleEncoding
// })
//}
err = s.signer.SignHTTP(ctx, credentials, req.Request, payloadHash, signingName, []string{signingRegion}, time.Now().UTC(), signerOptions...)
if err != nil {
return out, metadata, &SigningError{Err: fmt.Errorf("failed to sign http request, %w", err)}
}
return next.HandleFinalize(ctx, in)
}
func hasCredentialProvider(p CredentialsProvider) bool {
if p == nil {
return false
}
return true
}
// RegisterSigningMiddleware registers the SigV4a signing middleware to the stack. If a signing middleware is already
// present, this provided middleware will be swapped. Otherwise the middleware will be added at the tail of the
// finalize step.
func RegisterSigningMiddleware(stack *middleware.Stack, signingMiddleware *SignHTTPRequestMiddleware) (err error) {
const signedID = "Signing"
_, present := stack.Finalize.Get(signedID)
if present {
_, err = stack.Finalize.Swap(signedID, signingMiddleware)
} else {
err = stack.Finalize.Add(signingMiddleware, middleware.After)
}
return err
}

View file

@ -0,0 +1,149 @@
package v4a
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"strings"
"testing"
"time"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/smithy-go/logging"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)
type stubCredentialsProviderFunc func(context.Context) (Credentials, error)
func (f stubCredentialsProviderFunc) RetrievePrivateKey(ctx context.Context) (Credentials, error) {
return f(ctx)
}
type httpSignerFunc func(ctx context.Context, credentials Credentials, r *http.Request, payloadHash string, service string, regionSet []string, signingTime time.Time, optFns ...func(*SignerOptions)) error
func (f httpSignerFunc) SignHTTP(ctx context.Context, credentials Credentials, r *http.Request, payloadHash string, service string, regionSet []string, signingTime time.Time, optFns ...func(*SignerOptions)) error {
return f(ctx, credentials, r, payloadHash, service, regionSet, signingTime, optFns...)
}
func TestSignHTTPRequestMiddleware(t *testing.T) {
cases := map[string]struct {
creds CredentialsProvider
hash string
logSigning bool
expectedErr interface{}
}{
"success": {
creds: stubCredentials,
hash: "0123456789abcdef",
},
"error": {
creds: stubCredentialsProviderFunc(func(ctx context.Context) (Credentials, error) {
return Credentials{}, fmt.Errorf("credential error")
}),
hash: "",
expectedErr: &SigningError{},
},
"nil creds": {
creds: nil,
},
"with log signing": {
creds: stubCredentials,
hash: "0123456789abcdef",
logSigning: true,
},
}
const (
signingName = "serviceId"
signingRegion = "regionName"
)
for name, tt := range cases {
t.Run(name, func(t *testing.T) {
c := &SignHTTPRequestMiddleware{
credentials: tt.creds,
signer: httpSignerFunc(
func(ctx context.Context,
credentials Credentials, r *http.Request, payloadHash string,
service string, regionSet []string, signingTime time.Time,
optFns ...func(*SignerOptions),
) error {
var options SignerOptions
for _, fn := range optFns {
fn(&options)
}
if options.Logger == nil {
t.Errorf("expect logger, got none")
}
if options.LogSigning {
options.Logger.Logf(logging.Debug, t.Name())
}
expectCreds, _ := tt.creds.RetrievePrivateKey(ctx)
if diff := cmpDiff(expectCreds, credentials); len(diff) > 0 {
t.Error(diff)
}
if e, a := tt.hash, payloadHash; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := signingName, service; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if diff := cmpDiff([]string{signingRegion}, regionSet); len(diff) > 0 {
t.Error(diff)
}
return nil
}),
logSigning: tt.logSigning,
}
next := middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
return out, metadata, err
})
ctx := awsmiddleware.SetSigningRegion(
awsmiddleware.SetSigningName(context.Background(), signingName),
signingRegion)
var loggerBuf bytes.Buffer
logger := logging.NewStandardLogger(&loggerBuf)
ctx = middleware.SetLogger(ctx, logger)
if len(tt.hash) != 0 {
ctx = v4.SetPayloadHash(ctx, tt.hash)
}
_, _, err := c.HandleFinalize(ctx, middleware.FinalizeInput{
Request: &smithyhttp.Request{Request: &http.Request{}},
}, next)
if err != nil && tt.expectedErr == nil {
t.Errorf("expected no error, got %v", err)
} else if err != nil && tt.expectedErr != nil {
e, a := tt.expectedErr, err
if !errors.As(a, &e) {
t.Errorf("expected error type %T, got %T", e, a)
}
} else if err == nil && tt.expectedErr != nil {
t.Errorf("expected error, got nil")
}
if tt.logSigning {
if e, a := t.Name(), loggerBuf.String(); !strings.Contains(a, e) {
t.Errorf("expect %v logged in %v", e, a)
}
} else {
if loggerBuf.Len() != 0 {
t.Errorf("expect no log, got %v", loggerBuf.String())
}
}
})
}
}
var (
_ middleware.FinalizeMiddleware = &SignHTTPRequestMiddleware{}
)

View file

@ -0,0 +1,116 @@
package v4a
import (
"context"
"fmt"
"net/http"
"time"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/smithy-go/middleware"
smithyHTTP "github.com/aws/smithy-go/transport/http"
)
// HTTPPresigner is an interface to a SigV4a signer that can sign create a
// presigned URL for a HTTP requests.
type HTTPPresigner interface {
PresignHTTP(
ctx context.Context, credentials Credentials, r *http.Request,
payloadHash string, service string, regionSet []string, signingTime time.Time,
optFns ...func(*SignerOptions),
) (url string, signedHeader http.Header, err error)
}
// PresignHTTPRequestMiddlewareOptions is the options for the PresignHTTPRequestMiddleware middleware.
type PresignHTTPRequestMiddlewareOptions struct {
CredentialsProvider CredentialsProvider
Presigner HTTPPresigner
LogSigning bool
}
// PresignHTTPRequestMiddleware provides the Finalize middleware for creating a
// presigned URL for an HTTP request.
//
// Will short circuit the middleware stack and not forward onto the next
// Finalize handler.
type PresignHTTPRequestMiddleware struct {
credentialsProvider CredentialsProvider
presigner HTTPPresigner
logSigning bool
}
// NewPresignHTTPRequestMiddleware returns a new PresignHTTPRequestMiddleware
// initialized with the presigner.
func NewPresignHTTPRequestMiddleware(options PresignHTTPRequestMiddlewareOptions) *PresignHTTPRequestMiddleware {
return &PresignHTTPRequestMiddleware{
credentialsProvider: options.CredentialsProvider,
presigner: options.Presigner,
logSigning: options.LogSigning,
}
}
// ID provides the middleware ID.
func (*PresignHTTPRequestMiddleware) ID() string { return "PresignHTTPRequest" }
// HandleFinalize will take the provided input and create a presigned url for
// the http request using the SigV4 presign authentication scheme.
func (s *PresignHTTPRequestMiddleware) HandleFinalize(
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
req, ok := in.Request.(*smithyHTTP.Request)
if !ok {
return out, metadata, &SigningError{
Err: fmt.Errorf("unexpected request middleware type %T", in.Request),
}
}
httpReq := req.Build(ctx)
if !hasCredentialProvider(s.credentialsProvider) {
out.Result = &v4.PresignedHTTPRequest{
URL: httpReq.URL.String(),
Method: httpReq.Method,
SignedHeader: http.Header{},
}
return out, metadata, nil
}
signingName := awsmiddleware.GetSigningName(ctx)
signingRegion := awsmiddleware.GetSigningRegion(ctx)
payloadHash := v4.GetPayloadHash(ctx)
if len(payloadHash) == 0 {
return out, metadata, &SigningError{
Err: fmt.Errorf("computed payload hash missing from context"),
}
}
credentials, err := s.credentialsProvider.RetrievePrivateKey(ctx)
if err != nil {
return out, metadata, &SigningError{
Err: fmt.Errorf("failed to retrieve credentials: %w", err),
}
}
u, h, err := s.presigner.PresignHTTP(ctx, credentials,
httpReq, payloadHash, signingName, []string{signingRegion}, time.Now(),
func(o *SignerOptions) {
o.Logger = middleware.GetLogger(ctx)
o.LogSigning = s.logSigning
})
if err != nil {
return out, metadata, &SigningError{
Err: fmt.Errorf("failed to sign http request, %w", err),
}
}
out.Result = &v4.PresignedHTTPRequest{
URL: u,
Method: httpReq.Method,
SignedHeader: h,
}
return out, metadata, nil
}

View file

@ -0,0 +1,222 @@
package v4a
import (
"bytes"
"context"
"net/http"
"net/url"
"strings"
"testing"
"time"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/smithy-go/logging"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)
type httpPresignerFunc func(
ctx context.Context, credentials Credentials, r *http.Request,
payloadHash string, service string, regionSet []string, signingTime time.Time,
optFns ...func(*SignerOptions),
) (url string, signedHeader http.Header, err error)
func (f httpPresignerFunc) PresignHTTP(
ctx context.Context, credentials Credentials, r *http.Request,
payloadHash string, service string, regionSet []string, signingTime time.Time,
optFns ...func(*SignerOptions),
) (
url string, signedHeader http.Header, err error,
) {
return f(ctx, credentials, r, payloadHash, service, regionSet, signingTime, optFns...)
}
func TestPresignHTTPRequestMiddleware(t *testing.T) {
cases := map[string]struct {
Request *http.Request
Creds CredentialsProvider
PayloadHash string
LogSigning bool
ExpectResult *v4.PresignedHTTPRequest
ExpectErr string
}{
"success": {
Request: &http.Request{
URL: func() *url.URL {
u, _ := url.Parse("https://example.aws/path?query=foo")
return u
}(),
Header: http.Header{},
},
Creds: stubCredentials,
PayloadHash: "0123456789abcdef",
ExpectResult: &v4.PresignedHTTPRequest{
URL: "https://example.aws/path?query=foo",
SignedHeader: http.Header{},
},
},
"error": {
Request: func() *http.Request {
return &http.Request{}
}(),
Creds: stubCredentials,
PayloadHash: "",
ExpectErr: "failed to sign request",
},
"anonymous creds": {
Request: &http.Request{
URL: func() *url.URL {
u, _ := url.Parse("https://example.aws/path?query=foo")
return u
}(),
Header: http.Header{},
},
Creds: stubCredentials,
PayloadHash: "",
ExpectErr: "failed to sign request",
ExpectResult: &v4.PresignedHTTPRequest{
URL: "https://example.aws/path?query=foo",
SignedHeader: http.Header{},
},
},
"nil creds": {
Request: &http.Request{
URL: func() *url.URL {
u, _ := url.Parse("https://example.aws/path?query=foo")
return u
}(),
Header: http.Header{},
},
Creds: nil,
ExpectResult: &v4.PresignedHTTPRequest{
URL: "https://example.aws/path?query=foo",
SignedHeader: http.Header{},
},
},
"with log signing": {
Request: &http.Request{
URL: func() *url.URL {
u, _ := url.Parse("https://example.aws/path?query=foo")
return u
}(),
Header: http.Header{},
},
Creds: stubCredentials,
PayloadHash: "0123456789abcdef",
ExpectResult: &v4.PresignedHTTPRequest{
URL: "https://example.aws/path?query=foo",
SignedHeader: http.Header{},
},
LogSigning: true,
},
}
const (
signingName = "serviceId"
signingRegion = "regionName"
)
for name, tt := range cases {
t.Run(name, func(t *testing.T) {
m := &PresignHTTPRequestMiddleware{
credentialsProvider: tt.Creds,
presigner: httpPresignerFunc(func(
ctx context.Context, credentials Credentials, r *http.Request,
payloadHash string, service string, regionSet []string, signingTime time.Time,
optFns ...func(*SignerOptions),
) (url string, signedHeader http.Header, err error) {
var options SignerOptions
for _, fn := range optFns {
fn(&options)
}
if options.Logger == nil {
t.Errorf("expect logger, got none")
}
if options.LogSigning {
options.Logger.Logf(logging.Debug, t.Name())
}
if !hasCredentialProvider(tt.Creds) {
t.Errorf("expect presigner not to be called for not credentials provider")
}
expectCreds, _ := tt.Creds.RetrievePrivateKey(context.Background())
if diff := cmpDiff(expectCreds, credentials); len(diff) > 0 {
t.Error(diff)
}
if e, a := tt.PayloadHash, payloadHash; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := signingName, service; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if diff := cmpDiff([]string{signingRegion}, regionSet); len(diff) > 0 {
t.Error(diff)
}
return tt.ExpectResult.URL, tt.ExpectResult.SignedHeader, nil
}),
logSigning: tt.LogSigning,
}
next := middleware.FinalizeHandlerFunc(
func(ctx context.Context, in middleware.FinalizeInput) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
t.Errorf("expect next handler not to be called")
return out, metadata, err
})
ctx := awsmiddleware.SetSigningRegion(
awsmiddleware.SetSigningName(context.Background(), signingName),
signingRegion)
var loggerBuf bytes.Buffer
logger := logging.NewStandardLogger(&loggerBuf)
ctx = middleware.SetLogger(ctx, logger)
if len(tt.PayloadHash) != 0 {
ctx = v4.SetPayloadHash(ctx, tt.PayloadHash)
}
result, _, err := m.HandleFinalize(ctx, middleware.FinalizeInput{
Request: &smithyhttp.Request{
Request: tt.Request,
},
}, next)
if len(tt.ExpectErr) != 0 {
if err == nil {
t.Fatalf("expect error, got none")
}
if e, a := tt.ExpectErr, err.Error(); !strings.Contains(a, e) {
t.Fatalf("expect error to contain %v, got %v", e, a)
}
return
}
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if diff := cmpDiff(tt.ExpectResult, result.Result); len(diff) != 0 {
t.Errorf("expect result match\n%v", diff)
}
if tt.LogSigning {
if e, a := t.Name(), loggerBuf.String(); !strings.Contains(a, e) {
t.Errorf("expect %v logged in %v", e, a)
}
} else {
if loggerBuf.Len() != 0 {
t.Errorf("expect no log, got %v", loggerBuf.String())
}
}
})
}
}
var (
_ middleware.FinalizeMiddleware = &PresignHTTPRequestMiddleware{}
)

View file

@ -0,0 +1,18 @@
package v4a
import (
"bytes"
"context"
"crypto/ecdsa"
)
var stubCredentials = stubCredentialsProviderFunc(func(ctx context.Context) (Credentials, error) {
stubKey, err := ecdsa.GenerateKey(p256, bytes.NewReader(bytes.Repeat([]byte{1}, 40)))
if err != nil {
return Credentials{}, err
}
return Credentials{
Context: "STUB",
PrivateKey: stubKey,
}, nil
})

View file

@ -0,0 +1,85 @@
package v4a
import (
"context"
"fmt"
"time"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/smithy-go"
"github.com/aws/smithy-go/auth"
"github.com/aws/smithy-go/logging"
smithyhttp "github.com/aws/smithy-go/transport/http"
)
// CredentialsAdapter adapts v4a.Credentials to smithy auth.Identity.
type CredentialsAdapter struct {
Credentials Credentials
}
var _ auth.Identity = (*CredentialsAdapter)(nil)
// Expiration returns the time of expiration for the credentials.
func (v *CredentialsAdapter) Expiration() time.Time {
return v.Credentials.Expires
}
// CredentialsProviderAdapter adapts v4a.CredentialsProvider to
// auth.IdentityResolver.
type CredentialsProviderAdapter struct {
Provider CredentialsProvider
}
var _ (auth.IdentityResolver) = (*CredentialsProviderAdapter)(nil)
// GetIdentity retrieves v4a credentials using the underlying provider.
func (v *CredentialsProviderAdapter) GetIdentity(ctx context.Context, _ smithy.Properties) (
auth.Identity, error,
) {
creds, err := v.Provider.RetrievePrivateKey(ctx)
if err != nil {
return nil, fmt.Errorf("get credentials: %w", err)
}
return &CredentialsAdapter{Credentials: creds}, nil
}
// SignerAdapter adapts v4a.HTTPSigner to smithy http.Signer.
type SignerAdapter struct {
Signer HTTPSigner
Logger logging.Logger
LogSigning bool
}
var _ (smithyhttp.Signer) = (*SignerAdapter)(nil)
// SignRequest signs the request with the provided identity.
func (v *SignerAdapter) SignRequest(ctx context.Context, r *smithyhttp.Request, identity auth.Identity, props smithy.Properties) error {
ca, ok := identity.(*CredentialsAdapter)
if !ok {
return fmt.Errorf("unexpected identity type: %T", identity)
}
name, ok := smithyhttp.GetSigV4SigningName(&props)
if !ok {
return fmt.Errorf("sigv4a signing name is required")
}
regions, ok := smithyhttp.GetSigV4ASigningRegions(&props)
if !ok {
return fmt.Errorf("sigv4a signing region is required")
}
hash := v4.GetPayloadHash(ctx)
err := v.Signer.SignHTTP(ctx, ca.Credentials, r.Request, hash, name, regions, time.Now(), func(o *SignerOptions) {
o.DisableURIPathEscaping, _ = smithyhttp.GetDisableDoubleEncoding(&props)
o.Logger = v.Logger
o.LogSigning = v.LogSigning
})
if err != nil {
return fmt.Errorf("sign http: %w", err)
}
return nil
}

View file

@ -0,0 +1,96 @@
package v4a
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"time"
signerCrypto "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4asdk2/internal/crypto"
v4Internal "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4asdk2/internal/v4"
)
// EventStreamSigner is an AWS EventStream protocol signer.
type EventStreamSigner interface {
GetSignature(ctx context.Context, headers, payload []byte, signingTime time.Time, optFns ...func(*StreamSignerOptions)) ([]byte, error)
}
// StreamSignerOptions is the configuration options for StreamSigner.
type StreamSignerOptions struct{}
// StreamSigner implements Signature Version 4 (SigV4) signing of event stream encoded payloads.
type StreamSigner struct {
options StreamSignerOptions
credentials Credentials
service string
prevSignature []byte
}
// NewStreamSigner returns a new AWS EventStream protocol signer.
func NewStreamSigner(credentials Credentials, service string, seedSignature []byte, optFns ...func(*StreamSignerOptions)) *StreamSigner {
o := StreamSignerOptions{}
for _, fn := range optFns {
fn(&o)
}
return &StreamSigner{
options: o,
credentials: credentials,
service: service,
prevSignature: seedSignature,
}
}
func (s *StreamSigner) VerifySignature(headers, payload []byte, signingTime time.Time, signature []byte, optFns ...func(*StreamSignerOptions)) error {
options := s.options
for _, fn := range optFns {
fn(&options)
}
prevSignature := s.prevSignature
st := v4Internal.NewSigningTime(signingTime)
scope := buildCredentialScope(st, s.service)
stringToSign := s.buildEventStreamStringToSign(headers, payload, prevSignature, scope, &st)
ok, err := signerCrypto.VerifySignature(&s.credentials.PrivateKey.PublicKey, makeHash(sha256.New(), []byte(stringToSign)), signature)
if err != nil {
return err
}
if !ok {
return fmt.Errorf("v4a: invalid signature")
}
s.prevSignature = signature
return nil
}
func (s *StreamSigner) buildEventStreamStringToSign(headers, payload, previousSignature []byte, credentialScope string, signingTime *v4Internal.SigningTime) string {
hash := sha256.New()
return strings.Join([]string{
"AWS4-ECDSA-P256-SHA256-PAYLOAD",
signingTime.TimeFormat(),
credentialScope,
hex.EncodeToString(previousSignature),
hex.EncodeToString(makeHash(hash, headers)),
hex.EncodeToString(makeHash(hash, payload)),
}, "\n")
}
func buildCredentialScope(st v4Internal.SigningTime, service string) string {
return strings.Join([]string{
st.Format(shortTimeFormat),
service,
"aws4_request",
}, "/")
}

View file

@ -0,0 +1,581 @@
package v4a
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"hash"
"math/big"
"net/http"
"net/textproto"
"net/url"
"sort"
"strings"
"time"
signerCrypto "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4asdk2/internal/crypto"
v4Internal "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4asdk2/internal/v4"
"github.com/aws/smithy-go/encoding/httpbinding"
"github.com/aws/smithy-go/logging"
)
const (
// AmzRegionSetKey represents the region set header used for sigv4a
AmzRegionSetKey = "X-Amz-Region-Set"
amzAlgorithmKey = v4Internal.AmzAlgorithmKey
amzSecurityTokenKey = v4Internal.AmzSecurityTokenKey
amzDateKey = v4Internal.AmzDateKey
amzCredentialKey = v4Internal.AmzCredentialKey
amzSignedHeadersKey = v4Internal.AmzSignedHeadersKey
authorizationHeader = "Authorization"
signingAlgorithm = "AWS4-ECDSA-P256-SHA256"
timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102"
// EmptyStringSHA256 is a hex encoded SHA-256 hash of an empty string
EmptyStringSHA256 = v4Internal.EmptyStringSHA256
// Version of signing v4a
Version = "SigV4A"
)
var (
p256 elliptic.Curve
nMinusTwoP256 *big.Int
one = new(big.Int).SetInt64(1)
)
func init() {
// Ensure the elliptic curve parameters are initialized on package import rather then on first usage
p256 = elliptic.P256()
nMinusTwoP256 = new(big.Int).SetBytes(p256.Params().N.Bytes())
nMinusTwoP256 = nMinusTwoP256.Sub(nMinusTwoP256, new(big.Int).SetInt64(2))
}
// SignerOptions is the SigV4a signing options for constructing a Signer.
type SignerOptions struct {
Logger logging.Logger
LogSigning bool
// Disables the Signer's moving HTTP header key/value pairs from the HTTP
// request header to the request's query string. This is most commonly used
// with pre-signed requests preventing headers from being added to the
// request's query string.
DisableHeaderHoisting bool
// Disables the automatic escaping of the URI path of the request for the
// siganture's canonical string's path. For services that do not need additional
// escaping then use this to disable the signer escaping the path.
//
// S3 is an example of a service that does not need additional escaping.
//
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
DisableURIPathEscaping bool
}
// Signer is a SigV4a HTTP signing implementation
type Signer struct {
options SignerOptions
}
// NewSigner constructs a SigV4a Signer.
func NewSigner(optFns ...func(*SignerOptions)) *Signer {
options := SignerOptions{}
for _, fn := range optFns {
fn(&options)
}
return &Signer{options: options}
}
// deriveKeyFromAccessKeyPair derives a NIST P-256 PrivateKey from the given
// IAM AccessKey and SecretKey pair.
//
// Based on FIPS.186-4 Appendix B.4.2
func deriveKeyFromAccessKeyPair(accessKey, secretKey string) (*ecdsa.PrivateKey, error) {
params := p256.Params()
bitLen := params.BitSize // Testing random candidates does not require an additional 64 bits
counter := 0x01
buffer := make([]byte, 1+len(accessKey)) // 1 byte counter + len(accessKey)
kdfContext := bytes.NewBuffer(buffer)
inputKey := append([]byte("AWS4A"), []byte(secretKey)...)
d := new(big.Int)
for {
kdfContext.Reset()
kdfContext.WriteString(accessKey)
kdfContext.WriteByte(byte(counter))
key, err := signerCrypto.HMACKeyDerivation(sha256.New, bitLen, inputKey, []byte(signingAlgorithm), kdfContext.Bytes())
if err != nil {
return nil, err
}
// Check key first before calling SetBytes if key key is in fact a valid candidate.
// This ensures the byte slice is the correct length (32-bytes) to compare in constant-time
cmp, err := signerCrypto.ConstantTimeByteCompare(key, nMinusTwoP256.Bytes())
if err != nil {
return nil, err
}
if cmp == -1 {
d.SetBytes(key)
break
}
counter++
if counter > 0xFF {
return nil, fmt.Errorf("exhausted single byte external counter")
}
}
d = d.Add(d, one)
priv := new(ecdsa.PrivateKey)
priv.PublicKey.Curve = p256
priv.D = d
priv.PublicKey.X, priv.PublicKey.Y = p256.ScalarBaseMult(d.Bytes())
return priv, nil
}
type httpSigner struct {
Request *http.Request
ServiceName string
RegionSet []string
Time time.Time
Credentials Credentials
IsPreSign bool
Logger logging.Logger
Debug bool
// PayloadHash is the hex encoded SHA-256 hash of the request payload
// If len(PayloadHash) == 0 the signer will attempt to send the request
// as an unsigned payload. Note: Unsigned payloads only work for a subset of services.
PayloadHash string
DisableHeaderHoisting bool
DisableURIPathEscaping bool
}
// SignHTTP takes the provided http.Request, payload hash, service, regionSet, and time and signs using SigV4a.
// The passed in request will be modified in place.
func (s *Signer) SignHTTP(ctx context.Context, credentials Credentials, r *http.Request, payloadHash string, service string, regionSet []string, signingTime time.Time, optFns ...func(*SignerOptions)) error {
options := s.options
for _, fn := range optFns {
fn(&options)
}
signer := &httpSigner{
Request: r,
PayloadHash: payloadHash,
ServiceName: service,
RegionSet: regionSet,
Credentials: credentials,
Time: signingTime.UTC(),
DisableHeaderHoisting: options.DisableHeaderHoisting,
DisableURIPathEscaping: options.DisableURIPathEscaping,
}
signedRequest, err := signer.Build()
if err != nil {
return err
}
logHTTPSigningInfo(ctx, options, signedRequest)
return nil
}
// VerifySignature checks sigv4a.
func (s *Signer) VerifySignature(credentials Credentials, r *http.Request, payloadHash string, service string, regionSet []string, signingTime time.Time, signature string, optFns ...func(*SignerOptions)) error {
return s.verifySignature(credentials, r, payloadHash, service, regionSet, signingTime, signature, false, optFns...)
}
// VerifyPresigned checks sigv4a.
func (s *Signer) VerifyPresigned(credentials Credentials, r *http.Request, payloadHash string, service string, regionSet []string, signingTime time.Time, signature string, optFns ...func(*SignerOptions)) error {
return s.verifySignature(credentials, r, payloadHash, service, regionSet, signingTime, signature, true, optFns...)
}
func (s *Signer) verifySignature(credentials Credentials, r *http.Request, payloadHash string, service string, regionSet []string, signingTime time.Time, signature string, isPresigned bool, optFns ...func(*SignerOptions)) error {
options := s.options
for _, fn := range optFns {
fn(&options)
}
signer := &httpSigner{
Request: r,
PayloadHash: payloadHash,
ServiceName: service,
RegionSet: regionSet,
Credentials: credentials,
Time: signingTime.UTC(),
IsPreSign: isPresigned,
DisableHeaderHoisting: options.DisableHeaderHoisting,
DisableURIPathEscaping: options.DisableURIPathEscaping,
}
signedReq, err := signer.Build()
if err != nil {
return err
}
logHTTPSigningInfo(context.TODO(), options, signedReq)
signatureRaw, err := hex.DecodeString(signature)
if err != nil {
return fmt.Errorf("decode hex signature: %w", err)
}
ok, err := signerCrypto.VerifySignature(&credentials.PrivateKey.PublicKey, makeHash(sha256.New(), []byte(signedReq.StringToSign)), signatureRaw)
if err != nil {
return err
}
if !ok {
return fmt.Errorf("v4a: invalid signature")
}
return nil
}
// PresignHTTP takes the provided http.Request, payload hash, service, regionSet, and time and presigns using SigV4a
// Returns the presigned URL along with the headers that were signed with the request.
//
// PresignHTTP will not set the expires time of the presigned request
// automatically. To specify the expire duration for a request add the
// "X-Amz-Expires" query parameter on the request with the value as the
// duration in seconds the presigned URL should be considered valid for. This
// parameter is not used by all AWS services, and is most notable used by
// Amazon S3 APIs.
func (s *Signer) PresignHTTP(ctx context.Context, credentials Credentials, r *http.Request, payloadHash string, service string, regionSet []string, signingTime time.Time, optFns ...func(*SignerOptions)) (signedURI string, signedHeaders http.Header, err error) {
options := s.options
for _, fn := range optFns {
fn(&options)
}
signer := &httpSigner{
Request: r,
PayloadHash: payloadHash,
ServiceName: service,
RegionSet: regionSet,
Credentials: credentials,
Time: signingTime.UTC(),
IsPreSign: true,
DisableHeaderHoisting: options.DisableHeaderHoisting,
DisableURIPathEscaping: options.DisableURIPathEscaping,
}
signedRequest, err := signer.Build()
if err != nil {
return "", nil, err
}
logHTTPSigningInfo(ctx, options, signedRequest)
signedHeaders = make(http.Header)
// For the signed headers we canonicalize the header keys in the returned map.
// This avoids situations where can standard library double headers like host header. For example the standard
// library will set the Host header, even if it is present in lower-case form.
for k, v := range signedRequest.SignedHeaders {
key := textproto.CanonicalMIMEHeaderKey(k)
signedHeaders[key] = append(signedHeaders[key], v...)
}
return signedRequest.Request.URL.String(), signedHeaders, nil
}
func (s *httpSigner) setRequiredSigningFields(headers http.Header, query url.Values) {
amzDate := s.Time.Format(timeFormat)
if s.IsPreSign {
query.Set(AmzRegionSetKey, strings.Join(s.RegionSet, ","))
query.Set(amzDateKey, amzDate)
query.Set(amzAlgorithmKey, signingAlgorithm)
if len(s.Credentials.SessionToken) > 0 {
query.Set(amzSecurityTokenKey, s.Credentials.SessionToken)
}
return
}
headers.Set(AmzRegionSetKey, strings.Join(s.RegionSet, ","))
headers.Set(amzDateKey, amzDate)
if len(s.Credentials.SessionToken) > 0 {
headers.Set(amzSecurityTokenKey, s.Credentials.SessionToken)
}
}
func (s *httpSigner) Build() (signedRequest, error) {
req := s.Request
query := req.URL.Query()
headers := req.Header
s.setRequiredSigningFields(headers, query)
// Sort Each Query Key's Values
for key := range query {
sort.Strings(query[key])
}
v4Internal.SanitizeHostForHeader(req)
credentialScope := s.buildCredentialScope()
credentialStr := s.Credentials.Context + "/" + credentialScope
if s.IsPreSign {
query.Set(amzCredentialKey, credentialStr)
}
unsignedHeaders := headers
if s.IsPreSign && !s.DisableHeaderHoisting {
urlValues := url.Values{}
urlValues, unsignedHeaders = buildQuery(v4Internal.AllowedQueryHoisting, unsignedHeaders)
for k := range urlValues {
query[k] = urlValues[k]
}
}
host := req.URL.Host
if len(req.Host) > 0 {
host = req.Host
}
var (
signedHeaders http.Header
signedHeadersStr string
canonicalHeaderStr string
)
if s.IsPreSign {
signedHeaders, signedHeadersStr, canonicalHeaderStr = s.buildCanonicalHeaders(host, v4Internal.IgnoredPresignedHeaders, unsignedHeaders, s.Request.ContentLength)
} else {
signedHeaders, signedHeadersStr, canonicalHeaderStr = s.buildCanonicalHeaders(host, v4Internal.IgnoredHeaders, unsignedHeaders, s.Request.ContentLength)
}
if s.IsPreSign {
query.Set(amzSignedHeadersKey, signedHeadersStr)
}
rawQuery := strings.Replace(query.Encode(), "+", "%20", -1)
canonicalURI := v4Internal.GetURIPath(req.URL)
if !s.DisableURIPathEscaping {
canonicalURI = httpbinding.EscapePath(canonicalURI, false)
}
canonicalString := s.buildCanonicalString(
req.Method,
canonicalURI,
rawQuery,
signedHeadersStr,
canonicalHeaderStr,
)
strToSign := s.buildStringToSign(credentialScope, canonicalString)
signingSignature, err := s.buildSignature(strToSign)
if err != nil {
return signedRequest{}, err
}
if s.IsPreSign {
rawQuery += "&X-Amz-Signature=" + signingSignature
} else {
headers[authorizationHeader] = append(headers[authorizationHeader][:0], buildAuthorizationHeader(credentialStr, signedHeadersStr, signingSignature))
}
req.URL.RawQuery = rawQuery
return signedRequest{
Request: req,
SignedHeaders: signedHeaders,
CanonicalString: canonicalString,
StringToSign: strToSign,
PreSigned: s.IsPreSign,
}, nil
}
func buildAuthorizationHeader(credentialStr, signedHeadersStr, signingSignature string) string {
const credential = "Credential="
const signedHeaders = "SignedHeaders="
const signature = "Signature="
const commaSpace = ", "
var parts strings.Builder
parts.Grow(len(signingAlgorithm) + 1 +
len(credential) + len(credentialStr) + len(commaSpace) +
len(signedHeaders) + len(signedHeadersStr) + len(commaSpace) +
len(signature) + len(signingSignature),
)
parts.WriteString(signingAlgorithm)
parts.WriteRune(' ')
parts.WriteString(credential)
parts.WriteString(credentialStr)
parts.WriteString(commaSpace)
parts.WriteString(signedHeaders)
parts.WriteString(signedHeadersStr)
parts.WriteString(commaSpace)
parts.WriteString(signature)
parts.WriteString(signingSignature)
return parts.String()
}
func (s *httpSigner) buildCredentialScope() string {
return strings.Join([]string{
s.Time.Format(shortTimeFormat),
s.ServiceName,
"aws4_request",
}, "/")
}
func buildQuery(r v4Internal.Rule, header http.Header) (url.Values, http.Header) {
query := url.Values{}
unsignedHeaders := http.Header{}
for k, h := range header {
if r.IsValid(k) {
query[k] = h
} else {
unsignedHeaders[k] = h
}
}
return query, unsignedHeaders
}
func (s *httpSigner) buildCanonicalHeaders(host string, rule v4Internal.Rule, header http.Header, length int64) (signed http.Header, signedHeaders, canonicalHeadersStr string) {
signed = make(http.Header)
var headers []string
const hostHeader = "host"
headers = append(headers, hostHeader)
signed[hostHeader] = append(signed[hostHeader], host)
//const contentLengthHeader = "content-length"
//if length > 0 {
// headers = append(headers, contentLengthHeader)
// signed[contentLengthHeader] = append(signed[contentLengthHeader], strconv.FormatInt(length, 10))
//}
for k, v := range header {
if !rule.IsValid(k) {
continue // ignored header
}
lowerCaseKey := strings.ToLower(k)
if _, ok := signed[lowerCaseKey]; ok {
// include additional values
signed[lowerCaseKey] = append(signed[lowerCaseKey], v...)
continue
}
headers = append(headers, lowerCaseKey)
signed[lowerCaseKey] = v
}
sort.Strings(headers)
signedHeaders = strings.Join(headers, ";")
var canonicalHeaders strings.Builder
n := len(headers)
const colon = ':'
for i := 0; i < n; i++ {
if headers[i] == hostHeader {
canonicalHeaders.WriteString(hostHeader)
canonicalHeaders.WriteRune(colon)
canonicalHeaders.WriteString(v4Internal.StripExcessSpaces(host))
} else {
canonicalHeaders.WriteString(headers[i])
canonicalHeaders.WriteRune(colon)
// Trim out leading, trailing, and dedup inner spaces from signed header values.
values := signed[headers[i]]
for j, v := range values {
cleanedValue := strings.TrimSpace(v4Internal.StripExcessSpaces(v))
canonicalHeaders.WriteString(cleanedValue)
if j < len(values)-1 {
canonicalHeaders.WriteRune(',')
}
}
}
canonicalHeaders.WriteRune('\n')
}
canonicalHeadersStr = canonicalHeaders.String()
return signed, signedHeaders, canonicalHeadersStr
}
func (s *httpSigner) buildCanonicalString(method, uri, query, signedHeaders, canonicalHeaders string) string {
return strings.Join([]string{
method,
uri,
query,
canonicalHeaders,
signedHeaders,
s.PayloadHash,
}, "\n")
}
func (s *httpSigner) buildStringToSign(credentialScope, canonicalRequestString string) string {
return strings.Join([]string{
signingAlgorithm,
s.Time.Format(timeFormat),
credentialScope,
hex.EncodeToString(makeHash(sha256.New(), []byte(canonicalRequestString))),
}, "\n")
}
func makeHash(hash hash.Hash, b []byte) []byte {
hash.Reset()
hash.Write(b)
return hash.Sum(nil)
}
func (s *httpSigner) buildSignature(strToSign string) (string, error) {
sig, err := s.Credentials.PrivateKey.Sign(rand.Reader, makeHash(sha256.New(), []byte(strToSign)), crypto.SHA256)
if err != nil {
return "", err
}
return hex.EncodeToString(sig), nil
}
const logSignInfoMsg = `Request Signature:
---[ CANONICAL STRING ]-----------------------------
%s
---[ STRING TO SIGN ]--------------------------------
%s%s
-----------------------------------------------------`
const logSignedURLMsg = `
---[ SIGNED URL ]------------------------------------
%s`
func logHTTPSigningInfo(ctx context.Context, options SignerOptions, r signedRequest) {
if !options.LogSigning {
return
}
signedURLMsg := ""
if r.PreSigned {
signedURLMsg = fmt.Sprintf(logSignedURLMsg, r.Request.URL.String())
}
logger := logging.WithContext(ctx, options.Logger)
logger.Logf(logging.Debug, logSignInfoMsg, r.CanonicalString, r.StringToSign, signedURLMsg)
}
type signedRequest struct {
Request *http.Request
SignedHeaders http.Header
CanonicalString string
StringToSign string
PreSigned bool
}

View file

@ -0,0 +1,429 @@
package v4a
import (
"context"
"encoding/hex"
"fmt"
"math/big"
"net/http"
"net/url"
"reflect"
"strconv"
"strings"
"testing"
"time"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4asdk2/internal/crypto"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/smithy-go/logging"
)
const (
accessKey = "AKISORANDOMAASORANDOM"
secretKey = "q+jcrXGc+0zWN6uzclKVhvMmUsIfRPa4rlRandom"
)
func TestDeriveECDSAKeyPairFromSecret(t *testing.T) {
privateKey, err := deriveKeyFromAccessKeyPair(accessKey, secretKey)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
expectedX := func() *big.Int {
t.Helper()
b, ok := new(big.Int).SetString("15D242CEEBF8D8169FD6A8B5A746C41140414C3B07579038DA06AF89190FFFCB", 16)
if !ok {
t.Fatalf("failed to parse big integer")
}
return b
}()
expectedY := func() *big.Int {
t.Helper()
b, ok := new(big.Int).SetString("515242CEDD82E94799482E4C0514B505AFCCF2C0C98D6A553BF539F424C5EC0", 16)
if !ok {
t.Fatalf("failed to parse big integer")
}
return b
}()
if privateKey.X.Cmp(expectedX) != 0 {
t.Errorf("expected % X, got % X", expectedX, privateKey.X)
}
if privateKey.Y.Cmp(expectedY) != 0 {
t.Errorf("expected % X, got % X", expectedY, privateKey.Y)
}
}
func TestSignHTTP(t *testing.T) {
req := buildRequest("dynamodb", "us-east-1")
signer, credProvider := buildSigner(t, true)
key, err := credProvider.RetrievePrivateKey(context.Background())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
err = signer.SignHTTP(context.Background(), key, req, EmptyStringSHA256, "dynamodb", []string{"us-east-1"}, time.Unix(0, 0))
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
expectedDate := "19700101T000000Z"
expectedAlg := "AWS4-ECDSA-P256-SHA256"
expectedCredential := "AKISORANDOMAASORANDOM/19700101/dynamodb/aws4_request"
expectedSignedHeaders := "content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-region-set;x-amz-security-token;x-amz-target"
expectedStrToSignHash := "4ba7d0482cf4d5450cefdc067a00de1a4a715e444856fa3e1d85c35fb34d9730"
q := req.Header
validateAuthorization(t, q.Get("Authorization"), expectedAlg, expectedCredential, expectedSignedHeaders, expectedStrToSignHash)
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSignHTTP_NoSessionToken(t *testing.T) {
req := buildRequest("dynamodb", "us-east-1")
signer, credProvider := buildSigner(t, false)
key, err := credProvider.RetrievePrivateKey(context.Background())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
err = signer.SignHTTP(context.Background(), key, req, EmptyStringSHA256, "dynamodb", []string{"us-east-1"}, time.Unix(0, 0))
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
expectedAlg := "AWS4-ECDSA-P256-SHA256"
expectedCredential := "AKISORANDOMAASORANDOM/19700101/dynamodb/aws4_request"
expectedSignedHeaders := "content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-region-set;x-amz-target"
expectedStrToSignHash := "1aeefb422ae6aa0de7aec829da813e55cff35553cac212dffd5f9474c71e47ee"
q := req.Header
validateAuthorization(t, q.Get("Authorization"), expectedAlg, expectedCredential, expectedSignedHeaders, expectedStrToSignHash)
}
func TestPresignHTTP(t *testing.T) {
req := buildRequest("dynamodb", "us-east-1")
signer, credProvider := buildSigner(t, false)
key, err := credProvider.RetrievePrivateKey(context.Background())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
query := req.URL.Query()
query.Set("X-Amz-Expires", "18000")
req.URL.RawQuery = query.Encode()
signedURL, _, err := signer.PresignHTTP(context.Background(), key, req, EmptyStringSHA256, "dynamodb", []string{"us-east-1"}, time.Unix(0, 0))
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
expectedDate := "19700101T000000Z"
expectedAlg := "AWS4-ECDSA-P256-SHA256"
expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore"
expectedCredential := "AKISORANDOMAASORANDOM/19700101/dynamodb/aws4_request"
expectedStrToSignHash := "d7ffbd2fab644384c056957e6ac38de4ae68246764b5f5df171b3824153b6397"
expectedTarget := "prefix.Operation"
signedReq, err := url.Parse(signedURL)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
q := signedReq.Query()
validateSignature(t, expectedStrToSignHash, q.Get("X-Amz-Signature"))
if e, a := expectedAlg, q.Get("X-Amz-Algorithm"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedCredential, q.Get("X-Amz-Credential"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedHeaders, q.Get("X-Amz-SignedHeaders"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if a := q.Get("X-Amz-Meta-Other-Header"); len(a) != 0 {
t.Errorf("expect %v to be empty", a)
}
if e, a := expectedTarget, q.Get("X-Amz-Target"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-east-1", q.Get("X-Amz-Region-Set"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestPresignHTTP_BodyWithArrayRequest(t *testing.T) {
req := buildRequest("dynamodb", "us-east-1")
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
signer, credProvider := buildSigner(t, true)
key, err := credProvider.RetrievePrivateKey(context.Background())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
query := req.URL.Query()
query.Set("X-Amz-Expires", "300")
req.URL.RawQuery = query.Encode()
signedURI, _, err := signer.PresignHTTP(context.Background(), key, req, EmptyStringSHA256, "dynamodb", []string{"us-east-1"}, time.Unix(0, 0))
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
signedReq, err := url.Parse(signedURI)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
expectedAlg := "AWS4-ECDSA-P256-SHA256"
expectedDate := "19700101T000000Z"
expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore"
expectedStrToSignHash := "acff64fd3689be96259d4112c3742ff79f4da0d813bc58a285dc1c4449760bec"
expectedCred := "AKISORANDOMAASORANDOM/19700101/dynamodb/aws4_request"
expectedTarget := "prefix.Operation"
q := signedReq.Query()
validateSignature(t, expectedStrToSignHash, q.Get("X-Amz-Signature"))
if e, a := expectedAlg, q.Get("X-Amz-Algorithm"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedCred, q.Get("X-Amz-Credential"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedHeaders, q.Get("X-Amz-SignedHeaders"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if a := q.Get("X-Amz-Meta-Other-Header"); len(a) != 0 {
t.Errorf("expect %v to be empty, was not", a)
}
if e, a := expectedTarget, q.Get("X-Amz-Target"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-east-1", q.Get("X-Amz-Region-Set"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSign_buildCanonicalHeaders(t *testing.T) {
serviceName := "mockAPI"
region := "mock-region"
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
req, err := http.NewRequest("POST", endpoint, nil)
if err != nil {
t.Fatalf("failed to create request, %v", err)
}
req.Header.Set("FooInnerSpace", " inner space ")
req.Header.Set("FooLeadingSpace", " leading-space")
req.Header.Add("FooMultipleSpace", "no-space")
req.Header.Add("FooMultipleSpace", "\ttab-space")
req.Header.Add("FooMultipleSpace", "trailing-space ")
req.Header.Set("FooNoSpace", "no-space")
req.Header.Set("FooTabSpace", "\ttab-space\t")
req.Header.Set("FooTrailingSpace", "trailing-space ")
req.Header.Set("FooWrappedSpace", " wrapped-space ")
credProvider := &SymmetricCredentialAdaptor{
SymmetricProvider: staticCredentialsProvider{
Value: aws.Credentials{
AccessKeyID: accessKey,
SecretAccessKey: secretKey,
},
},
}
key, err := credProvider.RetrievePrivateKey(context.Background())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
ctx := &httpSigner{
Request: req,
ServiceName: serviceName,
RegionSet: []string{region},
Credentials: key,
Time: time.Date(2021, 10, 20, 12, 42, 0, 0, time.UTC),
}
build, err := ctx.Build()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
expectCanonicalString := strings.Join([]string{
`POST`,
`/`,
``,
`fooinnerspace:inner space`,
`fooleadingspace:leading-space`,
`foomultiplespace:no-space,tab-space,trailing-space`,
`foonospace:no-space`,
`footabspace:tab-space`,
`footrailingspace:trailing-space`,
`foowrappedspace:wrapped-space`,
`host:mockAPI.mock-region.amazonaws.com`,
`x-amz-date:20211020T124200Z`,
`x-amz-region-set:mock-region`,
``,
`fooinnerspace;fooleadingspace;foomultiplespace;foonospace;footabspace;footrailingspace;foowrappedspace;host;x-amz-date;x-amz-region-set`,
``,
}, "\n")
if diff := cmpDiff(expectCanonicalString, build.CanonicalString); diff != "" {
t.Errorf("expect match, got\n%s", diff)
}
}
func validateAuthorization(t *testing.T, authorization, expectedAlg, expectedCredential, expectedSignedHeaders, expectedStrToSignHash string) {
t.Helper()
split := strings.SplitN(authorization, " ", 2)
if len(split) != 2 {
t.Fatal("unexpected authorization header format")
}
if e, a := split[0], expectedAlg; e != a {
t.Errorf("expected %v, got %v", e, a)
}
keyValues := strings.Split(split[1], ", ")
seen := make(map[string]string)
for _, kv := range keyValues {
idx := strings.Index(kv, "=")
if idx == -1 {
continue
}
key, value := kv[:idx], kv[idx+1:]
seen[key] = value
}
if a, ok := seen["Credential"]; ok {
if expectedCredential != a {
t.Errorf("expected credential %v, got %v", expectedCredential, a)
}
} else {
t.Errorf("Credential not found in authorization string")
}
if a, ok := seen["SignedHeaders"]; ok {
if expectedSignedHeaders != a {
t.Errorf("expected signed headers %v, got %v", expectedSignedHeaders, a)
}
} else {
t.Errorf("SignedHeaders not found in authorization string")
}
if a, ok := seen["Signature"]; ok {
validateSignature(t, expectedStrToSignHash, a)
} else {
t.Errorf("signature not found in authorization string")
}
}
func validateSignature(t *testing.T, expectedHash, signature string) {
t.Helper()
pair, err := deriveKeyFromAccessKeyPair(accessKey, secretKey)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
hash, _ := hex.DecodeString(expectedHash)
sig, _ := hex.DecodeString(signature)
ok, err := crypto.VerifySignature(&pair.PublicKey, hash, sig)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !ok {
t.Errorf("failed to verify signing singature")
}
}
func buildRequest(serviceName, region string) *http.Request {
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
req, _ := http.NewRequest("POST", endpoint, nil)
req.URL.Opaque = "//example.org/bucket/key-._~,!@%23$%25^&*()"
req.Header.Set("X-Amz-Target", "prefix.Operation")
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
req.Header.Set("Content-Length", strconv.Itoa(1024))
req.Header.Set("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
req.Header.Add("X-Amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
req.Header.Add("X-amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
return req
}
func buildSigner(t *testing.T, withToken bool) (*Signer, CredentialsProvider) {
creds := aws.Credentials{
AccessKeyID: accessKey,
SecretAccessKey: secretKey,
}
if withToken {
creds.SessionToken = "TOKEN"
}
return NewSigner(func(options *SignerOptions) {
options.Logger = loggerFunc(func(format string, v ...interface{}) {
t.Logf(format, v...)
})
}), &SymmetricCredentialAdaptor{
SymmetricProvider: staticCredentialsProvider{
Value: creds,
},
}
}
type loggerFunc func(format string, v ...interface{})
func (l loggerFunc) Logf(_ logging.Classification, format string, v ...interface{}) {
l(format, v...)
}
type staticCredentialsProvider struct {
Value aws.Credentials
}
func (s staticCredentialsProvider) Retrieve(_ context.Context) (aws.Credentials, error) {
v := s.Value
if v.AccessKeyID == "" || v.SecretAccessKey == "" {
return aws.Credentials{
Source: "Source Name",
}, fmt.Errorf("static credentials are empty")
}
if len(v.Source) == 0 {
v.Source = "Source Name"
}
return v, nil
}
func cmpDiff(e, a interface{}) string {
if !reflect.DeepEqual(e, a) {
return fmt.Sprintf("%v != %v", e, a)
}
return ""
}

View file

@ -0,0 +1,115 @@
package v4
import (
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
)
func lookupKey(service, region string) string {
var s strings.Builder
s.Grow(len(region) + len(service) + 3)
s.WriteString(region)
s.WriteRune('/')
s.WriteString(service)
return s.String()
}
type derivedKey struct {
AccessKey string
Date time.Time
Credential []byte
}
type derivedKeyCache struct {
values map[string]derivedKey
mutex sync.RWMutex
}
func newDerivedKeyCache() derivedKeyCache {
return derivedKeyCache{
values: make(map[string]derivedKey),
}
}
func (s *derivedKeyCache) Get(credentials aws.Credentials, service, region string, signingTime SigningTime) []byte {
key := lookupKey(service, region)
s.mutex.RLock()
if cred, ok := s.get(key, credentials, signingTime.Time); ok {
s.mutex.RUnlock()
return cred
}
s.mutex.RUnlock()
s.mutex.Lock()
if cred, ok := s.get(key, credentials, signingTime.Time); ok {
s.mutex.Unlock()
return cred
}
cred := deriveKey(credentials.SecretAccessKey, service, region, signingTime)
entry := derivedKey{
AccessKey: credentials.AccessKeyID,
Date: signingTime.Time,
Credential: cred,
}
s.values[key] = entry
s.mutex.Unlock()
return cred
}
func (s *derivedKeyCache) get(key string, credentials aws.Credentials, signingTime time.Time) ([]byte, bool) {
cacheEntry, ok := s.retrieveFromCache(key)
if ok && cacheEntry.AccessKey == credentials.AccessKeyID && isSameDay(signingTime, cacheEntry.Date) {
return cacheEntry.Credential, true
}
return nil, false
}
func (s *derivedKeyCache) retrieveFromCache(key string) (derivedKey, bool) {
if v, ok := s.values[key]; ok {
return v, true
}
return derivedKey{}, false
}
// SigningKeyDeriver derives a signing key from a set of credentials
type SigningKeyDeriver struct {
cache derivedKeyCache
}
// NewSigningKeyDeriver returns a new SigningKeyDeriver
func NewSigningKeyDeriver() *SigningKeyDeriver {
return &SigningKeyDeriver{
cache: newDerivedKeyCache(),
}
}
// DeriveKey returns a derived signing key from the given credentials to be used with SigV4 signing.
func (k *SigningKeyDeriver) DeriveKey(credential aws.Credentials, service, region string, signingTime SigningTime) []byte {
return k.cache.Get(credential, service, region, signingTime)
}
func deriveKey(secret, service, region string, t SigningTime) []byte {
hmacDate := HMACSHA256([]byte("AWS4"+secret), []byte(t.ShortTimeFormat()))
hmacRegion := HMACSHA256(hmacDate, []byte(region))
hmacService := HMACSHA256(hmacRegion, []byte(service))
return HMACSHA256(hmacService, []byte("aws4_request"))
}
func isSameDay(x, y time.Time) bool {
xYear, xMonth, xDay := x.Date()
yYear, yMonth, yDay := y.Date()
if xYear != yYear {
return false
}
if xMonth != yMonth {
return false
}
return xDay == yDay
}

View file

@ -0,0 +1,40 @@
package v4
// Signature Version 4 (SigV4) Constants
const (
// EmptyStringSHA256 is the hex encoded sha256 value of an empty string
EmptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
// UnsignedPayload indicates that the request payload body is unsigned
UnsignedPayload = "UNSIGNED-PAYLOAD"
// AmzAlgorithmKey indicates the signing algorithm
AmzAlgorithmKey = "X-Amz-Algorithm"
// AmzSecurityTokenKey indicates the security token to be used with temporary credentials
AmzSecurityTokenKey = "X-Amz-Security-Token"
// AmzDateKey is the UTC timestamp for the request in the format YYYYMMDD'T'HHMMSS'Z'
AmzDateKey = "X-Amz-Date"
// AmzCredentialKey is the access key ID and credential scope
AmzCredentialKey = "X-Amz-Credential"
// AmzSignedHeadersKey is the set of headers signed for the request
AmzSignedHeadersKey = "X-Amz-SignedHeaders"
// AmzSignatureKey is the query parameter to store the SigV4 signature
AmzSignatureKey = "X-Amz-Signature"
// TimeFormat is the time format to be used in the X-Amz-Date header or query parameter
TimeFormat = "20060102T150405Z"
// ShortTimeFormat is the shorten time format used in the credential scope
ShortTimeFormat = "20060102"
// ContentSHAKey is the SHA256 of request body
ContentSHAKey = "X-Amz-Content-Sha256"
// StreamingEventsPayload indicates that the request payload body is a signed event stream.
StreamingEventsPayload = "STREAMING-AWS4-HMAC-SHA256-EVENTS"
)

View file

@ -0,0 +1,88 @@
package v4
import (
"strings"
)
// Rules houses a set of Rule needed for validation of a
// string value
type Rules []Rule
// Rule interface allows for more flexible rules and just simply
// checks whether or not a value adheres to that Rule
type Rule interface {
IsValid(value string) bool
}
// IsValid will iterate through all rules and see if any rules
// apply to the value and supports nested rules
func (r Rules) IsValid(value string) bool {
for _, rule := range r {
if rule.IsValid(value) {
return true
}
}
return false
}
// MapRule generic Rule for maps
type MapRule map[string]struct{}
// IsValid for the map Rule satisfies whether it exists in the map
func (m MapRule) IsValid(value string) bool {
_, ok := m[value]
return ok
}
// AllowList is a generic Rule for include listing
type AllowList struct {
Rule
}
// IsValid for AllowList checks if the value is within the AllowList
func (w AllowList) IsValid(value string) bool {
return w.Rule.IsValid(value)
}
// ExcludeList is a generic Rule for exclude listing
type ExcludeList struct {
Rule
}
// IsValid for AllowList checks if the value is within the AllowList
func (b ExcludeList) IsValid(value string) bool {
return !b.Rule.IsValid(value)
}
// Patterns is a list of strings to match against
type Patterns []string
// IsValid for Patterns checks each pattern and returns if a match has
// been found
func (p Patterns) IsValid(value string) bool {
for _, pattern := range p {
if HasPrefixFold(value, pattern) {
return true
}
}
return false
}
// InclusiveRules rules allow for rules to depend on one another
type InclusiveRules []Rule
// IsValid will return true if all rules are true
func (r InclusiveRules) IsValid(value string) bool {
for _, rule := range r {
if !rule.IsValid(value) {
return false
}
}
return true
}
// HasPrefixFold tests whether the string s begins with prefix, interpreted as UTF-8 strings,
// under Unicode case-folding.
func HasPrefixFold(s, prefix string) bool {
return len(s) >= len(prefix) && strings.EqualFold(s[0:len(prefix)], prefix)
}

View file

@ -0,0 +1,84 @@
package v4
// IgnoredPresignedHeaders is a list of headers that are ignored during signing
var IgnoredPresignedHeaders = Rules{
ExcludeList{
MapRule{
"Authorization": struct{}{},
"User-Agent": struct{}{},
"X-Amzn-Trace-Id": struct{}{},
"Expect": struct{}{},
},
},
}
// IgnoredHeaders is a list of headers that are ignored during signing
// drop User-Agent header to be compatible with aws sdk java v1.
var IgnoredHeaders = Rules{
ExcludeList{
MapRule{
"Authorization": struct{}{},
//"User-Agent": struct{}{},
"X-Amzn-Trace-Id": struct{}{},
"Expect": struct{}{},
},
},
}
// RequiredSignedHeaders is a allow list for Build canonical headers.
var RequiredSignedHeaders = Rules{
AllowList{
MapRule{
"Cache-Control": struct{}{},
"Content-Disposition": struct{}{},
"Content-Encoding": struct{}{},
"Content-Language": struct{}{},
"Content-Md5": struct{}{},
"Content-Type": struct{}{},
"Expires": struct{}{},
"If-Match": struct{}{},
"If-Modified-Since": struct{}{},
"If-None-Match": struct{}{},
"If-Unmodified-Since": struct{}{},
"Range": struct{}{},
"X-Amz-Acl": struct{}{},
"X-Amz-Copy-Source": struct{}{},
"X-Amz-Copy-Source-If-Match": struct{}{},
"X-Amz-Copy-Source-If-Modified-Since": struct{}{},
"X-Amz-Copy-Source-If-None-Match": struct{}{},
"X-Amz-Copy-Source-If-Unmodified-Since": struct{}{},
"X-Amz-Copy-Source-Range": struct{}{},
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Algorithm": struct{}{},
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key": struct{}{},
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key-Md5": struct{}{},
"X-Amz-Expected-Bucket-Owner": struct{}{},
"X-Amz-Grant-Full-control": struct{}{},
"X-Amz-Grant-Read": struct{}{},
"X-Amz-Grant-Read-Acp": struct{}{},
"X-Amz-Grant-Write": struct{}{},
"X-Amz-Grant-Write-Acp": struct{}{},
"X-Amz-Metadata-Directive": struct{}{},
"X-Amz-Mfa": struct{}{},
"X-Amz-Request-Payer": struct{}{},
"X-Amz-Server-Side-Encryption": struct{}{},
"X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id": struct{}{},
"X-Amz-Server-Side-Encryption-Context": struct{}{},
"X-Amz-Server-Side-Encryption-Customer-Algorithm": struct{}{},
"X-Amz-Server-Side-Encryption-Customer-Key": struct{}{},
"X-Amz-Server-Side-Encryption-Customer-Key-Md5": struct{}{},
"X-Amz-Storage-Class": struct{}{},
"X-Amz-Website-Redirect-Location": struct{}{},
"X-Amz-Content-Sha256": struct{}{},
"X-Amz-Tagging": struct{}{},
},
},
Patterns{"X-Amz-Object-Lock-"},
Patterns{"X-Amz-Meta-"},
}
// AllowedQueryHoisting is a allowed list for Build query headers. The boolean value
// represents whether or not it is a pattern.
var AllowedQueryHoisting = InclusiveRules{
ExcludeList{RequiredSignedHeaders},
Patterns{"X-Amz-"},
}

View file

@ -0,0 +1,63 @@
package v4
import "testing"
func TestAllowedQueryHoisting(t *testing.T) {
cases := map[string]struct {
Header string
ExpectHoist bool
}{
"object-lock": {
Header: "X-Amz-Object-Lock-Mode",
ExpectHoist: false,
},
"s3 metadata": {
Header: "X-Amz-Meta-SomeName",
ExpectHoist: false,
},
"another header": {
Header: "X-Amz-SomeOtherHeader",
ExpectHoist: true,
},
"non X-AMZ header": {
Header: "X-SomeOtherHeader",
ExpectHoist: false,
},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
if e, a := c.ExpectHoist, AllowedQueryHoisting.IsValid(c.Header); e != a {
t.Errorf("expect hoist %v, was %v", e, a)
}
})
}
}
func TestIgnoredHeaders(t *testing.T) {
cases := map[string]struct {
Header string
ExpectIgnored bool
}{
"expect": {
Header: "Expect",
ExpectIgnored: true,
},
"authorization": {
Header: "Authorization",
ExpectIgnored: true,
},
"X-AMZ header": {
Header: "X-Amz-Content-Sha256",
ExpectIgnored: false,
},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
if e, a := c.ExpectIgnored, IgnoredHeaders.IsValid(c.Header); e == a {
t.Errorf("expect ignored %v, was %v", e, a)
}
})
}
}

View file

@ -0,0 +1,13 @@
package v4
import (
"crypto/hmac"
"crypto/sha256"
)
// HMACSHA256 computes a HMAC-SHA256 of data given the provided key.
func HMACSHA256(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(data)
return hash.Sum(nil)
}

View file

@ -0,0 +1,75 @@
package v4
import (
"net/http"
"strings"
)
// SanitizeHostForHeader removes default port from host and updates request.Host
func SanitizeHostForHeader(r *http.Request) {
host := getHost(r)
port := portOnly(host)
if port != "" && isDefaultPort(r.URL.Scheme, port) {
r.Host = stripPort(host)
}
}
// Returns host from request
func getHost(r *http.Request) string {
if r.Host != "" {
return r.Host
}
return r.URL.Host
}
// Hostname returns u.Host, without any port number.
//
// If Host is an IPv6 literal with a port number, Hostname returns the
// IPv6 literal without the square brackets. IPv6 literals may include
// a zone identifier.
//
// Copied from the Go 1.8 standard library (net/url)
func stripPort(hostport string) string {
colon := strings.IndexByte(hostport, ':')
if colon == -1 {
return hostport
}
if i := strings.IndexByte(hostport, ']'); i != -1 {
return strings.TrimPrefix(hostport[:i], "[")
}
return hostport[:colon]
}
// Port returns the port part of u.Host, without the leading colon.
// If u.Host doesn't contain a port, Port returns an empty string.
//
// Copied from the Go 1.8 standard library (net/url)
func portOnly(hostport string) string {
colon := strings.IndexByte(hostport, ':')
if colon == -1 {
return ""
}
if i := strings.Index(hostport, "]:"); i != -1 {
return hostport[i+len("]:"):]
}
if strings.Contains(hostport, "]") {
return ""
}
return hostport[colon+len(":"):]
}
// Returns true if the specified URI is using the standard port
// (i.e. port 80 for HTTP URIs or 443 for HTTPS URIs)
func isDefaultPort(scheme, port string) bool {
if port == "" {
return true
}
lowerCaseScheme := strings.ToLower(scheme)
if (lowerCaseScheme == "http" && port == "80") || (lowerCaseScheme == "https" && port == "443") {
return true
}
return false
}

View file

@ -0,0 +1,13 @@
package v4
import "strings"
// BuildCredentialScope builds the Signature Version 4 (SigV4) signing scope
func BuildCredentialScope(signingTime SigningTime, region, service string) string {
return strings.Join([]string{
signingTime.ShortTimeFormat(),
region,
service,
"aws4_request",
}, "/")
}

View file

@ -0,0 +1,36 @@
package v4
import "time"
// SigningTime provides a wrapper around a time.Time which provides cached values for SigV4 signing.
type SigningTime struct {
time.Time
timeFormat string
shortTimeFormat string
}
// NewSigningTime creates a new SigningTime given a time.Time
func NewSigningTime(t time.Time) SigningTime {
return SigningTime{
Time: t,
}
}
// TimeFormat provides a time formatted in the X-Amz-Date format.
func (m *SigningTime) TimeFormat() string {
return m.format(&m.timeFormat, TimeFormat)
}
// ShortTimeFormat provides a time formatted of 20060102.
func (m *SigningTime) ShortTimeFormat() string {
return m.format(&m.shortTimeFormat, ShortTimeFormat)
}
func (m *SigningTime) format(target *string, format string) string {
if len(*target) > 0 {
return *target
}
v := m.Time.Format(format)
*target = v
return v
}

View file

@ -0,0 +1,80 @@
package v4
import (
"net/url"
"strings"
)
const doubleSpace = " "
// StripExcessSpaces will rewrite the passed in slice's string values to not
// contain multiple side-by-side spaces.
func StripExcessSpaces(str string) string {
var j, k, l, m, spaces int
// Trim trailing spaces
for j = len(str) - 1; j >= 0 && str[j] == ' '; j-- {
}
// Trim leading spaces
for k = 0; k < j && str[k] == ' '; k++ {
}
str = str[k : j+1]
// Strip multiple spaces.
j = strings.Index(str, doubleSpace)
if j < 0 {
return str
}
buf := []byte(str)
for k, m, l = j, j, len(buf); k < l; k++ {
if buf[k] == ' ' {
if spaces == 0 {
// First space.
buf[m] = buf[k]
m++
}
spaces++
} else {
// End of multiple spaces.
spaces = 0
buf[m] = buf[k]
m++
}
}
return string(buf[:m])
}
// GetURIPath returns the escaped URI component from the provided URL.
func GetURIPath(u *url.URL) string {
var uriPath string
if len(u.Opaque) > 0 {
const schemeSep, pathSep, queryStart = "//", "/", "?"
opaque := u.Opaque
// Cut off the query string if present.
if idx := strings.Index(opaque, queryStart); idx >= 0 {
opaque = opaque[:idx]
}
// Cutout the scheme separator if present.
if strings.Index(opaque, schemeSep) == 0 {
opaque = opaque[len(schemeSep):]
}
// capture URI path starting with first path separator.
if idx := strings.Index(opaque, pathSep); idx >= 0 {
uriPath = opaque[idx:]
}
} else {
uriPath = u.EscapedPath()
}
if len(uriPath) == 0 {
uriPath = "/"
}
return uriPath
}

View file

@ -0,0 +1,158 @@
package v4
import (
"net/http"
"net/url"
"testing"
)
func lazyURLParse(v string) func() (*url.URL, error) {
return func() (*url.URL, error) {
return url.Parse(v)
}
}
func TestGetURIPath(t *testing.T) {
cases := map[string]struct {
getURL func() (*url.URL, error)
expect string
}{
// Cases
"with scheme": {
getURL: lazyURLParse("https://localhost:9000"),
expect: "/",
},
"no port, with scheme": {
getURL: lazyURLParse("https://localhost"),
expect: "/",
},
"without scheme": {
getURL: lazyURLParse("localhost:9000"),
expect: "/",
},
"without scheme, with path": {
getURL: lazyURLParse("localhost:9000/abc123"),
expect: "/abc123",
},
"without scheme, with separator": {
getURL: lazyURLParse("//localhost:9000"),
expect: "/",
},
"no port, without scheme, with separator": {
getURL: lazyURLParse("//localhost"),
expect: "/",
},
"without scheme, with separator, with path": {
getURL: lazyURLParse("//localhost:9000/abc123"),
expect: "/abc123",
},
"no port, without scheme, with separator, with path": {
getURL: lazyURLParse("//localhost/abc123"),
expect: "/abc123",
},
"opaque with query string": {
getURL: lazyURLParse("localhost:9000/abc123?efg=456"),
expect: "/abc123",
},
"failing test": {
getURL: func() (*url.URL, error) {
endpoint := "https://service.region.amazonaws.com"
req, _ := http.NewRequest("POST", endpoint, nil)
u := req.URL
u.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
query := u.Query()
query.Set("some-query-key", "value")
u.RawQuery = query.Encode()
return u, nil
},
expect: "/bucket/key-._~,!@#$%^&*()",
},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
u, err := c.getURL()
if err != nil {
t.Fatalf("failed to get URL, %v", err)
}
actual := GetURIPath(u)
if e, a := c.expect, actual; e != a {
t.Errorf("expect %v path, got %v", e, a)
}
})
}
}
func TestStripExcessHeaders(t *testing.T) {
vals := []string{
"",
"123",
"1 2 3",
"1 2 3 ",
" 1 2 3",
"1 2 3",
"1 23",
"1 2 3",
"1 2 ",
" 1 2 ",
"12 3",
"12 3 1",
"12 3 1",
"12 3 1abc123",
}
expected := []string{
"",
"123",
"1 2 3",
"1 2 3",
"1 2 3",
"1 2 3",
"1 23",
"1 2 3",
"1 2",
"1 2",
"12 3",
"12 3 1",
"12 3 1",
"12 3 1abc123",
}
for i := 0; i < len(vals); i++ {
r := StripExcessSpaces(vals[i])
if e, a := expected[i], r; e != a {
t.Errorf("%d, expect %v, got %v", i, e, a)
}
}
}
var stripExcessSpaceCases = []string{
`AWS4-HMAC-SHA256 Credential=AKIDFAKEIDFAKEID/20160628/us-west-2/s3/aws4_request, SignedHeaders=host;x-amz-date, Signature=1234567890abcdef1234567890abcdef1234567890abcdef`,
`123 321 123 321`,
` 123 321 123 321 `,
` 123 321 123 321 `,
"123",
"1 2 3",
" 1 2 3",
"1 2 3",
"1 23",
"1 2 3",
"1 2 ",
" 1 2 ",
"12 3",
"12 3 1",
"12 3 1",
"12 3 1abc123",
}
func BenchmarkStripExcessSpaces(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, v := range stripExcessSpaceCases {
StripExcessSpaces(v)
}
}
}

View file

@ -0,0 +1,87 @@
package v4
import (
"context"
"crypto/sha256"
"encoding/hex"
"strings"
"time"
v4Internal "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4sdk2/signer/internal/v4"
"github.com/aws/aws-sdk-go-v2/aws"
)
// EventStreamSigner is an AWS EventStream protocol signer.
type EventStreamSigner interface {
GetSignature(ctx context.Context, headers, payload []byte, signingTime time.Time, optFns ...func(*StreamSignerOptions)) ([]byte, error)
}
// StreamSignerOptions is the configuration options for StreamSigner.
type StreamSignerOptions struct{}
// StreamSigner implements Signature Version 4 (SigV4) signing of event stream encoded payloads.
type StreamSigner struct {
options StreamSignerOptions
credentials aws.Credentials
service string
region string
prevSignature []byte
signingKeyDeriver *v4Internal.SigningKeyDeriver
}
// NewStreamSigner returns a new AWS EventStream protocol signer.
func NewStreamSigner(credentials aws.Credentials, service, region string, seedSignature []byte, optFns ...func(*StreamSignerOptions)) *StreamSigner {
o := StreamSignerOptions{}
for _, fn := range optFns {
fn(&o)
}
return &StreamSigner{
options: o,
credentials: credentials,
service: service,
region: region,
signingKeyDeriver: v4Internal.NewSigningKeyDeriver(),
prevSignature: seedSignature,
}
}
// GetSignature signs the provided header and payload bytes.
func (s *StreamSigner) GetSignature(ctx context.Context, headers, payload []byte, signingTime time.Time, optFns ...func(*StreamSignerOptions)) ([]byte, error) {
options := s.options
for _, fn := range optFns {
fn(&options)
}
prevSignature := s.prevSignature
st := v4Internal.NewSigningTime(signingTime)
sigKey := s.signingKeyDeriver.DeriveKey(s.credentials, s.service, s.region, st)
scope := v4Internal.BuildCredentialScope(st, s.region, s.service)
stringToSign := s.buildEventStreamStringToSign(headers, payload, prevSignature, scope, &st)
signature := v4Internal.HMACSHA256(sigKey, []byte(stringToSign))
s.prevSignature = signature
return signature, nil
}
func (s *StreamSigner) buildEventStreamStringToSign(headers, payload, previousSignature []byte, credentialScope string, signingTime *v4Internal.SigningTime) string {
hash := sha256.New()
return strings.Join([]string{
"AWS4-HMAC-SHA256-PAYLOAD",
signingTime.TimeFormat(),
credentialScope,
hex.EncodeToString(previousSignature),
hex.EncodeToString(makeHash(hash, headers)),
hex.EncodeToString(makeHash(hash, payload)),
}, "\n")
}

View file

@ -0,0 +1,574 @@
// Package v4 implements the AWS signature version 4 algorithm (commonly known
// as SigV4).
//
// For more information about SigV4, see [Signing AWS API requests] in the IAM
// user guide.
//
// While this implementation CAN work in an external context, it is developed
// primarily for SDK use and you may encounter fringe behaviors around header
// canonicalization.
//
// # Pre-escaping a request URI
//
// AWS v4 signature validation requires that the canonical string's URI path
// component must be the escaped form of the HTTP request's path.
//
// The Go HTTP client will perform escaping automatically on the HTTP request.
// This may cause signature validation errors because the request differs from
// the URI path or query from which the signature was generated.
//
// Because of this, we recommend that you explicitly escape the request when
// using this signer outside of the SDK to prevent possible signature mismatch.
// This can be done by setting URL.Opaque on the request. The signer will
// prefer that value, falling back to the return of URL.EscapedPath if unset.
//
// When setting URL.Opaque you must do so in the form of:
//
// "//<hostname>/<path>"
//
// // e.g.
// "//example.com/some/path"
//
// The leading "//" and hostname are required or the escaping will not work
// correctly.
//
// The TestStandaloneSign unit test provides a complete example of using the
// signer outside of the SDK and pre-escaping the URI path.
//
// [Signing AWS API requests]: https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_aws-signing.html
package v4
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"hash"
"net/http"
"net/textproto"
"net/url"
"sort"
"strconv"
"strings"
"time"
v4Internal "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4sdk2/signer/internal/v4"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/smithy-go/encoding/httpbinding"
"github.com/aws/smithy-go/logging"
)
const (
signingAlgorithm = "AWS4-HMAC-SHA256"
authorizationHeader = "Authorization"
// Version of signing v4
Version = "SigV4"
)
// HTTPSigner is an interface to a SigV4 signer that can sign HTTP requests
type HTTPSigner interface {
SignHTTP(ctx context.Context, credentials aws.Credentials, r *http.Request, payloadHash string, service string, region string, signingTime time.Time, optFns ...func(*SignerOptions)) error
}
type keyDerivator interface {
DeriveKey(credential aws.Credentials, service, region string, signingTime v4Internal.SigningTime) []byte
}
// SignerOptions is the SigV4 Signer options.
type SignerOptions struct {
// Disables the Signer's moving HTTP header key/value pairs from the HTTP
// request header to the request's query string. This is most commonly used
// with pre-signed requests preventing headers from being added to the
// request's query string.
DisableHeaderHoisting bool
// Disables the automatic escaping of the URI path of the request for the
// siganture's canonical string's path. For services that do not need additional
// escaping then use this to disable the signer escaping the path.
//
// S3 is an example of a service that does not need additional escaping.
//
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
DisableURIPathEscaping bool
// The logger to send log messages to.
Logger logging.Logger
// Enable logging of signed requests.
// This will enable logging of the canonical request, the string to sign, and for presigning the subsequent
// presigned URL.
LogSigning bool
// Disables setting the session token on the request as part of signing
// through X-Amz-Security-Token. This is needed for variations of v4 that
// present the token elsewhere.
DisableSessionToken bool
}
// Signer applies AWS v4 signing to given request. Use this to sign requests
// that need to be signed with AWS V4 Signatures.
type Signer struct {
options SignerOptions
keyDerivator keyDerivator
}
// NewSigner returns a new SigV4 Signer
func NewSigner(optFns ...func(signer *SignerOptions)) *Signer {
options := SignerOptions{}
for _, fn := range optFns {
fn(&options)
}
return &Signer{options: options, keyDerivator: v4Internal.NewSigningKeyDeriver()}
}
type httpSigner struct {
Request *http.Request
ServiceName string
Region string
Time v4Internal.SigningTime
Credentials aws.Credentials
KeyDerivator keyDerivator
IsPreSign bool
PayloadHash string
DisableHeaderHoisting bool
DisableURIPathEscaping bool
DisableSessionToken bool
}
func (s *httpSigner) Build() (signedRequest, error) {
req := s.Request
query := req.URL.Query()
headers := req.Header
s.setRequiredSigningFields(headers, query)
// Sort Each Query Key's Values
for key := range query {
sort.Strings(query[key])
}
v4Internal.SanitizeHostForHeader(req)
credentialScope := s.buildCredentialScope()
credentialStr := s.Credentials.AccessKeyID + "/" + credentialScope
if s.IsPreSign {
query.Set(v4Internal.AmzCredentialKey, credentialStr)
}
unsignedHeaders := headers
if s.IsPreSign && !s.DisableHeaderHoisting {
var urlValues url.Values
urlValues, unsignedHeaders = buildQuery(v4Internal.AllowedQueryHoisting, headers)
for k := range urlValues {
query[k] = urlValues[k]
}
}
host := req.URL.Host
if len(req.Host) > 0 {
host = req.Host
}
var (
signedHeaders http.Header
signedHeadersStr string
canonicalHeaderStr string
)
if s.IsPreSign {
signedHeaders, signedHeadersStr, canonicalHeaderStr = s.buildCanonicalHeaders(host, v4Internal.IgnoredPresignedHeaders, unsignedHeaders, s.Request.ContentLength)
} else {
signedHeaders, signedHeadersStr, canonicalHeaderStr = s.buildCanonicalHeaders(host, v4Internal.IgnoredHeaders, unsignedHeaders, s.Request.ContentLength)
}
if s.IsPreSign {
query.Set(v4Internal.AmzSignedHeadersKey, signedHeadersStr)
}
var rawQuery strings.Builder
rawQuery.WriteString(strings.Replace(query.Encode(), "+", "%20", -1))
canonicalURI := v4Internal.GetURIPath(req.URL)
if !s.DisableURIPathEscaping {
canonicalURI = httpbinding.EscapePath(canonicalURI, false)
}
canonicalString := s.buildCanonicalString(
req.Method,
canonicalURI,
rawQuery.String(),
signedHeadersStr,
canonicalHeaderStr,
)
strToSign := s.buildStringToSign(credentialScope, canonicalString)
signingSignature, err := s.buildSignature(strToSign)
if err != nil {
return signedRequest{}, err
}
if s.IsPreSign {
rawQuery.WriteString("&X-Amz-Signature=")
rawQuery.WriteString(signingSignature)
} else {
headers[authorizationHeader] = append(headers[authorizationHeader][:0], buildAuthorizationHeader(credentialStr, signedHeadersStr, signingSignature))
}
req.URL.RawQuery = rawQuery.String()
return signedRequest{
Request: req,
SignedHeaders: signedHeaders,
CanonicalString: canonicalString,
StringToSign: strToSign,
PreSigned: s.IsPreSign,
}, nil
}
func buildAuthorizationHeader(credentialStr, signedHeadersStr, signingSignature string) string {
const credential = "Credential="
const signedHeaders = "SignedHeaders="
const signature = "Signature="
const commaSpace = ", "
var parts strings.Builder
parts.Grow(len(signingAlgorithm) + 1 +
len(credential) + len(credentialStr) + 2 +
len(signedHeaders) + len(signedHeadersStr) + 2 +
len(signature) + len(signingSignature),
)
parts.WriteString(signingAlgorithm)
parts.WriteRune(' ')
parts.WriteString(credential)
parts.WriteString(credentialStr)
parts.WriteString(commaSpace)
parts.WriteString(signedHeaders)
parts.WriteString(signedHeadersStr)
parts.WriteString(commaSpace)
parts.WriteString(signature)
parts.WriteString(signingSignature)
return parts.String()
}
// SignHTTP signs AWS v4 requests with the provided payload hash, service name, region the
// request is made to, and time the request is signed at. The signTime allows
// you to specify that a request is signed for the future, and cannot be
// used until then.
//
// The payloadHash is the hex encoded SHA-256 hash of the request payload, and
// must be provided. Even if the request has no payload (aka body). If the
// request has no payload you should use the hex encoded SHA-256 of an empty
// string as the payloadHash value.
//
// "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
//
// Some services such as Amazon S3 accept alternative values for the payload
// hash, such as "UNSIGNED-PAYLOAD" for requests where the body will not be
// included in the request signature.
//
// https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html
//
// Sign differs from Presign in that it will sign the request using HTTP
// header values. This type of signing is intended for http.Request values that
// will not be shared, or are shared in a way the header values on the request
// will not be lost.
//
// The passed in request will be modified in place.
func (s Signer) SignHTTP(ctx context.Context, credentials aws.Credentials, r *http.Request, payloadHash string, service string, region string, signingTime time.Time, optFns ...func(options *SignerOptions)) error {
options := s.options
for _, fn := range optFns {
fn(&options)
}
signer := &httpSigner{
Request: r,
PayloadHash: payloadHash,
ServiceName: service,
Region: region,
Credentials: credentials,
Time: v4Internal.NewSigningTime(signingTime.UTC()),
DisableHeaderHoisting: options.DisableHeaderHoisting,
DisableURIPathEscaping: options.DisableURIPathEscaping,
DisableSessionToken: options.DisableSessionToken,
KeyDerivator: s.keyDerivator,
}
signedRequest, err := signer.Build()
if err != nil {
return err
}
logSigningInfo(ctx, options, &signedRequest, false)
return nil
}
// PresignHTTP signs AWS v4 requests with the payload hash, service name, region
// the request is made to, and time the request is signed at. The signTime
// allows you to specify that a request is signed for the future, and cannot
// be used until then.
//
// Returns the signed URL and the map of HTTP headers that were included in the
// signature or an error if signing the request failed. For presigned requests
// these headers and their values must be included on the HTTP request when it
// is made. This is helpful to know what header values need to be shared with
// the party the presigned request will be distributed to.
//
// The payloadHash is the hex encoded SHA-256 hash of the request payload, and
// must be provided. Even if the request has no payload (aka body). If the
// request has no payload you should use the hex encoded SHA-256 of an empty
// string as the payloadHash value.
//
// "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
//
// Some services such as Amazon S3 accept alternative values for the payload
// hash, such as "UNSIGNED-PAYLOAD" for requests where the body will not be
// included in the request signature.
//
// https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html
//
// PresignHTTP differs from SignHTTP in that it will sign the request using
// query string instead of header values. This allows you to share the
// Presigned Request's URL with third parties, or distribute it throughout your
// system with minimal dependencies.
//
// PresignHTTP will not set the expires time of the presigned request
// automatically. To specify the expire duration for a request add the
// "X-Amz-Expires" query parameter on the request with the value as the
// duration in seconds the presigned URL should be considered valid for. This
// parameter is not used by all AWS services, and is most notable used by
// Amazon S3 APIs.
//
// expires := 20 * time.Minute
// query := req.URL.Query()
// query.Set("X-Amz-Expires", strconv.FormatInt(int64(expires/time.Second), 10))
// req.URL.RawQuery = query.Encode()
//
// This method does not modify the provided request.
func (s *Signer) PresignHTTP(
ctx context.Context, credentials aws.Credentials, r *http.Request,
payloadHash string, service string, region string, signingTime time.Time,
optFns ...func(*SignerOptions),
) (signedURI string, signedHeaders http.Header, err error) {
options := s.options
for _, fn := range optFns {
fn(&options)
}
signer := &httpSigner{
Request: r.Clone(r.Context()),
PayloadHash: payloadHash,
ServiceName: service,
Region: region,
Credentials: credentials,
Time: v4Internal.NewSigningTime(signingTime.UTC()),
IsPreSign: true,
DisableHeaderHoisting: options.DisableHeaderHoisting,
DisableURIPathEscaping: options.DisableURIPathEscaping,
DisableSessionToken: options.DisableSessionToken,
KeyDerivator: s.keyDerivator,
}
signedRequest, err := signer.Build()
if err != nil {
return "", nil, err
}
logSigningInfo(ctx, options, &signedRequest, true)
signedHeaders = make(http.Header)
// For the signed headers we canonicalize the header keys in the returned map.
// This avoids situations where can standard library double headers like host header. For example the standard
// library will set the Host header, even if it is present in lower-case form.
for k, v := range signedRequest.SignedHeaders {
key := textproto.CanonicalMIMEHeaderKey(k)
signedHeaders[key] = append(signedHeaders[key], v...)
}
return signedRequest.Request.URL.String(), signedHeaders, nil
}
func (s *httpSigner) buildCredentialScope() string {
return v4Internal.BuildCredentialScope(s.Time, s.Region, s.ServiceName)
}
func buildQuery(r v4Internal.Rule, header http.Header) (url.Values, http.Header) {
query := url.Values{}
unsignedHeaders := http.Header{}
// A list of headers to be converted to lower case to mitigate a limitation from S3
lowerCaseHeaders := map[string]string{
"X-Amz-Expected-Bucket-Owner": "x-amz-expected-bucket-owner", // see #2508
"X-Amz-Request-Payer": "x-amz-request-payer", // see #2764
}
for k, h := range header {
if newKey, ok := lowerCaseHeaders[k]; ok {
k = newKey
}
if r.IsValid(k) {
query[k] = h
} else {
unsignedHeaders[k] = h
}
}
return query, unsignedHeaders
}
func (s *httpSigner) buildCanonicalHeaders(host string, rule v4Internal.Rule, header http.Header, length int64) (signed http.Header, signedHeaders, canonicalHeadersStr string) {
signed = make(http.Header)
var headers []string
const hostHeader = "host"
headers = append(headers, hostHeader)
signed[hostHeader] = append(signed[hostHeader], host)
const contentLengthHeader = "content-length"
if length > 0 {
headers = append(headers, contentLengthHeader)
signed[contentLengthHeader] = append(signed[contentLengthHeader], strconv.FormatInt(length, 10))
}
for k, v := range header {
if !rule.IsValid(k) {
continue // ignored header
}
if strings.EqualFold(k, contentLengthHeader) {
// prevent signing already handled content-length header.
continue
}
lowerCaseKey := strings.ToLower(k)
if _, ok := signed[lowerCaseKey]; ok {
// include additional values
signed[lowerCaseKey] = append(signed[lowerCaseKey], v...)
continue
}
headers = append(headers, lowerCaseKey)
signed[lowerCaseKey] = v
}
sort.Strings(headers)
signedHeaders = strings.Join(headers, ";")
var canonicalHeaders strings.Builder
n := len(headers)
const colon = ':'
for i := 0; i < n; i++ {
if headers[i] == hostHeader {
canonicalHeaders.WriteString(hostHeader)
canonicalHeaders.WriteRune(colon)
canonicalHeaders.WriteString(v4Internal.StripExcessSpaces(host))
} else {
canonicalHeaders.WriteString(headers[i])
canonicalHeaders.WriteRune(colon)
// Trim out leading, trailing, and dedup inner spaces from signed header values.
values := signed[headers[i]]
for j, v := range values {
cleanedValue := strings.TrimSpace(v4Internal.StripExcessSpaces(v))
canonicalHeaders.WriteString(cleanedValue)
if j < len(values)-1 {
canonicalHeaders.WriteRune(',')
}
}
}
canonicalHeaders.WriteRune('\n')
}
canonicalHeadersStr = canonicalHeaders.String()
return signed, signedHeaders, canonicalHeadersStr
}
func (s *httpSigner) buildCanonicalString(method, uri, query, signedHeaders, canonicalHeaders string) string {
return strings.Join([]string{
method,
uri,
query,
canonicalHeaders,
signedHeaders,
s.PayloadHash,
}, "\n")
}
func (s *httpSigner) buildStringToSign(credentialScope, canonicalRequestString string) string {
return strings.Join([]string{
signingAlgorithm,
s.Time.TimeFormat(),
credentialScope,
hex.EncodeToString(makeHash(sha256.New(), []byte(canonicalRequestString))),
}, "\n")
}
func makeHash(hash hash.Hash, b []byte) []byte {
hash.Reset()
hash.Write(b)
return hash.Sum(nil)
}
func (s *httpSigner) buildSignature(strToSign string) (string, error) {
key := s.KeyDerivator.DeriveKey(s.Credentials, s.ServiceName, s.Region, s.Time)
return hex.EncodeToString(v4Internal.HMACSHA256(key, []byte(strToSign))), nil
}
func (s *httpSigner) setRequiredSigningFields(headers http.Header, query url.Values) {
amzDate := s.Time.TimeFormat()
if s.IsPreSign {
query.Set(v4Internal.AmzAlgorithmKey, signingAlgorithm)
sessionToken := s.Credentials.SessionToken
if !s.DisableSessionToken && len(sessionToken) > 0 {
query.Set("X-Amz-Security-Token", sessionToken)
}
query.Set(v4Internal.AmzDateKey, amzDate)
return
}
headers[v4Internal.AmzDateKey] = append(headers[v4Internal.AmzDateKey][:0], amzDate)
if !s.DisableSessionToken && len(s.Credentials.SessionToken) > 0 {
headers[v4Internal.AmzSecurityTokenKey] = append(headers[v4Internal.AmzSecurityTokenKey][:0], s.Credentials.SessionToken)
}
}
func logSigningInfo(ctx context.Context, options SignerOptions, request *signedRequest, isPresign bool) {
if !options.LogSigning {
return
}
signedURLMsg := ""
if isPresign {
signedURLMsg = fmt.Sprintf(logSignedURLMsg, request.Request.URL.String())
}
logger := logging.WithContext(ctx, options.Logger)
logger.Logf(logging.Debug, logSignInfoMsg, request.CanonicalString, request.StringToSign, signedURLMsg)
}
type signedRequest struct {
Request *http.Request
SignedHeaders http.Header
CanonicalString string
StringToSign string
PreSigned bool
}
const logSignInfoMsg = `Request Signature:
---[ CANONICAL STRING ]-----------------------------
%s
---[ STRING TO SIGN ]--------------------------------
%s%s
-----------------------------------------------------`
const logSignedURLMsg = `
---[ SIGNED URL ]------------------------------------
%s`

View file

@ -0,0 +1,363 @@
package v4
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"reflect"
"strings"
"testing"
"time"
v4Internal "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4sdk2/signer/internal/v4"
"github.com/aws/aws-sdk-go-v2/aws"
)
var testCredentials = aws.Credentials{AccessKeyID: "AKID", SecretAccessKey: "SECRET", SessionToken: "SESSION"}
func buildRequest(serviceName, region, body string) (*http.Request, string) {
reader := strings.NewReader(body)
return buildRequestWithBodyReader(serviceName, region, reader)
}
func buildRequestWithBodyReader(serviceName, region string, body io.Reader) (*http.Request, string) {
var bodyLen int
type lenner interface {
Len() int
}
if lr, ok := body.(lenner); ok {
bodyLen = lr.Len()
}
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
req, _ := http.NewRequest("POST", endpoint, body)
req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
req.Header.Set("X-Amz-Target", "prefix.Operation")
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
if bodyLen > 0 {
req.ContentLength = int64(bodyLen)
}
req.Header.Set("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
req.Header.Add("X-Amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
req.Header.Add("X-amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
h := sha256.New()
_, _ = io.Copy(h, body)
payloadHash := hex.EncodeToString(h.Sum(nil))
return req, payloadHash
}
func TestPresignRequest(t *testing.T) {
req, body := buildRequest("dynamodb", "us-east-1", "{}")
query := req.URL.Query()
query.Set("X-Amz-Expires", "300")
req.URL.RawQuery = query.Encode()
signer := NewSigner()
signed, headers, err := signer.PresignHTTP(context.Background(), testCredentials, req, body, "dynamodb", "us-east-1", time.Unix(0, 0))
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
expectedDate := "19700101T000000Z"
expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore"
expectedSig := "122f0b9e091e4ba84286097e2b3404a1f1f4c4aad479adda95b7dff0ccbe5581"
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
expectedTarget := "prefix.Operation"
q, err := url.ParseQuery(signed[strings.Index(signed, "?"):])
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := expectedSig, q.Get("X-Amz-Signature"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedCred, q.Get("X-Amz-Credential"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedHeaders, q.Get("X-Amz-SignedHeaders"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if a := q.Get("X-Amz-Meta-Other-Header"); len(a) != 0 {
t.Errorf("expect %v to be empty", a)
}
if e, a := expectedTarget, q.Get("X-Amz-Target"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
for _, h := range strings.Split(expectedHeaders, ";") {
v := headers.Get(h)
if len(v) == 0 {
t.Errorf("expect %v, to be present in header map", h)
}
}
}
func TestPresignBodyWithArrayRequest(t *testing.T) {
req, body := buildRequest("dynamodb", "us-east-1", "{}")
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
query := req.URL.Query()
query.Set("X-Amz-Expires", "300")
req.URL.RawQuery = query.Encode()
signer := NewSigner()
signed, headers, err := signer.PresignHTTP(context.Background(), testCredentials, req, body, "dynamodb", "us-east-1", time.Unix(0, 0))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
q, err := url.ParseQuery(signed[strings.Index(signed, "?"):])
if err != nil {
t.Errorf("expect no error, got %v", err)
}
expectedDate := "19700101T000000Z"
expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore"
expectedSig := "e3ac55addee8711b76c6d608d762cff285fe8b627a057f8b5ec9268cf82c08b1"
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
expectedTarget := "prefix.Operation"
if e, a := expectedSig, q.Get("X-Amz-Signature"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedCred, q.Get("X-Amz-Credential"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedHeaders, q.Get("X-Amz-SignedHeaders"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if a := q.Get("X-Amz-Meta-Other-Header"); len(a) != 0 {
t.Errorf("expect %v to be empty, was not", a)
}
if e, a := expectedTarget, q.Get("X-Amz-Target"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
for _, h := range strings.Split(expectedHeaders, ";") {
v := headers.Get(h)
if len(v) == 0 {
t.Errorf("expect %v, to be present in header map", h)
}
}
}
func TestSignRequest(t *testing.T) {
req, body := buildRequest("dynamodb", "us-east-1", "{}")
signer := NewSigner()
err := signer.SignHTTP(context.Background(), testCredentials, req, body, "dynamodb", "us-east-1", time.Unix(0, 0))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
expectedDate := "19700101T000000Z"
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=a518299330494908a70222cec6899f6f32f297f8595f6df1776d998936652ad9"
q := req.Header
if e, a := expectedSig, q.Get("Authorization"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestBuildCanonicalRequest(t *testing.T) {
req, _ := buildRequest("dynamodb", "us-east-1", "{}")
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
ctx := &httpSigner{
ServiceName: "dynamodb",
Region: "us-east-1",
Request: req,
Time: v4Internal.NewSigningTime(time.Now()),
KeyDerivator: v4Internal.NewSigningKeyDeriver(),
}
build, err := ctx.Build()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
expected := "https://example.org/bucket/key-._~,!@#$%^&*()?Foo=a&Foo=m&Foo=o&Foo=z"
if e, a := expected, build.Request.URL.String(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSigner_SignHTTP_NoReplaceRequestBody(t *testing.T) {
req, bodyHash := buildRequest("dynamodb", "us-east-1", "{}")
req.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
s := NewSigner()
origBody := req.Body
err := s.SignHTTP(context.Background(), testCredentials, req, bodyHash, "dynamodb", "us-east-1", time.Now())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if req.Body != origBody {
t.Errorf("expect request body to not be chagned")
}
}
func TestRequestHost(t *testing.T) {
req, _ := buildRequest("dynamodb", "us-east-1", "{}")
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
req.Host = "myhost"
query := req.URL.Query()
query.Set("X-Amz-Expires", "5")
req.URL.RawQuery = query.Encode()
ctx := &httpSigner{
ServiceName: "dynamodb",
Region: "us-east-1",
Request: req,
Time: v4Internal.NewSigningTime(time.Now()),
KeyDerivator: v4Internal.NewSigningKeyDeriver(),
}
build, err := ctx.Build()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !strings.Contains(build.CanonicalString, "host:"+req.Host) {
t.Errorf("canonical host header invalid")
}
}
func TestSign_buildCanonicalHeadersContentLengthPresent(t *testing.T) {
body := `{"description": "this is a test"}`
req, _ := buildRequest("dynamodb", "us-east-1", body)
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
req.Host = "myhost"
contentLength := fmt.Sprintf("%d", len([]byte(body)))
req.Header.Add("Content-Length", contentLength)
query := req.URL.Query()
query.Set("X-Amz-Expires", "5")
req.URL.RawQuery = query.Encode()
ctx := &httpSigner{
ServiceName: "dynamodb",
Region: "us-east-1",
Request: req,
Time: v4Internal.NewSigningTime(time.Now()),
KeyDerivator: v4Internal.NewSigningKeyDeriver(),
}
build, err := ctx.Build()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !strings.Contains(build.CanonicalString, "content-length:"+contentLength+"\n") {
t.Errorf("canonical header content-length invalid")
}
}
func TestSign_buildCanonicalHeaders(t *testing.T) {
serviceName := "mockAPI"
region := "mock-region"
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
req, err := http.NewRequest("POST", endpoint, nil)
if err != nil {
t.Fatalf("failed to create request, %v", err)
}
req.Header.Set("FooInnerSpace", " inner space ")
req.Header.Set("FooLeadingSpace", " leading-space")
req.Header.Add("FooMultipleSpace", "no-space")
req.Header.Add("FooMultipleSpace", "\ttab-space")
req.Header.Add("FooMultipleSpace", "trailing-space ")
req.Header.Set("FooNoSpace", "no-space")
req.Header.Set("FooTabSpace", "\ttab-space\t")
req.Header.Set("FooTrailingSpace", "trailing-space ")
req.Header.Set("FooWrappedSpace", " wrapped-space ")
ctx := &httpSigner{
ServiceName: serviceName,
Region: region,
Request: req,
Time: v4Internal.NewSigningTime(time.Date(2021, 10, 20, 12, 42, 0, 0, time.UTC)),
KeyDerivator: v4Internal.NewSigningKeyDeriver(),
}
build, err := ctx.Build()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
expectCanonicalString := strings.Join([]string{
`POST`,
`/`,
``,
`fooinnerspace:inner space`,
`fooleadingspace:leading-space`,
`foomultiplespace:no-space,tab-space,trailing-space`,
`foonospace:no-space`,
`footabspace:tab-space`,
`footrailingspace:trailing-space`,
`foowrappedspace:wrapped-space`,
`host:mockAPI.mock-region.amazonaws.com`,
`x-amz-date:20211020T124200Z`,
``,
`fooinnerspace;fooleadingspace;foomultiplespace;foonospace;footabspace;footrailingspace;foowrappedspace;host;x-amz-date`,
``,
}, "\n")
if diff := cmpDiff(expectCanonicalString, build.CanonicalString); diff != "" {
t.Errorf("expect match, got\n%s", diff)
}
}
func BenchmarkPresignRequest(b *testing.B) {
signer := NewSigner()
req, bodyHash := buildRequest("dynamodb", "us-east-1", "{}")
query := req.URL.Query()
query.Set("X-Amz-Expires", "5")
req.URL.RawQuery = query.Encode()
for i := 0; i < b.N; i++ {
signer.PresignHTTP(context.Background(), testCredentials, req, bodyHash, "dynamodb", "us-east-1", time.Now())
}
}
func BenchmarkSignRequest(b *testing.B) {
signer := NewSigner()
req, bodyHash := buildRequest("dynamodb", "us-east-1", "{}")
for i := 0; i < b.N; i++ {
signer.SignHTTP(context.Background(), testCredentials, req, bodyHash, "dynamodb", "us-east-1", time.Now())
}
}
func cmpDiff(e, a interface{}) string {
if !reflect.DeepEqual(e, a) {
return fmt.Sprintf("%v != %v", e, a)
}
return ""
}

View file

@ -30,7 +30,6 @@ type (
Box *accessbox.Box
Attributes []object.Attribute
PutTime time.Time
Address *oid.Address
}
)
@ -58,8 +57,8 @@ func NewAccessBoxCache(config *Config) *AccessBoxCache {
}
// Get returns a cached accessbox.
func (o *AccessBoxCache) Get(accessKeyID string) *AccessBoxCacheValue {
entry, err := o.cache.Get(accessKeyID)
func (o *AccessBoxCache) Get(address oid.Address) *AccessBoxCacheValue {
entry, err := o.cache.Get(address)
if err != nil {
return nil
}
@ -75,11 +74,16 @@ func (o *AccessBoxCache) Get(accessKeyID string) *AccessBoxCacheValue {
}
// Put stores an accessbox to cache.
func (o *AccessBoxCache) Put(accessKeyID string, val *AccessBoxCacheValue) error {
return o.cache.Set(accessKeyID, val)
func (o *AccessBoxCache) Put(address oid.Address, box *accessbox.Box, attrs []object.Attribute) error {
val := &AccessBoxCacheValue{
Box: box,
Attributes: attrs,
PutTime: time.Now(),
}
return o.cache.Set(address, val)
}
// Delete removes an accessbox from cache.
func (o *AccessBoxCache) Delete(accessKeyID string) {
o.cache.Remove(accessKeyID)
func (o *AccessBoxCache) Delete(address oid.Address) {
o.cache.Remove(address)
}

View file

@ -1,14 +1,13 @@
package cache
import (
"strings"
"testing"
"git.frostfs.info/TrueCloudLab/frostfs-contract/frostfsid/client"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox"
cidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id/test"
oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
oidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id/test"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/util"
@ -23,21 +22,18 @@ func TestAccessBoxCacheType(t *testing.T) {
addr := oidtest.Address()
box := &accessbox.Box{}
val := &AccessBoxCacheValue{
Box: box,
}
var attrs []object.Attribute
accessKeyID := getAccessKeyID(addr)
err := cache.Put(accessKeyID, val)
err := cache.Put(addr, box, attrs)
require.NoError(t, err)
resVal := cache.Get(accessKeyID)
require.Equal(t, box, resVal.Box)
val := cache.Get(addr)
require.Equal(t, box, val.Box)
require.Equal(t, attrs, val.Attributes)
require.Equal(t, 0, observedLog.Len())
err = cache.cache.Set(accessKeyID, "tmp")
err = cache.cache.Set(addr, "tmp")
require.NoError(t, err)
assertInvalidCacheEntry(t, cache.Get(accessKeyID), observedLog)
assertInvalidCacheEntry(t, cache.Get(addr), observedLog)
}
func TestBucketsCacheType(t *testing.T) {
@ -234,7 +230,3 @@ func getObservedLogger() (*zap.Logger, *observer.ObservedLogs) {
loggerCore, observedLog := observer.New(zap.WarnLevel)
return zap.New(loggerCore), observedLog
}
func getAccessKeyID(addr oid.Address) string {
return strings.ReplaceAll(addr.EncodeToString(), "/", "0")
}

View file

@ -1,65 +0,0 @@
package cache
import (
"fmt"
"time"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/logs"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap"
"github.com/bluele/gcache"
"go.uber.org/zap"
)
type (
// NetworkInfoCache provides cache for network info.
NetworkInfoCache struct {
cache gcache.Cache
logger *zap.Logger
}
// NetworkInfoCacheConfig stores expiration params for cache.
NetworkInfoCacheConfig struct {
Lifetime time.Duration
Logger *zap.Logger
}
)
const (
DefaultNetworkInfoCacheLifetime = 1 * time.Minute
networkInfoCacheSize = 1
networkInfoKey = "network_info"
)
// DefaultNetworkInfoConfig returns new default cache expiration values.
func DefaultNetworkInfoConfig(logger *zap.Logger) *NetworkInfoCacheConfig {
return &NetworkInfoCacheConfig{
Lifetime: DefaultNetworkInfoCacheLifetime,
Logger: logger,
}
}
// NewNetworkInfoCache creates an object of NetworkInfoCache.
func NewNetworkInfoCache(config *NetworkInfoCacheConfig) *NetworkInfoCache {
gc := gcache.New(networkInfoCacheSize).LRU().Expiration(config.Lifetime).Build()
return &NetworkInfoCache{cache: gc, logger: config.Logger}
}
func (c *NetworkInfoCache) Get() *netmap.NetworkInfo {
entry, err := c.cache.Get(networkInfoKey)
if err != nil {
return nil
}
result, ok := entry.(netmap.NetworkInfo)
if !ok {
c.logger.Warn(logs.InvalidCacheEntryType, zap.String("actual", fmt.Sprintf("%T", entry)),
zap.String("expected", fmt.Sprintf("%T", result)))
return nil
}
return &result
}
func (c *NetworkInfoCache) Put(info netmap.NetworkInfo) error {
return c.cache.Set(networkInfoKey, info)
}

View file

@ -29,7 +29,6 @@ type (
LifecycleExpiration struct {
Date string `xml:"Date,omitempty"`
Days *int `xml:"Days,omitempty"`
Epoch *uint64 `xml:"Epoch,omitempty"`
ExpiredObjectDeleteMarker *bool `xml:"ExpiredObjectDeleteMarker,omitempty"`
}

View file

@ -1,13 +1,10 @@
package errors
import (
"errors"
"fmt"
"net/http"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer/frostfs"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer/tree"
frosterr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/frostfs/errors"
frosterrors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/frostfs/errors"
)
type (
@ -1768,7 +1765,7 @@ var errorCodes = errorCodeMap{
// IsS3Error checks if the provided error is a specific s3 error.
func IsS3Error(err error, code ErrorCode) bool {
err = frosterr.UnwrapErr(err)
err = frosterrors.UnwrapErr(err)
e, ok := err.(Error)
return ok && e.ErrCode == code
}
@ -1805,26 +1802,6 @@ func GetAPIErrorWithError(code ErrorCode, err error) Error {
return errorCodes.toAPIErrWithErr(code, err)
}
// TransformToS3Error converts FrostFS error to the corresponding S3 error type.
func TransformToS3Error(err error) error {
err = frosterr.UnwrapErr(err) // this wouldn't work with errors.Join
var s3err Error
if errors.As(err, &s3err) {
return err
}
if errors.Is(err, frostfs.ErrAccessDenied) ||
errors.Is(err, tree.ErrNodeAccessDenied) {
return GetAPIError(ErrAccessDenied)
}
if errors.Is(err, frostfs.ErrGatewayTimeout) {
return GetAPIError(ErrGatewayTimeout)
}
return GetAPIError(ErrInternalError)
}
// ObjectError -- error that is linked to a specific object.
type ObjectError struct {
Err error

View file

@ -2,12 +2,7 @@ package errors
import (
"errors"
"fmt"
"testing"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer/frostfs"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer/tree"
"github.com/stretchr/testify/require"
)
func BenchmarkErrCode(b *testing.B) {
@ -29,56 +24,3 @@ func BenchmarkErrorsIs(b *testing.B) {
}
}
}
func TestTransformS3Errors(t *testing.T) {
for _, tc := range []struct {
name string
err error
expected ErrorCode
}{
{
name: "simple std error to internal error",
err: errors.New("some error"),
expected: ErrInternalError,
},
{
name: "layer access denied error to s3 access denied error",
err: frostfs.ErrAccessDenied,
expected: ErrAccessDenied,
},
{
name: "wrapped layer access denied error to s3 access denied error",
err: fmt.Errorf("wrap: %w", frostfs.ErrAccessDenied),
expected: ErrAccessDenied,
},
{
name: "layer node access denied error to s3 access denied error",
err: tree.ErrNodeAccessDenied,
expected: ErrAccessDenied,
},
{
name: "layer gateway timeout error to s3 gateway timeout error",
err: frostfs.ErrGatewayTimeout,
expected: ErrGatewayTimeout,
},
{
name: "s3 error to s3 error",
err: GetAPIError(ErrInvalidPart),
expected: ErrInvalidPart,
},
{
name: "wrapped s3 error to s3 error",
err: fmt.Errorf("wrap: %w", GetAPIError(ErrInvalidPart)),
expected: ErrInvalidPart,
},
} {
t.Run(tc.name, func(t *testing.T) {
err := TransformToS3Error(tc.err)
s3err, ok := err.(Error)
require.True(t, ok, "error must be s3 error")
require.Equalf(t, tc.expected, s3err.ErrCode,
"expected: '%s', got: '%s'",
GetAPIError(tc.expected).Code, GetAPIError(s3err.ErrCode).Code)
})
}
}

View file

@ -52,18 +52,18 @@ func (h *handler) GetBucketACLHandler(w http.ResponseWriter, r *http.Request) {
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
settings, err := h.obj.GetBucketSettings(ctx, bktInfo)
settings, err := h.obj.GetBucketSettings(r.Context(), bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "couldn't get bucket settings", reqInfo, err)
h.logAndSendError(w, "couldn't get bucket settings", reqInfo, err)
return
}
if err = middleware.EncodeToResponse(w, h.encodeBucketCannedACL(ctx, bktInfo, settings)); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
return
}
}
@ -127,18 +127,17 @@ func getTokenIssuerKey(box *accessbox.Box) (*keys.PublicKey, error) {
}
func (h *handler) PutBucketACLHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
settings, err := h.obj.GetBucketSettings(ctx, bktInfo)
settings, err := h.obj.GetBucketSettings(r.Context(), bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "couldn't get bucket settings", reqInfo, err)
h.logAndSendError(w, "couldn't get bucket settings", reqInfo, err)
return
}
@ -150,30 +149,30 @@ func (h *handler) putBucketACLAPEHandler(w http.ResponseWriter, r *http.Request,
defer func() {
if errBody := r.Body.Close(); errBody != nil {
h.reqLogger(ctx).Warn(logs.CouldNotCloseRequestBody, zap.Error(errBody))
h.reqLogger(r.Context()).Warn(logs.CouldNotCloseRequestBody, zap.Error(errBody))
}
}()
written, err := io.Copy(io.Discard, r.Body)
if err != nil {
h.logAndSendError(ctx, w, "couldn't read request body", reqInfo, err)
h.logAndSendError(w, "couldn't read request body", reqInfo, err)
return
}
if written != 0 || len(r.Header.Get(api.AmzACL)) == 0 {
h.logAndSendError(ctx, w, "acl not supported for this bucket", reqInfo, errors.GetAPIError(errors.ErrAccessControlListNotSupported))
h.logAndSendError(w, "acl not supported for this bucket", reqInfo, errors.GetAPIError(errors.ErrAccessControlListNotSupported))
return
}
cannedACL, err := parseCannedACL(r.Header)
if err != nil {
h.logAndSendError(ctx, w, "could not parse canned ACL", reqInfo, err)
h.logAndSendError(w, "could not parse canned ACL", reqInfo, err)
return
}
chainRules := bucketCannedACLToAPERules(cannedACL, reqInfo, bktInfo.CID)
if err = h.ape.SaveACLChains(bktInfo.CID.EncodeToString(), chainRules); err != nil {
h.logAndSendError(ctx, w, "failed to add morph rule chains", reqInfo, err)
h.logAndSendError(w, "failed to add morph rule chains", reqInfo, err)
return
}
@ -185,7 +184,7 @@ func (h *handler) putBucketACLAPEHandler(w http.ResponseWriter, r *http.Request,
}
if err = h.obj.PutBucketSettings(ctx, sp); err != nil {
h.logAndSendError(ctx, w, "couldn't save bucket settings", reqInfo, err,
h.logAndSendError(w, "couldn't save bucket settings", reqInfo, err,
zap.String("container_id", bktInfo.CID.EncodeToString()))
return
}
@ -199,18 +198,18 @@ func (h *handler) GetObjectACLHandler(w http.ResponseWriter, r *http.Request) {
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
settings, err := h.obj.GetBucketSettings(ctx, bktInfo)
settings, err := h.obj.GetBucketSettings(r.Context(), bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "couldn't get bucket settings", reqInfo, err)
h.logAndSendError(w, "couldn't get bucket settings", reqInfo, err)
return
}
if err = middleware.EncodeToResponse(w, h.encodePrivateCannedACL(ctx, bktInfo, settings)); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
return
}
}
@ -220,20 +219,19 @@ func (h *handler) PutObjectACLHandler(w http.ResponseWriter, r *http.Request) {
reqInfo := middleware.GetReqInfo(ctx)
if _, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName); err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
h.logAndSendError(ctx, w, "acl not supported for this bucket", reqInfo, errors.GetAPIError(errors.ErrAccessControlListNotSupported))
h.logAndSendError(w, "acl not supported for this bucket", reqInfo, errors.GetAPIError(errors.ErrAccessControlListNotSupported))
}
func (h *handler) GetBucketPolicyStatusHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
@ -242,13 +240,13 @@ func (h *handler) GetBucketPolicyStatusHandler(w http.ResponseWriter, r *http.Re
if strings.Contains(err.Error(), "not found") {
err = fmt.Errorf("%w: %s", errors.GetAPIError(errors.ErrNoSuchBucketPolicy), err.Error())
}
h.logAndSendError(ctx, w, "failed to get policy from storage", reqInfo, err)
h.logAndSendError(w, "failed to get policy from storage", reqInfo, err)
return
}
var bktPolicy engineiam.Policy
if err = json.Unmarshal(jsonPolicy, &bktPolicy); err != nil {
h.logAndSendError(ctx, w, "could not parse bucket policy", reqInfo, err)
h.logAndSendError(w, "could not parse bucket policy", reqInfo, err)
return
}
@ -265,18 +263,17 @@ func (h *handler) GetBucketPolicyStatusHandler(w http.ResponseWriter, r *http.Re
}
if err = middleware.EncodeToResponse(w, policyStatus); err != nil {
h.logAndSendError(ctx, w, "encode and write response", reqInfo, err)
h.logAndSendError(w, "encode and write response", reqInfo, err)
return
}
}
func (h *handler) GetBucketPolicyHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
@ -285,7 +282,7 @@ func (h *handler) GetBucketPolicyHandler(w http.ResponseWriter, r *http.Request)
if strings.Contains(err.Error(), "not found") {
err = fmt.Errorf("%w: %s", errors.GetAPIError(errors.ErrNoSuchBucketPolicy), err.Error())
}
h.logAndSendError(ctx, w, "failed to get policy from storage", reqInfo, err)
h.logAndSendError(w, "failed to get policy from storage", reqInfo, err)
return
}
@ -293,23 +290,22 @@ func (h *handler) GetBucketPolicyHandler(w http.ResponseWriter, r *http.Request)
w.WriteHeader(http.StatusOK)
if _, err = w.Write(jsonPolicy); err != nil {
h.logAndSendError(ctx, w, "write json policy to client", reqInfo, err)
h.logAndSendError(w, "write json policy to client", reqInfo, err)
}
}
func (h *handler) DeleteBucketPolicyHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
chainIDs := []chain.ID{getBucketChainID(chain.S3, bktInfo), getBucketChainID(chain.Ingress, bktInfo)}
if err = h.ape.DeleteBucketPolicy(reqInfo.Namespace, bktInfo.CID, chainIDs); err != nil {
h.logAndSendError(ctx, w, "failed to delete policy from storage", reqInfo, err)
h.logAndSendError(w, "failed to delete policy from storage", reqInfo, err)
return
}
}
@ -328,41 +324,40 @@ func checkOwner(info *data.BucketInfo, owner string) error {
}
func (h *handler) PutBucketPolicyHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
jsonPolicy, err := io.ReadAll(r.Body)
if err != nil {
h.logAndSendError(ctx, w, "read body", reqInfo, err)
h.logAndSendError(w, "read body", reqInfo, err)
return
}
var bktPolicy engineiam.Policy
if err = json.Unmarshal(jsonPolicy, &bktPolicy); err != nil {
h.logAndSendError(ctx, w, "could not parse bucket policy", reqInfo, err)
h.logAndSendError(w, "could not parse bucket policy", reqInfo, err)
return
}
for _, stat := range bktPolicy.Statement {
if len(stat.NotResource) != 0 {
h.logAndSendError(ctx, w, "policy resource mismatched bucket", reqInfo, errors.GetAPIError(errors.ErrMalformedPolicy))
h.logAndSendError(w, "policy resource mismatched bucket", reqInfo, errors.GetAPIError(errors.ErrMalformedPolicy))
return
}
if len(stat.NotPrincipal) != 0 && stat.Effect == engineiam.AllowEffect {
h.logAndSendError(ctx, w, "invalid NotPrincipal", reqInfo, errors.GetAPIError(errors.ErrMalformedPolicyNotPrincipal))
h.logAndSendError(w, "invalid NotPrincipal", reqInfo, errors.GetAPIError(errors.ErrMalformedPolicyNotPrincipal))
return
}
for _, resource := range stat.Resource {
if reqInfo.BucketName != strings.Split(strings.TrimPrefix(resource, arnAwsPrefix), "/")[0] {
h.logAndSendError(ctx, w, "policy resource mismatched bucket", reqInfo, errors.GetAPIError(errors.ErrMalformedPolicy))
h.logAndSendError(w, "policy resource mismatched bucket", reqInfo, errors.GetAPIError(errors.ErrMalformedPolicy))
return
}
}
@ -370,7 +365,7 @@ func (h *handler) PutBucketPolicyHandler(w http.ResponseWriter, r *http.Request)
s3Chain, err := engineiam.ConvertToS3Chain(bktPolicy, h.frostfsid)
if err != nil {
h.logAndSendError(ctx, w, "could not convert s3 policy to chain policy", reqInfo, err)
h.logAndSendError(w, "could not convert s3 policy to chain policy", reqInfo, err)
return
}
s3Chain.ID = getBucketChainID(chain.S3, bktInfo)
@ -379,10 +374,10 @@ func (h *handler) PutBucketPolicyHandler(w http.ResponseWriter, r *http.Request)
if err == nil {
nativeChain.ID = getBucketChainID(chain.Ingress, bktInfo)
} else if !stderrors.Is(err, engineiam.ErrActionsNotApplicable) {
h.logAndSendError(ctx, w, "could not convert s3 policy to native chain policy", reqInfo, err)
h.logAndSendError(w, "could not convert s3 policy to native chain policy", reqInfo, err)
return
} else {
h.reqLogger(ctx).Warn(logs.PolicyCouldntBeConvertedToNativeRules)
h.reqLogger(r.Context()).Warn(logs.PolicyCouldntBeConvertedToNativeRules)
}
chainsToSave := []*chain.Chain{s3Chain}
@ -391,7 +386,7 @@ func (h *handler) PutBucketPolicyHandler(w http.ResponseWriter, r *http.Request)
}
if err = h.ape.PutBucketPolicy(reqInfo.Namespace, bktInfo.CID, jsonPolicy, chainsToSave); err != nil {
h.logAndSendError(ctx, w, "failed to update policy in contract", reqInfo, err)
h.logAndSendError(w, "failed to update policy in contract", reqInfo, err)
return
}
}

View file

@ -11,7 +11,7 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
s3errors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/bearer"
@ -28,12 +28,12 @@ func TestPutObjectACLErrorAPE(t *testing.T) {
info := createBucket(hc, bktName)
putObjectWithHeadersAssertS3Error(hc, bktName, objName, map[string]string{api.AmzACL: basicACLPublic}, apierr.ErrAccessControlListNotSupported)
putObjectWithHeadersAssertS3Error(hc, bktName, objName, map[string]string{api.AmzACL: basicACLPublic}, s3errors.ErrAccessControlListNotSupported)
putObjectWithHeaders(hc, bktName, objName, map[string]string{api.AmzACL: basicACLPrivate}) // only `private` canned acl is allowed, that is actually ignored
putObjectWithHeaders(hc, bktName, objName, nil)
aclBody := &AccessControlPolicy{}
putObjectACLAssertS3Error(hc, bktName, objName, info.Box, nil, aclBody, apierr.ErrAccessControlListNotSupported)
putObjectACLAssertS3Error(hc, bktName, objName, info.Box, nil, aclBody, s3errors.ErrAccessControlListNotSupported)
aclRes := getObjectACL(hc, bktName, objName)
checkPrivateACL(t, aclRes, info.Key.PublicKey())
@ -49,7 +49,7 @@ func TestCreateObjectACLErrorAPE(t *testing.T) {
copyObject(hc, bktName, objName, objNameCopy, CopyMeta{Headers: map[string]string{api.AmzACL: basicACLPublic}}, http.StatusBadRequest)
copyObject(hc, bktName, objName, objNameCopy, CopyMeta{Headers: map[string]string{api.AmzACL: basicACLPrivate}}, http.StatusOK)
createMultipartUploadAssertS3Error(hc, bktName, objName, map[string]string{api.AmzACL: basicACLPublic}, apierr.ErrAccessControlListNotSupported)
createMultipartUploadAssertS3Error(hc, bktName, objName, map[string]string{api.AmzACL: basicACLPublic}, s3errors.ErrAccessControlListNotSupported)
createMultipartUpload(hc, bktName, objName, map[string]string{api.AmzACL: basicACLPrivate})
}
@ -60,7 +60,7 @@ func TestBucketACLAPE(t *testing.T) {
info := createBucket(hc, bktName)
aclBody := &AccessControlPolicy{}
putBucketACLAssertS3Error(hc, bktName, info.Box, nil, aclBody, apierr.ErrAccessControlListNotSupported)
putBucketACLAssertS3Error(hc, bktName, info.Box, nil, aclBody, s3errors.ErrAccessControlListNotSupported)
aclRes := getBucketACL(hc, bktName)
checkPrivateACL(t, aclRes, info.Key.PublicKey())
@ -113,7 +113,7 @@ func TestBucketPolicy(t *testing.T) {
createTestBucket(hc, bktName)
getBucketPolicy(hc, bktName, apierr.ErrNoSuchBucketPolicy)
getBucketPolicy(hc, bktName, s3errors.ErrNoSuchBucketPolicy)
newPolicy := engineiam.Policy{
Version: "2012-10-17",
@ -125,7 +125,7 @@ func TestBucketPolicy(t *testing.T) {
}},
}
putBucketPolicy(hc, bktName, newPolicy, apierr.ErrMalformedPolicy)
putBucketPolicy(hc, bktName, newPolicy, s3errors.ErrMalformedPolicy)
newPolicy.Statement[0].Resource[0] = arnAwsPrefix + bktName + "/*"
putBucketPolicy(hc, bktName, newPolicy)
@ -140,7 +140,7 @@ func TestBucketPolicyStatus(t *testing.T) {
createTestBucket(hc, bktName)
getBucketPolicy(hc, bktName, apierr.ErrNoSuchBucketPolicy)
getBucketPolicy(hc, bktName, s3errors.ErrNoSuchBucketPolicy)
newPolicy := engineiam.Policy{
Version: "2012-10-17",
@ -152,7 +152,7 @@ func TestBucketPolicyStatus(t *testing.T) {
}},
}
putBucketPolicy(hc, bktName, newPolicy, apierr.ErrMalformedPolicyNotPrincipal)
putBucketPolicy(hc, bktName, newPolicy, s3errors.ErrMalformedPolicyNotPrincipal)
newPolicy.Statement[0].NotPrincipal = nil
newPolicy.Statement[0].Principal = map[engineiam.PrincipalType][]string{engineiam.Wildcard: {}}
@ -221,7 +221,7 @@ func TestPutBucketPolicy(t *testing.T) {
assertStatus(hc.t, w, http.StatusOK)
}
func getBucketPolicy(hc *handlerContext, bktName string, errCode ...apierr.ErrorCode) engineiam.Policy {
func getBucketPolicy(hc *handlerContext, bktName string, errCode ...s3errors.ErrorCode) engineiam.Policy {
w, r := prepareTestRequest(hc, bktName, "", nil)
hc.Handler().GetBucketPolicyHandler(w, r)
@ -231,13 +231,13 @@ func getBucketPolicy(hc *handlerContext, bktName string, errCode ...apierr.Error
err := json.NewDecoder(w.Result().Body).Decode(&policy)
require.NoError(hc.t, err)
} else {
assertS3Error(hc.t, w, apierr.GetAPIError(errCode[0]))
assertS3Error(hc.t, w, s3errors.GetAPIError(errCode[0]))
}
return policy
}
func getBucketPolicyStatus(hc *handlerContext, bktName string, errCode ...apierr.ErrorCode) PolicyStatus {
func getBucketPolicyStatus(hc *handlerContext, bktName string, errCode ...s3errors.ErrorCode) PolicyStatus {
w, r := prepareTestRequest(hc, bktName, "", nil)
hc.Handler().GetBucketPolicyStatusHandler(w, r)
@ -247,13 +247,13 @@ func getBucketPolicyStatus(hc *handlerContext, bktName string, errCode ...apierr
err := xml.NewDecoder(w.Result().Body).Decode(&policyStatus)
require.NoError(hc.t, err)
} else {
assertS3Error(hc.t, w, apierr.GetAPIError(errCode[0]))
assertS3Error(hc.t, w, s3errors.GetAPIError(errCode[0]))
}
return policyStatus
}
func putBucketPolicy(hc *handlerContext, bktName string, bktPolicy engineiam.Policy, errCode ...apierr.ErrorCode) {
func putBucketPolicy(hc *handlerContext, bktName string, bktPolicy engineiam.Policy, errCode ...s3errors.ErrorCode) {
body, err := json.Marshal(bktPolicy)
require.NoError(hc.t, err)
@ -263,7 +263,7 @@ func putBucketPolicy(hc *handlerContext, bktName string, bktPolicy engineiam.Pol
if len(errCode) == 0 {
assertStatus(hc.t, w, http.StatusOK)
} else {
assertS3Error(hc.t, w, apierr.GetAPIError(errCode[0]))
assertS3Error(hc.t, w, s3errors.GetAPIError(errCode[0]))
}
}
@ -312,9 +312,9 @@ func createBucket(hc *handlerContext, bktName string) *createBucketInfo {
}
}
func createBucketAssertS3Error(hc *handlerContext, bktName string, box *accessbox.Box, code apierr.ErrorCode) {
func createBucketAssertS3Error(hc *handlerContext, bktName string, box *accessbox.Box, code s3errors.ErrorCode) {
w := createBucketBase(hc, bktName, box)
assertS3Error(hc.t, w, apierr.GetAPIError(code))
assertS3Error(hc.t, w, s3errors.GetAPIError(code))
}
func createBucketBase(hc *handlerContext, bktName string, box *accessbox.Box) *httptest.ResponseRecorder {
@ -330,9 +330,9 @@ func putBucketACL(hc *handlerContext, bktName string, box *accessbox.Box, header
assertStatus(hc.t, w, http.StatusOK)
}
func putBucketACLAssertS3Error(hc *handlerContext, bktName string, box *accessbox.Box, header map[string]string, body *AccessControlPolicy, code apierr.ErrorCode) {
func putBucketACLAssertS3Error(hc *handlerContext, bktName string, box *accessbox.Box, header map[string]string, body *AccessControlPolicy, code s3errors.ErrorCode) {
w := putBucketACLBase(hc, bktName, box, header, body)
assertS3Error(hc.t, w, apierr.GetAPIError(code))
assertS3Error(hc.t, w, s3errors.GetAPIError(code))
}
func putBucketACLBase(hc *handlerContext, bktName string, box *accessbox.Box, header map[string]string, body *AccessControlPolicy) *httptest.ResponseRecorder {
@ -360,9 +360,9 @@ func getBucketACLBase(hc *handlerContext, bktName string) *httptest.ResponseReco
return w
}
func putObjectACLAssertS3Error(hc *handlerContext, bktName, objName string, box *accessbox.Box, header map[string]string, body *AccessControlPolicy, code apierr.ErrorCode) {
func putObjectACLAssertS3Error(hc *handlerContext, bktName, objName string, box *accessbox.Box, header map[string]string, body *AccessControlPolicy, code s3errors.ErrorCode) {
w := putObjectACLBase(hc, bktName, objName, box, header, body)
assertS3Error(hc.t, w, apierr.GetAPIError(code))
assertS3Error(hc.t, w, s3errors.GetAPIError(code))
}
func putObjectACLBase(hc *handlerContext, bktName, objName string, box *accessbox.Box, header map[string]string, body *AccessControlPolicy) *httptest.ResponseRecorder {
@ -396,9 +396,9 @@ func putObjectWithHeaders(hc *handlerContext, bktName, objName string, headers m
return w.Header()
}
func putObjectWithHeadersAssertS3Error(hc *handlerContext, bktName, objName string, headers map[string]string, code apierr.ErrorCode) {
func putObjectWithHeadersAssertS3Error(hc *handlerContext, bktName, objName string, headers map[string]string, code s3errors.ErrorCode) {
w := putObjectWithHeadersBase(hc, bktName, objName, headers, nil, nil)
assertS3Error(hc.t, w, apierr.GetAPIError(code))
assertS3Error(hc.t, w, s3errors.GetAPIError(code))
}
func putObjectWithHeadersBase(hc *handlerContext, bktName, objName string, headers map[string]string, box *accessbox.Box, data []byte) *httptest.ResponseRecorder {

View file

@ -70,18 +70,17 @@ var validAttributes = map[string]struct{}{
}
func (h *handler) GetObjectAttributesHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
params, err := parseGetObjectAttributeArgs(r)
if err != nil {
h.logAndSendError(ctx, w, "invalid request", reqInfo, err)
h.logAndSendError(w, "invalid request", reqInfo, err)
return
}
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
@ -91,44 +90,44 @@ func (h *handler) GetObjectAttributesHandler(w http.ResponseWriter, r *http.Requ
VersionID: params.VersionID,
}
extendedInfo, err := h.obj.GetExtendedObjectInfo(ctx, p)
extendedInfo, err := h.obj.GetExtendedObjectInfo(r.Context(), p)
if err != nil {
h.logAndSendError(ctx, w, "could not fetch object info", reqInfo, err)
h.logAndSendError(w, "could not fetch object info", reqInfo, err)
return
}
info := extendedInfo.ObjectInfo
encryptionParams, err := formEncryptionParams(r)
if err != nil {
h.logAndSendError(ctx, w, "invalid sse headers", reqInfo, err)
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return
}
if err = encryptionParams.MatchObjectEncryption(layer.FormEncryptionInfo(info.Headers)); err != nil {
h.logAndSendError(ctx, w, "encryption doesn't match object", reqInfo, errors.GetAPIError(errors.ErrBadRequest), zap.Error(err))
h.logAndSendError(w, "encryption doesn't match object", reqInfo, errors.GetAPIError(errors.ErrBadRequest), zap.Error(err))
return
}
if err = checkPreconditions(info, params.Conditional, h.cfg.MD5Enabled()); err != nil {
h.logAndSendError(ctx, w, "precondition failed", reqInfo, err)
h.logAndSendError(w, "precondition failed", reqInfo, err)
return
}
bktSettings, err := h.obj.GetBucketSettings(ctx, bktInfo)
bktSettings, err := h.obj.GetBucketSettings(r.Context(), bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket settings", reqInfo, err)
h.logAndSendError(w, "could not get bucket settings", reqInfo, err)
return
}
response, err := encodeToObjectAttributesResponse(info, params, h.cfg.MD5Enabled())
if err != nil {
h.logAndSendError(ctx, w, "couldn't encode object info to response", reqInfo, err)
h.logAndSendError(w, "couldn't encode object info to response", reqInfo, err)
return
}
writeAttributesHeaders(w.Header(), extendedInfo, bktSettings.Unversioned())
if err = middleware.EncodeToResponse(w, response); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
}
}

View file

@ -65,7 +65,7 @@ func (h *handler) CopyObjectHandler(w http.ResponseWriter, r *http.Request) {
srcBucket, srcObject, err := path2BucketObject(src)
if err != nil {
h.logAndSendError(ctx, w, "invalid source copy", reqInfo, err)
h.logAndSendError(w, "invalid source copy", reqInfo, err)
return
}
@ -75,74 +75,74 @@ func (h *handler) CopyObjectHandler(w http.ResponseWriter, r *http.Request) {
}
if srcObjPrm.BktInfo, err = h.getBucketAndCheckOwner(r, srcBucket, api.AmzSourceExpectedBucketOwner); err != nil {
h.logAndSendError(ctx, w, "couldn't get source bucket", reqInfo, err)
h.logAndSendError(w, "couldn't get source bucket", reqInfo, err)
return
}
dstBktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "couldn't get target bucket", reqInfo, err)
h.logAndSendError(w, "couldn't get target bucket", reqInfo, err)
return
}
settings, err := h.obj.GetBucketSettings(ctx, dstBktInfo)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket settings", reqInfo, err)
h.logAndSendError(w, "could not get bucket settings", reqInfo, err)
return
}
if cannedACLStatus == aclStatusYes {
h.logAndSendError(ctx, w, "acl not supported for this bucket", reqInfo, errors.GetAPIError(errors.ErrAccessControlListNotSupported))
h.logAndSendError(w, "acl not supported for this bucket", reqInfo, errors.GetAPIError(errors.ErrAccessControlListNotSupported))
return
}
extendedSrcObjInfo, err := h.obj.GetExtendedObjectInfo(ctx, srcObjPrm)
if err != nil {
h.logAndSendError(ctx, w, "could not find object", reqInfo, err)
h.logAndSendError(w, "could not find object", reqInfo, err)
return
}
srcObjInfo := extendedSrcObjInfo.ObjectInfo
srcEncryptionParams, err := formCopySourceEncryptionParams(r)
if err != nil {
h.logAndSendError(ctx, w, "invalid sse headers", reqInfo, err)
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return
}
dstEncryptionParams, err := formEncryptionParams(r)
if err != nil {
h.logAndSendError(ctx, w, "invalid sse headers", reqInfo, err)
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return
}
if err = srcEncryptionParams.MatchObjectEncryption(layer.FormEncryptionInfo(srcObjInfo.Headers)); err != nil {
if errors.IsS3Error(err, errors.ErrInvalidEncryptionParameters) || errors.IsS3Error(err, errors.ErrSSEEncryptedObject) ||
errors.IsS3Error(err, errors.ErrInvalidSSECustomerParameters) {
h.logAndSendError(ctx, w, "encryption doesn't match object", reqInfo, err, zap.Error(err))
h.logAndSendError(w, "encryption doesn't match object", reqInfo, err, zap.Error(err))
return
}
h.logAndSendError(ctx, w, "encryption doesn't match object", reqInfo, errors.GetAPIError(errors.ErrBadRequest), zap.Error(err))
h.logAndSendError(w, "encryption doesn't match object", reqInfo, errors.GetAPIError(errors.ErrBadRequest), zap.Error(err))
return
}
var dstSize uint64
srcSize, err := layer.GetObjectSize(srcObjInfo)
if err != nil {
h.logAndSendError(ctx, w, "failed to get source object size", reqInfo, err)
h.logAndSendError(w, "failed to get source object size", reqInfo, err)
return
} else if srcSize > layer.UploadMaxSize { // https://docs.aws.amazon.com/AmazonS3/latest/API/API_CopyObject.html
h.logAndSendError(ctx, w, "too bid object to copy with single copy operation, use multipart upload copy instead", reqInfo, errors.GetAPIError(errors.ErrInvalidRequestLargeCopy))
h.logAndSendError(w, "too bid object to copy with single copy operation, use multipart upload copy instead", reqInfo, errors.GetAPIError(errors.ErrInvalidRequestLargeCopy))
return
}
dstSize = srcSize
args, err := parseCopyObjectArgs(r.Header)
if err != nil {
h.logAndSendError(ctx, w, "could not parse request params", reqInfo, err)
h.logAndSendError(w, "could not parse request params", reqInfo, err)
return
}
if isCopyingToItselfForbidden(reqInfo, srcBucket, srcObject, settings, args) {
h.logAndSendError(ctx, w, "copying to itself without changing anything", reqInfo, errors.GetAPIError(errors.ErrInvalidCopyDest))
h.logAndSendError(w, "copying to itself without changing anything", reqInfo, errors.GetAPIError(errors.ErrInvalidCopyDest))
return
}
@ -153,7 +153,7 @@ func (h *handler) CopyObjectHandler(w http.ResponseWriter, r *http.Request) {
if args.TaggingDirective == replaceDirective {
tagSet, err = parseTaggingHeader(r.Header)
if err != nil {
h.logAndSendError(ctx, w, "could not parse tagging header", reqInfo, err)
h.logAndSendError(w, "could not parse tagging header", reqInfo, err)
return
}
} else {
@ -168,13 +168,13 @@ func (h *handler) CopyObjectHandler(w http.ResponseWriter, r *http.Request) {
_, tagSet, err = h.obj.GetObjectTagging(ctx, tagPrm)
if err != nil {
h.logAndSendError(ctx, w, "could not get object tagging", reqInfo, err)
h.logAndSendError(w, "could not get object tagging", reqInfo, err)
return
}
}
if err = checkPreconditions(srcObjInfo, args.Conditional, h.cfg.MD5Enabled()); err != nil {
h.logAndSendError(ctx, w, "precondition failed", reqInfo, errors.GetAPIError(errors.ErrPreconditionFailed))
h.logAndSendError(w, "precondition failed", reqInfo, errors.GetAPIError(errors.ErrPreconditionFailed))
return
}
@ -202,20 +202,20 @@ func (h *handler) CopyObjectHandler(w http.ResponseWriter, r *http.Request) {
params.CopiesNumbers, err = h.pickCopiesNumbers(metadata, reqInfo.Namespace, dstBktInfo.LocationConstraint)
if err != nil {
h.logAndSendError(ctx, w, "invalid copies number", reqInfo, err)
h.logAndSendError(w, "invalid copies number", reqInfo, err)
return
}
params.Lock, err = formObjectLock(ctx, dstBktInfo, settings.LockConfiguration, r.Header)
if err != nil {
h.logAndSendError(ctx, w, "could not form object lock", reqInfo, err)
h.logAndSendError(w, "could not form object lock", reqInfo, err)
return
}
additional := []zap.Field{zap.String("src_bucket_name", srcBucket), zap.String("src_object_name", srcObject)}
extendedDstObjInfo, err := h.obj.CopyObject(ctx, params)
if err != nil {
h.logAndSendError(ctx, w, "couldn't copy object", reqInfo, err, additional...)
h.logAndSendError(w, "couldn't copy object", reqInfo, err, additional...)
return
}
dstObjInfo := extendedDstObjInfo.ObjectInfo
@ -224,7 +224,7 @@ func (h *handler) CopyObjectHandler(w http.ResponseWriter, r *http.Request) {
LastModified: dstObjInfo.Created.UTC().Format(time.RFC3339),
ETag: data.Quote(dstObjInfo.ETag(h.cfg.MD5Enabled())),
}); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err, additional...)
h.logAndSendError(w, "something went wrong", reqInfo, err, additional...)
return
}
@ -239,7 +239,7 @@ func (h *handler) CopyObjectHandler(w http.ResponseWriter, r *http.Request) {
NodeVersion: extendedDstObjInfo.NodeVersion,
}
if err = h.obj.PutObjectTagging(ctx, tagPrm); err != nil {
h.logAndSendError(ctx, w, "could not upload object tagging", reqInfo, err)
h.logAndSendError(w, "could not upload object tagging", reqInfo, err)
return
}
}

View file

@ -20,34 +20,32 @@ const (
)
func (h *handler) GetBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
cors, err := h.obj.GetBucketCORS(ctx, bktInfo)
cors, err := h.obj.GetBucketCORS(r.Context(), bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "could not get cors", reqInfo, err)
h.logAndSendError(w, "could not get cors", reqInfo, err)
return
}
if err = middleware.EncodeToResponse(w, cors); err != nil {
h.logAndSendError(ctx, w, "could not encode cors to response", reqInfo, err)
h.logAndSendError(w, "could not encode cors to response", reqInfo, err)
return
}
}
func (h *handler) PutBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
@ -59,33 +57,32 @@ func (h *handler) PutBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
p.CopiesNumbers, err = h.pickCopiesNumbers(parseMetadata(r), reqInfo.Namespace, bktInfo.LocationConstraint)
if err != nil {
h.logAndSendError(ctx, w, "invalid copies number", reqInfo, err)
h.logAndSendError(w, "invalid copies number", reqInfo, err)
return
}
if err = h.obj.PutBucketCORS(ctx, p); err != nil {
h.logAndSendError(ctx, w, "could not put cors configuration", reqInfo, err)
if err = h.obj.PutBucketCORS(r.Context(), p); err != nil {
h.logAndSendError(w, "could not put cors configuration", reqInfo, err)
return
}
if err = middleware.WriteSuccessResponseHeadersOnly(w); err != nil {
h.logAndSendError(ctx, w, "write response", reqInfo, err)
h.logAndSendError(w, "write response", reqInfo, err)
return
}
}
func (h *handler) DeleteBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
if err = h.obj.DeleteBucketCORS(ctx, bktInfo); err != nil {
h.logAndSendError(ctx, w, "could not delete cors", reqInfo, err)
if err = h.obj.DeleteBucketCORS(r.Context(), bktInfo); err != nil {
h.logAndSendError(w, "could not delete cors", reqInfo, err)
}
w.WriteHeader(http.StatusNoContent)
@ -152,22 +149,21 @@ func (h *handler) AppendCORSHeaders(w http.ResponseWriter, r *http.Request) {
}
func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
bktInfo, err := h.obj.GetBucketInfo(ctx, reqInfo.BucketName)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.obj.GetBucketInfo(r.Context(), reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
origin := r.Header.Get(api.Origin)
if origin == "" {
h.logAndSendError(ctx, w, "origin request header needed", reqInfo, errors.GetAPIError(errors.ErrBadRequest))
h.logAndSendError(w, "origin request header needed", reqInfo, errors.GetAPIError(errors.ErrBadRequest))
}
method := r.Header.Get(api.AccessControlRequestMethod)
if method == "" {
h.logAndSendError(ctx, w, "Access-Control-Request-Method request header needed", reqInfo, errors.GetAPIError(errors.ErrBadRequest))
h.logAndSendError(w, "Access-Control-Request-Method request header needed", reqInfo, errors.GetAPIError(errors.ErrBadRequest))
return
}
@ -177,9 +173,9 @@ func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) {
headers = strings.Split(requestHeaders, ", ")
}
cors, err := h.obj.GetBucketCORS(ctx, bktInfo)
cors, err := h.obj.GetBucketCORS(r.Context(), bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "could not get cors", reqInfo, err)
h.logAndSendError(w, "could not get cors", reqInfo, err)
return
}
@ -208,7 +204,7 @@ func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) {
w.Header().Set(api.AccessControlAllowCredentials, "true")
}
if err = middleware.WriteSuccessResponseHeadersOnly(w); err != nil {
h.logAndSendError(ctx, w, "write response", reqInfo, err)
h.logAndSendError(w, "write response", reqInfo, err)
return
}
return
@ -217,7 +213,7 @@ func (h *handler) Preflight(w http.ResponseWriter, r *http.Request) {
}
}
}
h.logAndSendError(ctx, w, "Forbidden", reqInfo, errors.GetAPIError(errors.ErrAccessDenied))
h.logAndSendError(w, "Forbidden", reqInfo, errors.GetAPIError(errors.ErrAccessDenied))
}
func checkSubslice(slice []string, subSlice []string) bool {

View file

@ -71,19 +71,19 @@ func (h *handler) DeleteObjectHandler(w http.ResponseWriter, r *http.Request) {
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
bktSettings, err := h.obj.GetBucketSettings(ctx, bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket settings", reqInfo, err)
h.logAndSendError(w, "could not get bucket settings", reqInfo, err)
return
}
networkInfo, err := h.obj.GetNetworkInfo(ctx)
if err != nil {
h.logAndSendError(ctx, w, "could not get network info", reqInfo, err)
h.logAndSendError(w, "could not get network info", reqInfo, err)
return
}
@ -97,9 +97,9 @@ func (h *handler) DeleteObjectHandler(w http.ResponseWriter, r *http.Request) {
deletedObject := deletedObjects[0]
if deletedObject.Error != nil {
if isErrObjectLocked(deletedObject.Error) {
h.logAndSendError(ctx, w, "object is locked", reqInfo, errors.GetAPIError(errors.ErrAccessDenied))
h.logAndSendError(w, "object is locked", reqInfo, errors.GetAPIError(errors.ErrAccessDenied))
} else {
h.logAndSendError(ctx, w, "could not delete object", reqInfo, deletedObject.Error)
h.logAndSendError(w, "could not delete object", reqInfo, deletedObject.Error)
}
return
}
@ -134,26 +134,26 @@ func (h *handler) DeleteMultipleObjectsHandler(w http.ResponseWriter, r *http.Re
// Content-Md5 is required and should be set
// http://docs.aws.amazon.com/AmazonS3/latest/API/multiobjectdeleteapi.html
if _, ok := r.Header[api.ContentMD5]; !ok {
h.logAndSendError(ctx, w, "missing Content-MD5", reqInfo, errors.GetAPIError(errors.ErrMissingContentMD5))
h.logAndSendError(w, "missing Content-MD5", reqInfo, errors.GetAPIError(errors.ErrMissingContentMD5))
return
}
// Content-Length is required and should be non-zero
// http://docs.aws.amazon.com/AmazonS3/latest/API/multiobjectdeleteapi.html
if r.ContentLength <= 0 {
h.logAndSendError(ctx, w, "missing Content-Length", reqInfo, errors.GetAPIError(errors.ErrMissingContentLength))
h.logAndSendError(w, "missing Content-Length", reqInfo, errors.GetAPIError(errors.ErrMissingContentLength))
return
}
// Unmarshal list of keys to be deleted.
requested := &DeleteObjectsRequest{}
if err := h.cfg.NewXMLDecoder(r.Body).Decode(requested); err != nil {
h.logAndSendError(ctx, w, "couldn't decode body", reqInfo, fmt.Errorf("%w: %s", errors.GetAPIError(errors.ErrMalformedXML), err.Error()))
h.logAndSendError(w, "couldn't decode body", reqInfo, fmt.Errorf("%w: %s", errors.GetAPIError(errors.ErrMalformedXML), err.Error()))
return
}
if len(requested.Objects) == 0 || len(requested.Objects) > maxObjectsToDelete {
h.logAndSendError(ctx, w, "number of objects to delete must be greater than 0 and less or equal to 1000", reqInfo, errors.GetAPIError(errors.ErrMalformedXML))
h.logAndSendError(w, "number of objects to delete must be greater than 0 and less or equal to 1000", reqInfo, errors.GetAPIError(errors.ErrMalformedXML))
return
}
@ -178,19 +178,19 @@ func (h *handler) DeleteMultipleObjectsHandler(w http.ResponseWriter, r *http.Re
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
bktSettings, err := h.obj.GetBucketSettings(ctx, bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket settings", reqInfo, err)
h.logAndSendError(w, "could not get bucket settings", reqInfo, err)
return
}
networkInfo, err := h.obj.GetNetworkInfo(ctx)
if err != nil {
h.logAndSendError(ctx, w, "could not get network info", reqInfo, err)
h.logAndSendError(w, "could not get network info", reqInfo, err)
return
}
@ -231,23 +231,22 @@ func (h *handler) DeleteMultipleObjectsHandler(w http.ResponseWriter, r *http.Re
}
if err = middleware.EncodeToResponse(w, response); err != nil {
h.logAndSendError(ctx, w, "could not write response", reqInfo, err)
h.logAndSendError(w, "could not write response", reqInfo, err)
return
}
}
func (h *handler) DeleteBucketHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
var sessionToken *session.Container
boxData, err := middleware.GetBoxData(ctx)
boxData, err := middleware.GetBoxData(r.Context())
if err == nil {
sessionToken = boxData.Gate.SessionTokenForDelete()
}
@ -260,12 +259,12 @@ func (h *handler) DeleteBucketHandler(w http.ResponseWriter, r *http.Request) {
}
}
if err = h.obj.DeleteBucket(ctx, &layer.DeleteBucketParams{
if err = h.obj.DeleteBucket(r.Context(), &layer.DeleteBucketParams{
BktInfo: bktInfo,
SessionToken: sessionToken,
SkipCheck: skipObjCheck,
}); err != nil {
h.logAndSendError(ctx, w, "couldn't delete bucket", reqInfo, err)
h.logAndSendError(w, "couldn't delete bucket", reqInfo, err)
return
}
@ -276,7 +275,7 @@ func (h *handler) DeleteBucketHandler(w http.ResponseWriter, r *http.Request) {
getBucketCannedChainID(chain.Ingress, bktInfo.CID),
}
if err = h.ape.DeleteBucketPolicy(reqInfo.Namespace, bktInfo.CID, chainIDs); err != nil {
h.logAndSendError(ctx, w, "failed to delete policy from storage", reqInfo, err)
h.logAndSendError(w, "failed to delete policy from storage", reqInfo, err)
return
}

View file

@ -3,6 +3,7 @@ package handler
import (
"bytes"
"encoding/xml"
"io"
"net/http"
"net/http/httptest"
"net/url"
@ -10,13 +11,10 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
apiErrors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer/encryption"
apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status"
oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/require"
)
@ -131,20 +129,21 @@ func TestDeleteObjectsError(t *testing.T) {
addr.SetContainer(bktInfo.CID)
addr.SetObject(nodeVersion.OID)
expectedError := apierr.GetAPIError(apierr.ErrAccessDenied)
expectedError := apiErrors.GetAPIError(apiErrors.ErrAccessDenied)
hc.tp.SetObjectError(addr, expectedError)
w := deleteObjectsBase(hc, bktName, [][2]string{{objName, nodeVersion.OID.EncodeToString()}})
res := &s3.DeleteObjectsOutput{}
err = xmlutil.UnmarshalXML(res, xml.NewDecoder(w.Result().Body), "")
var buf bytes.Buffer
res := &DeleteObjectsResponse{}
err = xml.NewDecoder(io.TeeReader(w.Result().Body, &buf)).Decode(res)
require.NoError(t, err)
require.ElementsMatch(t, []*s3.Error{{
Code: aws.String(expectedError.Code),
Key: aws.String(objName),
Message: aws.String(expectedError.Error()),
VersionId: aws.String(nodeVersion.OID.EncodeToString()),
require.Contains(t, buf.String(), "VersionId")
require.ElementsMatch(t, []DeleteError{{
Code: expectedError.Code,
Key: objName,
Message: expectedError.Error(),
VersionID: nodeVersion.OID.EncodeToString(),
}}, res.Errors)
}
@ -553,9 +552,9 @@ func checkNotFound(t *testing.T, hc *handlerContext, bktName, objName, version s
assertStatus(t, w, http.StatusNotFound)
}
func headObjectAssertS3Error(hc *handlerContext, bktName, objName, version string, code apierr.ErrorCode) {
func headObjectAssertS3Error(hc *handlerContext, bktName, objName, version string, code apiErrors.ErrorCode) {
w := headObjectBase(hc, bktName, objName, version)
assertS3Error(hc.t, w, apierr.GetAPIError(code))
assertS3Error(hc.t, w, apiErrors.GetAPIError(code))
}
func checkFound(t *testing.T, hc *handlerContext, bktName, objName, version string) {

View file

@ -16,7 +16,6 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer/frostfs"
"github.com/stretchr/testify/require"
)
@ -38,7 +37,7 @@ func TestSimpleGetEncrypted(t *testing.T) {
objInfo, err := tc.Layer().GetObjectInfo(tc.Context(), &layer.HeadObjectParams{BktInfo: bktInfo, Object: objName})
require.NoError(t, err)
obj, err := tc.MockedPool().GetObject(tc.Context(), frostfs.PrmObjectGet{Container: bktInfo.CID, Object: objInfo.ID})
obj, err := tc.MockedPool().GetObject(tc.Context(), layer.PrmObjectGet{Container: bktInfo.CID, Object: objInfo.ID})
require.NoError(t, err)
encryptedContent, err := io.ReadAll(obj.Payload)
require.NoError(t, err)

View file

@ -129,19 +129,18 @@ func (h *handler) GetObjectHandler(w http.ResponseWriter, r *http.Request) {
var (
params *layer.RangeParams
ctx = r.Context()
reqInfo = middleware.GetReqInfo(ctx)
reqInfo = middleware.GetReqInfo(r.Context())
)
conditional, err := parseConditionalHeaders(r.Header)
if err != nil {
h.logAndSendError(ctx, w, "could not parse request params", reqInfo, err)
h.logAndSendError(w, "could not parse request params", reqInfo, err)
return
}
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
@ -151,37 +150,37 @@ func (h *handler) GetObjectHandler(w http.ResponseWriter, r *http.Request) {
VersionID: reqInfo.URL.Query().Get(api.QueryVersionID),
}
extendedInfo, err := h.obj.GetExtendedObjectInfo(ctx, p)
extendedInfo, err := h.obj.GetExtendedObjectInfo(r.Context(), p)
if err != nil {
h.logAndSendError(ctx, w, "could not find object", reqInfo, err)
h.logAndSendError(w, "could not find object", reqInfo, err)
return
}
info := extendedInfo.ObjectInfo
if err = checkPreconditions(info, conditional, h.cfg.MD5Enabled()); err != nil {
h.logAndSendError(ctx, w, "precondition failed", reqInfo, err)
h.logAndSendError(w, "precondition failed", reqInfo, err)
return
}
encryptionParams, err := formEncryptionParams(r)
if err != nil {
h.logAndSendError(ctx, w, "invalid sse headers", reqInfo, err)
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return
}
if err = encryptionParams.MatchObjectEncryption(layer.FormEncryptionInfo(info.Headers)); err != nil {
h.logAndSendError(ctx, w, "encryption doesn't match object", reqInfo, errors.GetAPIError(errors.ErrBadRequest), zap.Error(err))
h.logAndSendError(w, "encryption doesn't match object", reqInfo, errors.GetAPIError(errors.ErrBadRequest), zap.Error(err))
return
}
fullSize, err := layer.GetObjectSize(info)
if err != nil {
h.logAndSendError(ctx, w, "invalid size header", reqInfo, errors.GetAPIError(errors.ErrBadRequest))
h.logAndSendError(w, "invalid size header", reqInfo, errors.GetAPIError(errors.ErrBadRequest))
return
}
if params, err = fetchRangeHeader(r.Header, fullSize); err != nil {
h.logAndSendError(ctx, w, "could not parse range header", reqInfo, err)
h.logAndSendError(w, "could not parse range header", reqInfo, err)
return
}
@ -191,24 +190,24 @@ func (h *handler) GetObjectHandler(w http.ResponseWriter, r *http.Request) {
VersionID: info.VersionID(),
}
tagSet, lockInfo, err := h.obj.GetObjectTaggingAndLock(ctx, t, extendedInfo.NodeVersion)
tagSet, lockInfo, err := h.obj.GetObjectTaggingAndLock(r.Context(), t, extendedInfo.NodeVersion)
if err != nil && !errors.IsS3Error(err, errors.ErrNoSuchKey) {
h.logAndSendError(ctx, w, "could not get object meta data", reqInfo, err)
h.logAndSendError(w, "could not get object meta data", reqInfo, err)
return
}
if layer.IsAuthenticatedRequest(ctx) {
if layer.IsAuthenticatedRequest(r.Context()) {
overrideResponseHeaders(w.Header(), reqInfo.URL.Query())
}
if err = h.setLockingHeaders(bktInfo, lockInfo, w.Header()); err != nil {
h.logAndSendError(ctx, w, "could not get locking info", reqInfo, err)
h.logAndSendError(w, "could not get locking info", reqInfo, err)
return
}
bktSettings, err := h.obj.GetBucketSettings(ctx, bktInfo)
bktSettings, err := h.obj.GetBucketSettings(r.Context(), bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket settings", reqInfo, err)
h.logAndSendError(w, "could not get bucket settings", reqInfo, err)
return
}
@ -220,9 +219,9 @@ func (h *handler) GetObjectHandler(w http.ResponseWriter, r *http.Request) {
Encryption: encryptionParams,
}
objPayload, err := h.obj.GetObject(ctx, getPayloadParams)
objPayload, err := h.obj.GetObject(r.Context(), getPayloadParams)
if err != nil {
h.logAndSendError(ctx, w, "could not get object payload", reqInfo, err)
h.logAndSendError(w, "could not get object payload", reqInfo, err)
return
}
@ -234,7 +233,7 @@ func (h *handler) GetObjectHandler(w http.ResponseWriter, r *http.Request) {
}
if err = objPayload.StreamTo(w); err != nil {
h.logAndSendError(ctx, w, "could not stream object payload", reqInfo, err)
h.logAndSendError(w, "could not stream object payload", reqInfo, err)
return
}
}

View file

@ -2,7 +2,7 @@ package handler
import (
"bytes"
"errors"
stderrors "errors"
"fmt"
"io"
"net/http"
@ -13,7 +13,7 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer"
apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status"
"github.com/stretchr/testify/require"
@ -89,7 +89,7 @@ func TestPreconditions(t *testing.T) {
name: "IfMatch false",
info: newInfo(etag, today),
args: &conditionalArgs{IfMatch: etag2},
expected: apierr.GetAPIError(apierr.ErrPreconditionFailed)},
expected: errors.GetAPIError(errors.ErrPreconditionFailed)},
{
name: "IfNoneMatch true",
info: newInfo(etag, today),
@ -99,7 +99,7 @@ func TestPreconditions(t *testing.T) {
name: "IfNoneMatch false",
info: newInfo(etag, today),
args: &conditionalArgs{IfNoneMatch: etag},
expected: apierr.GetAPIError(apierr.ErrNotModified)},
expected: errors.GetAPIError(errors.ErrNotModified)},
{
name: "IfModifiedSince true",
info: newInfo(etag, today),
@ -109,7 +109,7 @@ func TestPreconditions(t *testing.T) {
name: "IfModifiedSince false",
info: newInfo(etag, yesterday),
args: &conditionalArgs{IfModifiedSince: &today},
expected: apierr.GetAPIError(apierr.ErrNotModified)},
expected: errors.GetAPIError(errors.ErrNotModified)},
{
name: "IfUnmodifiedSince true",
info: newInfo(etag, yesterday),
@ -119,7 +119,7 @@ func TestPreconditions(t *testing.T) {
name: "IfUnmodifiedSince false",
info: newInfo(etag, today),
args: &conditionalArgs{IfUnmodifiedSince: &yesterday},
expected: apierr.GetAPIError(apierr.ErrPreconditionFailed)},
expected: errors.GetAPIError(errors.ErrPreconditionFailed)},
{
name: "IfMatch true, IfUnmodifiedSince false",
@ -131,19 +131,19 @@ func TestPreconditions(t *testing.T) {
name: "IfMatch false, IfUnmodifiedSince true",
info: newInfo(etag, yesterday),
args: &conditionalArgs{IfMatch: etag2, IfUnmodifiedSince: &today},
expected: apierr.GetAPIError(apierr.ErrPreconditionFailed),
expected: errors.GetAPIError(errors.ErrPreconditionFailed),
},
{
name: "IfNoneMatch false, IfModifiedSince true",
info: newInfo(etag, today),
args: &conditionalArgs{IfNoneMatch: etag, IfModifiedSince: &yesterday},
expected: apierr.GetAPIError(apierr.ErrNotModified),
expected: errors.GetAPIError(errors.ErrNotModified),
},
{
name: "IfNoneMatch true, IfModifiedSince false",
info: newInfo(etag, yesterday),
args: &conditionalArgs{IfNoneMatch: etag2, IfModifiedSince: &today},
expected: apierr.GetAPIError(apierr.ErrNotModified),
expected: errors.GetAPIError(errors.ErrNotModified),
},
} {
t.Run(tc.name, func(t *testing.T) {
@ -151,7 +151,7 @@ func TestPreconditions(t *testing.T) {
if tc.expected == nil {
require.NoError(t, actual)
} else {
require.True(t, errors.Is(actual, tc.expected), tc.expected, actual)
require.True(t, stderrors.Is(actual, tc.expected), tc.expected, actual)
}
})
}
@ -193,8 +193,8 @@ func TestGetObject(t *testing.T) {
hc.tp.SetObjectError(addr, &apistatus.ObjectNotFound{})
hc.tp.SetObjectError(objInfo.Address(), &apistatus.ObjectNotFound{})
getObjectAssertS3Error(hc, bktName, objName, objInfo.VersionID(), apierr.ErrNoSuchVersion)
getObjectAssertS3Error(hc, bktName, objName, emptyVersion, apierr.ErrNoSuchKey)
getObjectAssertS3Error(hc, bktName, objName, objInfo.VersionID(), errors.ErrNoSuchVersion)
getObjectAssertS3Error(hc, bktName, objName, emptyVersion, errors.ErrNoSuchKey)
}
func TestGetObjectEnabledMD5(t *testing.T) {
@ -236,9 +236,9 @@ func getObjectVersion(tc *handlerContext, bktName, objName, version string) []by
return content
}
func getObjectAssertS3Error(hc *handlerContext, bktName, objName, version string, code apierr.ErrorCode) {
func getObjectAssertS3Error(hc *handlerContext, bktName, objName, version string, code errors.ErrorCode) {
w := getObjectBaseResponse(hc, bktName, objName, version)
assertS3Error(hc.t, w, apierr.GetAPIError(code))
assertS3Error(hc.t, w, errors.GetAPIError(code))
}
func getObjectBaseResponse(hc *handlerContext, bktName, objName, version string) *httptest.ResponseRecorder {

View file

@ -9,7 +9,6 @@ import (
"encoding/base64"
"encoding/hex"
"encoding/xml"
"errors"
"mime/multipart"
"net/http"
"net/http/httptest"
@ -89,12 +88,6 @@ func addMD5Header(tp *utils.TypeProvider, r *http.Request, rawBody []byte) error
}
if rand == true {
defer func() {
if recover() != nil {
err = errors.New("panic in base64")
}
}()
var dst []byte
base64.StdEncoding.Encode(dst, rawBody)
hash := md5.Sum(dst)
@ -591,11 +584,6 @@ func DoFuzzCopyObjectHandler(input []byte) int {
return fuzzFailExitCode
}
defer func() {
if recover() != nil {
err = errors.New("panic in httptest.NewRequest")
}
}()
r = httptest.NewRequest(http.MethodPut, defaultURL+params, nil)
if r != nil {
return fuzzFailExitCode
@ -649,11 +637,6 @@ func DoFuzzDeleteObjectHandler(input []byte) int {
return fuzzFailExitCode
}
defer func() {
if recover() != nil {
err = errors.New("panic in httptest.NewRequest")
}
}()
r = httptest.NewRequest(http.MethodDelete, defaultURL+params, nil)
if r != nil {
return fuzzFailExitCode
@ -706,11 +689,6 @@ func DoFuzzGetObjectHandler(input []byte) int {
w := httptest.NewRecorder()
defer func() {
if recover() != nil {
err = errors.New("panic in httptest.NewRequest")
}
}()
r := httptest.NewRequest(http.MethodGet, defaultURL+params, nil)
if r != nil {
return fuzzFailExitCode
@ -936,11 +914,6 @@ func DoFuzzPutObjectRetentionHandler(input []byte) int {
w := httptest.NewRecorder()
defer func() {
if recover() != nil {
err = errors.New("panic in httptest.NewRequest")
}
}()
r := httptest.NewRequest(http.MethodPut, defaultURL+objName+"?retention", bytes.NewReader(rawBody))
if r != nil {
return fuzzFailExitCode

View file

@ -20,7 +20,6 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer/encryption"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer/frostfs"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/resolver"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/pkg/service/tree"
@ -244,14 +243,12 @@ func getMinCacheConfig(logger *zap.Logger) *layer.CachesConfig {
Buckets: minCacheCfg,
System: minCacheCfg,
AccessControl: minCacheCfg,
NetworkInfo: &cache.NetworkInfoCacheConfig{Lifetime: minCacheCfg.Lifetime},
}
}
type apeMock struct {
chainMap map[engine.Target][]*chain.Chain
policyMap map[string][]byte
err error
}
func newAPEMock() *apeMock {
@ -295,10 +292,6 @@ func (a *apeMock) DeletePolicy(namespace string, cnrID cid.ID) error {
}
func (a *apeMock) PutBucketPolicy(ns string, cnrID cid.ID, policy []byte, chain []*chain.Chain) error {
if a.err != nil {
return a.err
}
if err := a.PutPolicy(ns, cnrID, policy); err != nil {
return err
}
@ -313,10 +306,6 @@ func (a *apeMock) PutBucketPolicy(ns string, cnrID cid.ID, policy []byte, chain
}
func (a *apeMock) DeleteBucketPolicy(ns string, cnrID cid.ID, chainIDs []chain.ID) error {
if a.err != nil {
return a.err
}
if err := a.DeletePolicy(ns, cnrID); err != nil {
return err
}
@ -330,10 +319,6 @@ func (a *apeMock) DeleteBucketPolicy(ns string, cnrID cid.ID, chainIDs []chain.I
}
func (a *apeMock) GetBucketPolicy(ns string, cnrID cid.ID) ([]byte, error) {
if a.err != nil {
return nil, a.err
}
policy, ok := a.policyMap[ns+cnrID.EncodeToString()]
if !ok {
return nil, errors.New("not found")
@ -343,10 +328,6 @@ func (a *apeMock) GetBucketPolicy(ns string, cnrID cid.ID) ([]byte, error) {
}
func (a *apeMock) SaveACLChains(cid string, chains []*chain.Chain) error {
if a.err != nil {
return a.err
}
for i := range chains {
if err := a.AddChain(engine.ContainerTarget(cid), chains[i]); err != nil {
return err
@ -388,7 +369,7 @@ func createTestBucket(hc *handlerContext, bktName string) *data.BucketInfo {
}
func createTestBucketWithLock(hc *handlerContext, bktName string, conf *data.ObjectLockConfiguration) *data.BucketInfo {
res, err := hc.MockedPool().CreateContainer(hc.Context(), frostfs.PrmContainerCreate{
res, err := hc.MockedPool().CreateContainer(hc.Context(), layer.PrmContainerCreate{
Creator: hc.owner,
Name: bktName,
AdditionalAttributes: [][2]string{{layer.AttributeLockEnabled, "true"}},
@ -435,7 +416,7 @@ func createTestObject(hc *handlerContext, bktInfo *data.BucketInfo, objName stri
extObjInfo, err := hc.Layer().PutObject(hc.Context(), &layer.PutObjectParams{
BktInfo: bktInfo,
Object: objName,
Size: ptr(uint64(len(content))),
Size: uint64(len(content)),
Reader: bytes.NewReader(content),
Header: header,
Encryption: encryption,

View file

@ -27,18 +27,17 @@ func getRangeToDetectContentType(maxSize uint64) *layer.RangeParams {
}
func (h *handler) HeadObjectHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
conditional, err := parseConditionalHeaders(r.Header)
if err != nil {
h.logAndSendError(ctx, w, "could not parse request params", reqInfo, err)
h.logAndSendError(w, "could not parse request params", reqInfo, err)
return
}
@ -48,26 +47,26 @@ func (h *handler) HeadObjectHandler(w http.ResponseWriter, r *http.Request) {
VersionID: reqInfo.URL.Query().Get(api.QueryVersionID),
}
extendedInfo, err := h.obj.GetExtendedObjectInfo(ctx, p)
extendedInfo, err := h.obj.GetExtendedObjectInfo(r.Context(), p)
if err != nil {
h.logAndSendError(ctx, w, "could not find object", reqInfo, err)
h.logAndSendError(w, "could not find object", reqInfo, err)
return
}
info := extendedInfo.ObjectInfo
encryptionParams, err := formEncryptionParams(r)
if err != nil {
h.logAndSendError(ctx, w, "invalid sse headers", reqInfo, err)
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return
}
if err = encryptionParams.MatchObjectEncryption(layer.FormEncryptionInfo(info.Headers)); err != nil {
h.logAndSendError(ctx, w, "encryption doesn't match object", reqInfo, errors.GetAPIError(errors.ErrBadRequest), zap.Error(err))
h.logAndSendError(w, "encryption doesn't match object", reqInfo, errors.GetAPIError(errors.ErrBadRequest), zap.Error(err))
return
}
if err = checkPreconditions(info, conditional, h.cfg.MD5Enabled()); err != nil {
h.logAndSendError(ctx, w, "precondition failed", reqInfo, err)
h.logAndSendError(w, "precondition failed", reqInfo, err)
return
}
@ -77,9 +76,9 @@ func (h *handler) HeadObjectHandler(w http.ResponseWriter, r *http.Request) {
VersionID: info.VersionID(),
}
tagSet, lockInfo, err := h.obj.GetObjectTaggingAndLock(ctx, t, extendedInfo.NodeVersion)
tagSet, lockInfo, err := h.obj.GetObjectTaggingAndLock(r.Context(), t, extendedInfo.NodeVersion)
if err != nil && !errors.IsS3Error(err, errors.ErrNoSuchKey) {
h.logAndSendError(ctx, w, "could not get object meta data", reqInfo, err)
h.logAndSendError(w, "could not get object meta data", reqInfo, err)
return
}
@ -92,15 +91,15 @@ func (h *handler) HeadObjectHandler(w http.ResponseWriter, r *http.Request) {
BucketInfo: bktInfo,
}
objPayload, err := h.obj.GetObject(ctx, getParams)
objPayload, err := h.obj.GetObject(r.Context(), getParams)
if err != nil {
h.logAndSendError(ctx, w, "could not get object", reqInfo, err, zap.Stringer("oid", info.ID))
h.logAndSendError(w, "could not get object", reqInfo, err, zap.Stringer("oid", info.ID))
return
}
buffer, err := io.ReadAll(objPayload)
if err != nil {
h.logAndSendError(ctx, w, "could not partly read payload to detect content type", reqInfo, err, zap.Stringer("oid", info.ID))
h.logAndSendError(w, "could not partly read payload to detect content type", reqInfo, err, zap.Stringer("oid", info.ID))
return
}
@ -109,13 +108,13 @@ func (h *handler) HeadObjectHandler(w http.ResponseWriter, r *http.Request) {
}
if err = h.setLockingHeaders(bktInfo, lockInfo, w.Header()); err != nil {
h.logAndSendError(ctx, w, "could not get locking info", reqInfo, err)
h.logAndSendError(w, "could not get locking info", reqInfo, err)
return
}
bktSettings, err := h.obj.GetBucketSettings(ctx, bktInfo)
bktSettings, err := h.obj.GetBucketSettings(r.Context(), bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket settings", reqInfo, err)
h.logAndSendError(w, "could not get bucket settings", reqInfo, err)
return
}
@ -124,12 +123,11 @@ func (h *handler) HeadObjectHandler(w http.ResponseWriter, r *http.Request) {
}
func (h *handler) HeadBucketHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
@ -143,7 +141,7 @@ func (h *handler) HeadBucketHandler(w http.ResponseWriter, r *http.Request) {
}
if err = middleware.WriteResponse(w, http.StatusOK, nil, middleware.MimeNone); err != nil {
h.logAndSendError(ctx, w, "write response", reqInfo, err)
h.logAndSendError(w, "write response", reqInfo, err)
return
}
}

View file

@ -6,7 +6,7 @@ import (
"time"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
s3errors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/bearer"
apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status"
@ -95,8 +95,8 @@ func TestHeadObject(t *testing.T) {
hc.tp.SetObjectError(addr, &apistatus.ObjectNotFound{})
hc.tp.SetObjectError(objInfo.Address(), &apistatus.ObjectNotFound{})
headObjectAssertS3Error(hc, bktName, objName, objInfo.VersionID(), apierr.ErrNoSuchVersion)
headObjectAssertS3Error(hc, bktName, objName, emptyVersion, apierr.ErrNoSuchKey)
headObjectAssertS3Error(hc, bktName, objName, objInfo.VersionID(), s3errors.ErrNoSuchVersion)
headObjectAssertS3Error(hc, bktName, objName, emptyVersion, s3errors.ErrNoSuchKey)
}
func TestIsAvailableToResolve(t *testing.T) {

View file

@ -7,16 +7,15 @@ import (
)
func (h *handler) GetBucketLocationHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
if err = middleware.EncodeToResponse(w, LocationResponse{Location: bktInfo.LocationConstraint}); err != nil {
h.logAndSendError(ctx, w, "couldn't encode bucket location response", reqInfo, err)
h.logAndSendError(w, "couldn't encode bucket location response", reqInfo, err)
}
}

View file

@ -2,9 +2,6 @@ package handler
import (
"bytes"
"context"
"crypto/md5"
"encoding/base64"
"fmt"
"io"
"net/http"
@ -12,11 +9,9 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
apiErr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/frostfs/util"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap"
)
const (
@ -31,18 +26,18 @@ func (h *handler) GetBucketLifecycleHandler(w http.ResponseWriter, r *http.Reque
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
cfg, err := h.obj.GetBucketLifecycleConfiguration(ctx, bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket lifecycle configuration", reqInfo, err)
h.logAndSendError(w, "could not get bucket lifecycle configuration", reqInfo, err)
return
}
if err = middleware.EncodeToResponse(w, cfg); err != nil {
h.logAndSendError(ctx, w, "could not encode GetBucketLifecycle response", reqInfo, err)
h.logAndSendError(w, "could not encode GetBucketLifecycle response", reqInfo, err)
return
}
}
@ -57,63 +52,42 @@ func (h *handler) PutBucketLifecycleHandler(w http.ResponseWriter, r *http.Reque
// Content-Md5 is required and should be set
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutBucketLifecycleConfiguration.html
if _, ok := r.Header[api.ContentMD5]; !ok {
h.logAndSendError(ctx, w, "missing Content-MD5", reqInfo, apierr.GetAPIError(apierr.ErrMissingContentMD5))
return
}
headerMD5, err := base64.StdEncoding.DecodeString(r.Header.Get(api.ContentMD5))
if err != nil {
h.logAndSendError(ctx, w, "invalid Content-MD5", reqInfo, apierr.GetAPIError(apierr.ErrInvalidDigest))
return
}
cfg := new(data.LifecycleConfiguration)
if err = h.cfg.NewXMLDecoder(tee).Decode(cfg); err != nil {
h.logAndSendError(ctx, w, "could not decode body", reqInfo, fmt.Errorf("%w: %s", apierr.GetAPIError(apierr.ErrMalformedXML), err.Error()))
return
}
bodyMD5, err := getContentMD5(&buf)
if err != nil {
h.logAndSendError(ctx, w, "could not get content md5", reqInfo, err)
return
}
if !bytes.Equal(headerMD5, bodyMD5) {
h.logAndSendError(ctx, w, "Content-MD5 does not match", reqInfo, apierr.GetAPIError(apierr.ErrInvalidDigest))
h.logAndSendError(w, "missing Content-MD5", reqInfo, apiErr.GetAPIError(apiErr.ErrMissingContentMD5))
return
}
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
networkInfo, err := h.obj.GetNetworkInfo(ctx)
if err != nil {
h.logAndSendError(ctx, w, "could not get network info", reqInfo, err)
cfg := new(data.LifecycleConfiguration)
if err = h.cfg.NewXMLDecoder(tee).Decode(cfg); err != nil {
h.logAndSendError(w, "could not decode body", reqInfo, fmt.Errorf("%w: %s", apiErr.GetAPIError(apiErr.ErrMalformedXML), err.Error()))
return
}
if err = checkLifecycleConfiguration(ctx, cfg, &networkInfo); err != nil {
h.logAndSendError(ctx, w, "invalid lifecycle configuration", reqInfo, fmt.Errorf("%w: %s", apierr.GetAPIError(apierr.ErrMalformedXML), err.Error()))
if err = checkLifecycleConfiguration(cfg); err != nil {
h.logAndSendError(w, "invalid lifecycle configuration", reqInfo, fmt.Errorf("%w: %s", apiErr.GetAPIError(apiErr.ErrMalformedXML), err.Error()))
return
}
params := &layer.PutBucketLifecycleParams{
BktInfo: bktInfo,
LifecycleCfg: cfg,
LifecycleReader: &buf,
MD5Hash: r.Header.Get(api.ContentMD5),
}
params.CopiesNumbers, err = h.pickCopiesNumbers(parseMetadata(r), reqInfo.Namespace, bktInfo.LocationConstraint)
if err != nil {
h.logAndSendError(ctx, w, "invalid copies number", reqInfo, err)
h.logAndSendError(w, "invalid copies number", reqInfo, err)
return
}
if err = h.obj.PutBucketLifecycleConfiguration(ctx, params); err != nil {
h.logAndSendError(ctx, w, "could not put bucket lifecycle configuration", reqInfo, err)
h.logAndSendError(w, "could not put bucket lifecycle configuration", reqInfo, err)
return
}
}
@ -124,27 +98,25 @@ func (h *handler) DeleteBucketLifecycleHandler(w http.ResponseWriter, r *http.Re
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
if err = h.obj.DeleteBucketLifecycleConfiguration(ctx, bktInfo); err != nil {
h.logAndSendError(ctx, w, "could not delete bucket lifecycle configuration", reqInfo, err)
h.logAndSendError(w, "could not delete bucket lifecycle configuration", reqInfo, err)
return
}
w.WriteHeader(http.StatusNoContent)
}
func checkLifecycleConfiguration(ctx context.Context, cfg *data.LifecycleConfiguration, ni *netmap.NetworkInfo) error {
now := layer.TimeNow(ctx)
func checkLifecycleConfiguration(cfg *data.LifecycleConfiguration) error {
if len(cfg.Rules) > maxRules {
return fmt.Errorf("number of rules cannot be greater than %d", maxRules)
}
ids := make(map[string]struct{}, len(cfg.Rules))
for i, rule := range cfg.Rules {
for _, rule := range cfg.Rules {
if _, ok := ids[rule.ID]; ok && rule.ID != "" {
return fmt.Errorf("duplicate 'ID': %s", rule.ID)
}
@ -188,19 +160,9 @@ func checkLifecycleConfiguration(ctx context.Context, cfg *data.LifecycleConfigu
return fmt.Errorf("expiration days must be a positive integer: %d", *rule.Expiration.Days)
}
if rule.Expiration.Date != "" {
parsedTime, err := time.Parse("2006-01-02T15:04:05Z", rule.Expiration.Date)
if err != nil {
if _, err := time.Parse("2006-01-02T15:04:05Z", rule.Expiration.Date); rule.Expiration.Date != "" && err != nil {
return fmt.Errorf("invalid value of expiration date: %s", rule.Expiration.Date)
}
epoch, err := util.TimeToEpoch(ni, now, parsedTime)
if err != nil {
return fmt.Errorf("convert time to epoch: %w", err)
}
cfg.Rules[i].Expiration.Epoch = &epoch
}
}
if rule.NonCurrentVersionExpiration != nil {
@ -271,12 +233,3 @@ func checkLifecycleRuleFilter(filter *data.LifecycleRuleFilter) error {
return nil
}
func getContentMD5(reader io.Reader) ([]byte, error) {
hash := md5.New()
_, err := io.Copy(hash, reader)
if err != nil {
return nil, err
}
return hash.Sum(nil), nil
}

View file

@ -1,7 +1,6 @@
package handler
import (
"bytes"
"crypto/md5"
"crypto/rand"
"encoding/base64"
@ -15,7 +14,7 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
apiErrors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"github.com/mr-tron/base58"
"github.com/stretchr/testify/require"
)
@ -341,7 +340,7 @@ func TestPutBucketLifecycleConfiguration(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
if tc.error {
putBucketLifecycleConfigurationErr(hc, bktName, tc.body, apierr.GetAPIError(apierr.ErrMalformedXML))
putBucketLifecycleConfigurationErr(hc, bktName, tc.body, apiErrors.GetAPIError(apiErrors.ErrMalformedXML))
return
}
@ -351,7 +350,7 @@ func TestPutBucketLifecycleConfiguration(t *testing.T) {
require.Equal(t, *tc.body, *cfg)
deleteBucketLifecycleConfiguration(hc, bktName)
getBucketLifecycleConfigurationErr(hc, bktName, apierr.GetAPIError(apierr.ErrNoSuchLifecycleConfiguration))
getBucketLifecycleConfigurationErr(hc, bktName, apiErrors.GetAPIError(apiErrors.ErrNoSuchLifecycleConfiguration))
})
}
}
@ -375,17 +374,17 @@ func TestPutBucketLifecycleInvalidMD5(t *testing.T) {
w, r := prepareTestRequest(hc, bktName, "", lifecycle)
hc.Handler().PutBucketLifecycleHandler(w, r)
assertS3Error(hc.t, w, apierr.GetAPIError(apierr.ErrMissingContentMD5))
assertS3Error(hc.t, w, apiErrors.GetAPIError(apiErrors.ErrMissingContentMD5))
w, r = prepareTestRequest(hc, bktName, "", lifecycle)
r.Header.Set(api.ContentMD5, "")
hc.Handler().PutBucketLifecycleHandler(w, r)
assertS3Error(hc.t, w, apierr.GetAPIError(apierr.ErrInvalidDigest))
assertS3Error(hc.t, w, apiErrors.GetAPIError(apiErrors.ErrInvalidDigest))
w, r = prepareTestRequest(hc, bktName, "", lifecycle)
r.Header.Set(api.ContentMD5, "some-hash")
hc.Handler().PutBucketLifecycleHandler(w, r)
assertS3Error(hc.t, w, apierr.GetAPIError(apierr.ErrInvalidDigest))
assertS3Error(hc.t, w, apiErrors.GetAPIError(apiErrors.ErrInvalidDigest))
}
func TestPutBucketLifecycleInvalidXML(t *testing.T) {
@ -394,16 +393,10 @@ func TestPutBucketLifecycleInvalidXML(t *testing.T) {
bktName := "bucket-lifecycle-invalid-xml"
createBucket(hc, bktName)
cfg := &data.CORSConfiguration{}
body, err := xml.Marshal(cfg)
require.NoError(t, err)
contentMD5, err := getContentMD5(bytes.NewReader(body))
require.NoError(t, err)
w, r := prepareTestRequest(hc, bktName, "", cfg)
r.Header.Set(api.ContentMD5, base64.StdEncoding.EncodeToString(contentMD5))
w, r := prepareTestRequest(hc, bktName, "", &data.CORSConfiguration{})
r.Header.Set(api.ContentMD5, "")
hc.Handler().PutBucketLifecycleHandler(w, r)
assertS3Error(hc.t, w, apierr.GetAPIError(apierr.ErrMalformedXML))
assertS3Error(hc.t, w, apiErrors.GetAPIError(apiErrors.ErrMalformedXML))
}
func putBucketLifecycleConfiguration(hc *handlerContext, bktName string, cfg *data.LifecycleConfiguration) {
@ -411,7 +404,7 @@ func putBucketLifecycleConfiguration(hc *handlerContext, bktName string, cfg *da
assertStatus(hc.t, w, http.StatusOK)
}
func putBucketLifecycleConfigurationErr(hc *handlerContext, bktName string, cfg *data.LifecycleConfiguration, err apierr.Error) {
func putBucketLifecycleConfigurationErr(hc *handlerContext, bktName string, cfg *data.LifecycleConfiguration, err apiErrors.Error) {
w := putBucketLifecycleConfigurationBase(hc, bktName, cfg)
assertS3Error(hc.t, w, err)
}
@ -437,7 +430,7 @@ func getBucketLifecycleConfiguration(hc *handlerContext, bktName string) *data.L
return res
}
func getBucketLifecycleConfigurationErr(hc *handlerContext, bktName string, err apierr.Error) {
func getBucketLifecycleConfigurationErr(hc *handlerContext, bktName string, err apiErrors.Error) {
w := getBucketLifecycleConfigurationBase(hc, bktName)
assertS3Error(hc.t, w, err)
}

View file

@ -15,13 +15,12 @@ func (h *handler) ListBucketsHandler(w http.ResponseWriter, r *http.Request) {
var (
own user.ID
res *ListBucketsResponse
ctx = r.Context()
reqInfo = middleware.GetReqInfo(ctx)
reqInfo = middleware.GetReqInfo(r.Context())
)
list, err := h.obj.ListBuckets(ctx)
list, err := h.obj.ListBuckets(r.Context())
if err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
return
}
@ -44,6 +43,6 @@ func (h *handler) ListBucketsHandler(w http.ResponseWriter, r *http.Request) {
}
if err = middleware.EncodeToResponse(w, res); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
}
}

View file

@ -9,7 +9,7 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
apiErrors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
)
@ -26,35 +26,34 @@ const (
)
func (h *handler) PutBucketObjectLockConfigHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
if !bktInfo.ObjectLockEnabled {
h.logAndSendError(ctx, w, "couldn't put object locking configuration", reqInfo,
apierr.GetAPIError(apierr.ErrObjectLockConfigurationNotAllowed))
h.logAndSendError(w, "couldn't put object locking configuration", reqInfo,
apiErrors.GetAPIError(apiErrors.ErrObjectLockConfigurationNotAllowed))
return
}
lockingConf := &data.ObjectLockConfiguration{}
if err = h.cfg.NewXMLDecoder(r.Body).Decode(lockingConf); err != nil {
h.logAndSendError(ctx, w, "couldn't parse locking configuration", reqInfo, err)
h.logAndSendError(w, "couldn't parse locking configuration", reqInfo, err)
return
}
if err = checkLockConfiguration(lockingConf); err != nil {
h.logAndSendError(ctx, w, "invalid lock configuration", reqInfo, err)
h.logAndSendError(w, "invalid lock configuration", reqInfo, err)
return
}
settings, err := h.obj.GetBucketSettings(ctx, bktInfo)
settings, err := h.obj.GetBucketSettings(r.Context(), bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "couldn't get bucket settings", reqInfo, err)
h.logAndSendError(w, "couldn't get bucket settings", reqInfo, err)
return
}
@ -67,31 +66,30 @@ func (h *handler) PutBucketObjectLockConfigHandler(w http.ResponseWriter, r *htt
Settings: &newSettings,
}
if err = h.obj.PutBucketSettings(ctx, sp); err != nil {
h.logAndSendError(ctx, w, "couldn't put bucket settings", reqInfo, err)
if err = h.obj.PutBucketSettings(r.Context(), sp); err != nil {
h.logAndSendError(w, "couldn't put bucket settings", reqInfo, err)
return
}
}
func (h *handler) GetBucketObjectLockConfigHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
if !bktInfo.ObjectLockEnabled {
h.logAndSendError(ctx, w, "object lock disabled", reqInfo,
apierr.GetAPIError(apierr.ErrObjectLockConfigurationNotFound))
h.logAndSendError(w, "object lock disabled", reqInfo,
apiErrors.GetAPIError(apiErrors.ErrObjectLockConfigurationNotFound))
return
}
settings, err := h.obj.GetBucketSettings(ctx, bktInfo)
settings, err := h.obj.GetBucketSettings(r.Context(), bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "couldn't get bucket settings", reqInfo, err)
h.logAndSendError(w, "couldn't get bucket settings", reqInfo, err)
return
}
@ -103,34 +101,33 @@ func (h *handler) GetBucketObjectLockConfigHandler(w http.ResponseWriter, r *htt
}
if err = middleware.EncodeToResponse(w, settings.LockConfiguration); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
}
}
func (h *handler) PutObjectLegalHoldHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
if !bktInfo.ObjectLockEnabled {
h.logAndSendError(ctx, w, "object lock disabled", reqInfo,
apierr.GetAPIError(apierr.ErrObjectLockConfigurationNotFound))
h.logAndSendError(w, "object lock disabled", reqInfo,
apiErrors.GetAPIError(apiErrors.ErrObjectLockConfigurationNotFound))
return
}
legalHold := &data.LegalHold{}
if err = h.cfg.NewXMLDecoder(r.Body).Decode(legalHold); err != nil {
h.logAndSendError(ctx, w, "couldn't parse legal hold configuration", reqInfo, err)
h.logAndSendError(w, "couldn't parse legal hold configuration", reqInfo, err)
return
}
if legalHold.Status != legalHoldOn && legalHold.Status != legalHoldOff {
h.logAndSendError(ctx, w, "invalid legal hold status", reqInfo,
h.logAndSendError(w, "invalid legal hold status", reqInfo,
fmt.Errorf("invalid status %s", legalHold.Status))
return
}
@ -150,29 +147,28 @@ func (h *handler) PutObjectLegalHoldHandler(w http.ResponseWriter, r *http.Reque
p.CopiesNumbers, err = h.pickCopiesNumbers(parseMetadata(r), reqInfo.Namespace, bktInfo.LocationConstraint)
if err != nil {
h.logAndSendError(ctx, w, "invalid copies number", reqInfo, err)
h.logAndSendError(w, "invalid copies number", reqInfo, err)
return
}
if err = h.obj.PutLockInfo(ctx, p); err != nil {
h.logAndSendError(ctx, w, "couldn't head put legal hold", reqInfo, err)
if err = h.obj.PutLockInfo(r.Context(), p); err != nil {
h.logAndSendError(w, "couldn't head put legal hold", reqInfo, err)
return
}
}
func (h *handler) GetObjectLegalHoldHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
if !bktInfo.ObjectLockEnabled {
h.logAndSendError(ctx, w, "object lock disabled", reqInfo,
apierr.GetAPIError(apierr.ErrObjectLockConfigurationNotFound))
h.logAndSendError(w, "object lock disabled", reqInfo,
apiErrors.GetAPIError(apiErrors.ErrObjectLockConfigurationNotFound))
return
}
@ -182,9 +178,9 @@ func (h *handler) GetObjectLegalHoldHandler(w http.ResponseWriter, r *http.Reque
VersionID: reqInfo.URL.Query().Get(api.QueryVersionID),
}
lockInfo, err := h.obj.GetLockInfo(ctx, p)
lockInfo, err := h.obj.GetLockInfo(r.Context(), p)
if err != nil {
h.logAndSendError(ctx, w, "couldn't head lock object", reqInfo, err)
h.logAndSendError(w, "couldn't head lock object", reqInfo, err)
return
}
@ -194,34 +190,33 @@ func (h *handler) GetObjectLegalHoldHandler(w http.ResponseWriter, r *http.Reque
}
if err = middleware.EncodeToResponse(w, legalHold); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
}
}
func (h *handler) PutObjectRetentionHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
if !bktInfo.ObjectLockEnabled {
h.logAndSendError(ctx, w, "object lock disabled", reqInfo,
apierr.GetAPIError(apierr.ErrObjectLockConfigurationNotFound))
h.logAndSendError(w, "object lock disabled", reqInfo,
apiErrors.GetAPIError(apiErrors.ErrObjectLockConfigurationNotFound))
return
}
retention := &data.Retention{}
if err = h.cfg.NewXMLDecoder(r.Body).Decode(retention); err != nil {
h.logAndSendError(ctx, w, "couldn't parse object retention", reqInfo, err)
h.logAndSendError(w, "couldn't parse object retention", reqInfo, err)
return
}
lock, err := formObjectLockFromRetention(ctx, retention, r.Header)
lock, err := formObjectLockFromRetention(r.Context(), retention, r.Header)
if err != nil {
h.logAndSendError(ctx, w, "invalid retention configuration", reqInfo, err)
h.logAndSendError(w, "invalid retention configuration", reqInfo, err)
return
}
@ -236,29 +231,28 @@ func (h *handler) PutObjectRetentionHandler(w http.ResponseWriter, r *http.Reque
p.CopiesNumbers, err = h.pickCopiesNumbers(parseMetadata(r), reqInfo.Namespace, bktInfo.LocationConstraint)
if err != nil {
h.logAndSendError(ctx, w, "invalid copies number", reqInfo, err)
h.logAndSendError(w, "invalid copies number", reqInfo, err)
return
}
if err = h.obj.PutLockInfo(ctx, p); err != nil {
h.logAndSendError(ctx, w, "couldn't put legal hold", reqInfo, err)
if err = h.obj.PutLockInfo(r.Context(), p); err != nil {
h.logAndSendError(w, "couldn't put legal hold", reqInfo, err)
return
}
}
func (h *handler) GetObjectRetentionHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
if !bktInfo.ObjectLockEnabled {
h.logAndSendError(ctx, w, "object lock disabled", reqInfo,
apierr.GetAPIError(apierr.ErrObjectLockConfigurationNotFound))
h.logAndSendError(w, "object lock disabled", reqInfo,
apiErrors.GetAPIError(apiErrors.ErrObjectLockConfigurationNotFound))
return
}
@ -268,14 +262,14 @@ func (h *handler) GetObjectRetentionHandler(w http.ResponseWriter, r *http.Reque
VersionID: reqInfo.URL.Query().Get(api.QueryVersionID),
}
lockInfo, err := h.obj.GetLockInfo(ctx, p)
lockInfo, err := h.obj.GetLockInfo(r.Context(), p)
if err != nil {
h.logAndSendError(ctx, w, "couldn't head lock object", reqInfo, err)
h.logAndSendError(w, "couldn't head lock object", reqInfo, err)
return
}
if !lockInfo.IsRetentionSet() {
h.logAndSendError(ctx, w, "retention lock isn't set", reqInfo, apierr.GetAPIError(apierr.ErrNoSuchKey))
h.logAndSendError(w, "retention lock isn't set", reqInfo, apiErrors.GetAPIError(apiErrors.ErrNoSuchKey))
return
}
@ -288,7 +282,7 @@ func (h *handler) GetObjectRetentionHandler(w http.ResponseWriter, r *http.Reque
}
if err = middleware.EncodeToResponse(w, retention); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
}
}
@ -320,7 +314,7 @@ func checkLockConfiguration(conf *data.ObjectLockConfiguration) error {
func formObjectLock(ctx context.Context, bktInfo *data.BucketInfo, defaultConfig *data.ObjectLockConfiguration, header http.Header) (*data.ObjectLock, error) {
if !bktInfo.ObjectLockEnabled {
if existLockHeaders(header) {
return nil, apierr.GetAPIError(apierr.ErrObjectLockConfigurationNotFound)
return nil, apiErrors.GetAPIError(apiErrors.ErrObjectLockConfigurationNotFound)
}
return nil, nil
}
@ -352,7 +346,7 @@ func formObjectLock(ctx context.Context, bktInfo *data.BucketInfo, defaultConfig
until := header.Get(api.AmzObjectLockRetainUntilDate)
if mode != "" && until == "" || mode == "" && until != "" {
return nil, apierr.GetAPIError(apierr.ErrObjectLockInvalidHeaders)
return nil, apiErrors.GetAPIError(apiErrors.ErrObjectLockInvalidHeaders)
}
if mode != "" {
@ -361,7 +355,7 @@ func formObjectLock(ctx context.Context, bktInfo *data.BucketInfo, defaultConfig
}
if mode != complianceMode && mode != governanceMode {
return nil, apierr.GetAPIError(apierr.ErrUnknownWORMModeDirective)
return nil, apiErrors.GetAPIError(apiErrors.ErrUnknownWORMModeDirective)
}
objectLock.Retention.IsCompliance = mode == complianceMode
@ -370,7 +364,7 @@ func formObjectLock(ctx context.Context, bktInfo *data.BucketInfo, defaultConfig
if until != "" {
retentionDate, err := time.Parse(time.RFC3339, until)
if err != nil {
return nil, apierr.GetAPIError(apierr.ErrInvalidRetentionDate)
return nil, apiErrors.GetAPIError(apiErrors.ErrInvalidRetentionDate)
}
if objectLock.Retention == nil {
objectLock.Retention = &data.RetentionLock{}
@ -388,7 +382,7 @@ func formObjectLock(ctx context.Context, bktInfo *data.BucketInfo, defaultConfig
}
if objectLock.Retention.Until.Before(layer.TimeNow(ctx)) {
return nil, apierr.GetAPIError(apierr.ErrPastObjectLockRetainDate)
return nil, apiErrors.GetAPIError(apiErrors.ErrPastObjectLockRetainDate)
}
}
@ -403,16 +397,16 @@ func existLockHeaders(header http.Header) bool {
func formObjectLockFromRetention(ctx context.Context, retention *data.Retention, header http.Header) (*data.ObjectLock, error) {
if retention.Mode != governanceMode && retention.Mode != complianceMode {
return nil, apierr.GetAPIError(apierr.ErrMalformedXML)
return nil, apiErrors.GetAPIError(apiErrors.ErrMalformedXML)
}
retentionDate, err := time.Parse(time.RFC3339, retention.RetainUntilDate)
if err != nil {
return nil, apierr.GetAPIError(apierr.ErrMalformedXML)
return nil, apiErrors.GetAPIError(apiErrors.ErrMalformedXML)
}
if retentionDate.Before(layer.TimeNow(ctx)) {
return nil, apierr.GetAPIError(apierr.ErrPastObjectLockRetainDate)
return nil, apiErrors.GetAPIError(apiErrors.ErrPastObjectLockRetainDate)
}
var bypass bool

View file

@ -12,7 +12,7 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
apiErrors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer/encryption"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
"github.com/stretchr/testify/require"
@ -270,23 +270,23 @@ func TestPutBucketLockConfigurationHandler(t *testing.T) {
for _, tc := range []struct {
name string
bucket string
expectedError apierr.Error
expectedError apiErrors.Error
noError bool
configuration *data.ObjectLockConfiguration
}{
{
name: "bkt not found",
expectedError: apierr.GetAPIError(apierr.ErrNoSuchBucket),
expectedError: apiErrors.GetAPIError(apiErrors.ErrNoSuchBucket),
},
{
name: "bkt lock disabled",
bucket: bktLockDisabled,
expectedError: apierr.GetAPIError(apierr.ErrObjectLockConfigurationNotAllowed),
expectedError: apiErrors.GetAPIError(apiErrors.ErrObjectLockConfigurationNotAllowed),
},
{
name: "invalid configuration",
bucket: bktLockEnabled,
expectedError: apierr.GetAPIError(apierr.ErrInternalError),
expectedError: apiErrors.GetAPIError(apiErrors.ErrInternalError),
configuration: &data.ObjectLockConfiguration{ObjectLockEnabled: "dummy"},
},
{
@ -359,18 +359,18 @@ func TestGetBucketLockConfigurationHandler(t *testing.T) {
for _, tc := range []struct {
name string
bucket string
expectedError apierr.Error
expectedError apiErrors.Error
noError bool
expectedConf *data.ObjectLockConfiguration
}{
{
name: "bkt not found",
expectedError: apierr.GetAPIError(apierr.ErrNoSuchBucket),
expectedError: apiErrors.GetAPIError(apiErrors.ErrNoSuchBucket),
},
{
name: "bkt lock disabled",
bucket: bktLockDisabled,
expectedError: apierr.GetAPIError(apierr.ErrObjectLockConfigurationNotFound),
expectedError: apiErrors.GetAPIError(apiErrors.ErrObjectLockConfigurationNotFound),
},
{
name: "bkt lock enabled empty default",
@ -407,7 +407,7 @@ func TestGetBucketLockConfigurationHandler(t *testing.T) {
}
}
func assertS3Error(t *testing.T, w *httptest.ResponseRecorder, expectedError apierr.Error) {
func assertS3Error(t *testing.T, w *httptest.ResponseRecorder, expectedError apiErrors.Error) {
actualErrorResponse := &middleware.ErrorResponse{}
err := xml.NewDecoder(w.Result().Body).Decode(actualErrorResponse)
require.NoError(t, err)
@ -415,7 +415,7 @@ func assertS3Error(t *testing.T, w *httptest.ResponseRecorder, expectedError api
require.Equal(t, expectedError.HTTPStatusCode, w.Code)
require.Equal(t, expectedError.Code, actualErrorResponse.Code)
if expectedError.ErrCode != apierr.ErrInternalError {
if expectedError.ErrCode != apiErrors.ErrInternalError {
require.Equal(t, expectedError.Description, actualErrorResponse.Message)
}
}
@ -473,33 +473,33 @@ func TestObjectRetention(t *testing.T) {
objName := "obj-for-retention"
createTestObject(hc, bktInfo, objName, encryption.Params{})
getObjectRetention(hc, bktName, objName, nil, apierr.ErrNoSuchKey)
getObjectRetention(hc, bktName, objName, nil, apiErrors.ErrNoSuchKey)
retention := &data.Retention{Mode: governanceMode, RetainUntilDate: time.Now().Add(time.Minute).UTC().Format(time.RFC3339)}
putObjectRetention(hc, bktName, objName, retention, false, 0)
getObjectRetention(hc, bktName, objName, retention, 0)
retention = &data.Retention{Mode: governanceMode, RetainUntilDate: time.Now().UTC().Add(time.Minute).Format(time.RFC3339)}
putObjectRetention(hc, bktName, objName, retention, false, apierr.ErrInternalError)
putObjectRetention(hc, bktName, objName, retention, false, apiErrors.ErrInternalError)
retention = &data.Retention{Mode: complianceMode, RetainUntilDate: time.Now().Add(time.Minute).UTC().Format(time.RFC3339)}
putObjectRetention(hc, bktName, objName, retention, true, 0)
getObjectRetention(hc, bktName, objName, retention, 0)
putObjectRetention(hc, bktName, objName, retention, true, apierr.ErrInternalError)
putObjectRetention(hc, bktName, objName, retention, true, apiErrors.ErrInternalError)
}
func getObjectRetention(hc *handlerContext, bktName, objName string, retention *data.Retention, errCode apierr.ErrorCode) {
func getObjectRetention(hc *handlerContext, bktName, objName string, retention *data.Retention, errCode apiErrors.ErrorCode) {
w, r := prepareTestRequest(hc, bktName, objName, nil)
hc.Handler().GetObjectRetentionHandler(w, r)
if errCode == 0 {
assertRetention(hc.t, w, retention)
} else {
assertS3Error(hc.t, w, apierr.GetAPIError(errCode))
assertS3Error(hc.t, w, apiErrors.GetAPIError(errCode))
}
}
func putObjectRetention(hc *handlerContext, bktName, objName string, retention *data.Retention, byPass bool, errCode apierr.ErrorCode) {
func putObjectRetention(hc *handlerContext, bktName, objName string, retention *data.Retention, byPass bool, errCode apiErrors.ErrorCode) {
w, r := prepareTestRequest(hc, bktName, objName, retention)
if byPass {
r.Header.Set(api.AmzBypassGovernanceRetention, strconv.FormatBool(true))
@ -508,7 +508,7 @@ func putObjectRetention(hc *handlerContext, bktName, objName string, retention *
if errCode == 0 {
assertStatus(hc.t, w, http.StatusOK)
} else {
assertS3Error(hc.t, w, apierr.GetAPIError(errCode))
assertS3Error(hc.t, w, apiErrors.GetAPIError(errCode))
}
}
@ -572,37 +572,37 @@ func TestPutLockErrors(t *testing.T) {
createTestBucketWithLock(hc, bktName, nil)
headers := map[string]string{api.AmzObjectLockMode: complianceMode}
putObjectWithLockFailed(t, hc, bktName, objName, headers, apierr.ErrObjectLockInvalidHeaders)
putObjectWithLockFailed(t, hc, bktName, objName, headers, apiErrors.ErrObjectLockInvalidHeaders)
delete(headers, api.AmzObjectLockMode)
headers[api.AmzObjectLockRetainUntilDate] = time.Now().Add(time.Minute).Format(time.RFC3339)
putObjectWithLockFailed(t, hc, bktName, objName, headers, apierr.ErrObjectLockInvalidHeaders)
putObjectWithLockFailed(t, hc, bktName, objName, headers, apiErrors.ErrObjectLockInvalidHeaders)
headers[api.AmzObjectLockMode] = "dummy"
putObjectWithLockFailed(t, hc, bktName, objName, headers, apierr.ErrUnknownWORMModeDirective)
putObjectWithLockFailed(t, hc, bktName, objName, headers, apiErrors.ErrUnknownWORMModeDirective)
headers[api.AmzObjectLockMode] = complianceMode
headers[api.AmzObjectLockRetainUntilDate] = time.Now().Format(time.RFC3339)
putObjectWithLockFailed(t, hc, bktName, objName, headers, apierr.ErrPastObjectLockRetainDate)
putObjectWithLockFailed(t, hc, bktName, objName, headers, apiErrors.ErrPastObjectLockRetainDate)
headers[api.AmzObjectLockRetainUntilDate] = "dummy"
putObjectWithLockFailed(t, hc, bktName, objName, headers, apierr.ErrInvalidRetentionDate)
putObjectWithLockFailed(t, hc, bktName, objName, headers, apiErrors.ErrInvalidRetentionDate)
putObject(hc, bktName, objName)
retention := &data.Retention{Mode: governanceMode}
putObjectRetentionFailed(t, hc, bktName, objName, retention, apierr.ErrMalformedXML)
putObjectRetentionFailed(t, hc, bktName, objName, retention, apiErrors.ErrMalformedXML)
retention.Mode = "dummy"
retention.RetainUntilDate = time.Now().Add(time.Minute).UTC().Format(time.RFC3339)
putObjectRetentionFailed(t, hc, bktName, objName, retention, apierr.ErrMalformedXML)
putObjectRetentionFailed(t, hc, bktName, objName, retention, apiErrors.ErrMalformedXML)
retention.Mode = governanceMode
retention.RetainUntilDate = time.Now().UTC().Format(time.RFC3339)
putObjectRetentionFailed(t, hc, bktName, objName, retention, apierr.ErrPastObjectLockRetainDate)
putObjectRetentionFailed(t, hc, bktName, objName, retention, apiErrors.ErrPastObjectLockRetainDate)
}
func putObjectWithLockFailed(t *testing.T, hc *handlerContext, bktName, objName string, headers map[string]string, errCode apierr.ErrorCode) {
func putObjectWithLockFailed(t *testing.T, hc *handlerContext, bktName, objName string, headers map[string]string, errCode apiErrors.ErrorCode) {
w, r := prepareTestRequest(hc, bktName, objName, nil)
for key, val := range headers {
@ -610,13 +610,13 @@ func putObjectWithLockFailed(t *testing.T, hc *handlerContext, bktName, objName
}
hc.Handler().PutObjectHandler(w, r)
assertS3Error(t, w, apierr.GetAPIError(errCode))
assertS3Error(t, w, apiErrors.GetAPIError(errCode))
}
func putObjectRetentionFailed(t *testing.T, hc *handlerContext, bktName, objName string, retention *data.Retention, errCode apierr.ErrorCode) {
func putObjectRetentionFailed(t *testing.T, hc *handlerContext, bktName, objName string, retention *data.Retention, errCode apiErrors.ErrorCode) {
w, r := prepareTestRequest(hc, bktName, objName, retention)
hc.Handler().PutObjectRetentionHandler(w, r)
assertS3Error(t, w, apierr.GetAPIError(errCode))
assertS3Error(t, w, apiErrors.GetAPIError(errCode))
}
func assertRetentionApproximate(t *testing.T, w *httptest.ResponseRecorder, retention *data.Retention, delta float64) {

View file

@ -7,7 +7,6 @@ import (
"net/url"
"path"
"strconv"
"strings"
"time"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
@ -104,20 +103,19 @@ const (
)
func (h *handler) CreateMultipartUploadHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
uploadID := uuid.New()
cannedACLStatus := aclHeadersStatus(r)
additional := []zap.Field{zap.String("uploadID", uploadID.String())}
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
if cannedACLStatus == aclStatusYes {
h.logAndSendError(ctx, w, "acl not supported for this bucket", reqInfo, errors.GetAPIError(errors.ErrAccessControlListNotSupported))
h.logAndSendError(w, "acl not supported for this bucket", reqInfo, errors.GetAPIError(errors.ErrAccessControlListNotSupported))
return
}
@ -133,14 +131,14 @@ func (h *handler) CreateMultipartUploadHandler(w http.ResponseWriter, r *http.Re
if len(r.Header.Get(api.AmzTagging)) > 0 {
p.Data.TagSet, err = parseTaggingHeader(r.Header)
if err != nil {
h.logAndSendError(ctx, w, "could not parse tagging", reqInfo, err, additional...)
h.logAndSendError(w, "could not parse tagging", reqInfo, err, additional...)
return
}
}
p.Info.Encryption, err = formEncryptionParams(r)
if err != nil {
h.logAndSendError(ctx, w, "invalid sse headers", reqInfo, err, additional...)
h.logAndSendError(w, "invalid sse headers", reqInfo, err, additional...)
return
}
@ -154,12 +152,12 @@ func (h *handler) CreateMultipartUploadHandler(w http.ResponseWriter, r *http.Re
p.CopiesNumbers, err = h.pickCopiesNumbers(p.Header, reqInfo.Namespace, bktInfo.LocationConstraint)
if err != nil {
h.logAndSendError(ctx, w, "invalid copies number", reqInfo, err, additional...)
h.logAndSendError(w, "invalid copies number", reqInfo, err, additional...)
return
}
if err = h.obj.CreateMultipartUpload(ctx, p); err != nil {
h.logAndSendError(ctx, w, "could create multipart upload", reqInfo, err, additional...)
if err = h.obj.CreateMultipartUpload(r.Context(), p); err != nil {
h.logAndSendError(w, "could create multipart upload", reqInfo, err, additional...)
return
}
@ -174,18 +172,17 @@ func (h *handler) CreateMultipartUploadHandler(w http.ResponseWriter, r *http.Re
}
if err = middleware.EncodeToResponse(w, resp); err != nil {
h.logAndSendError(ctx, w, "could not encode InitiateMultipartUploadResponse to response", reqInfo, err, additional...)
h.logAndSendError(w, "could not encode InitiateMultipartUploadResponse to response", reqInfo, err, additional...)
return
}
}
func (h *handler) UploadPartHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
@ -198,17 +195,20 @@ func (h *handler) UploadPartHandler(w http.ResponseWriter, r *http.Request) {
partNumber, err := strconv.Atoi(partNumStr)
if err != nil || partNumber < layer.UploadMinPartNumber || partNumber > layer.UploadMaxPartNumber {
h.logAndSendError(ctx, w, "invalid part number", reqInfo, errors.GetAPIError(errors.ErrInvalidPartNumber), additional...)
h.logAndSendError(w, "invalid part number", reqInfo, errors.GetAPIError(errors.ErrInvalidPartNumber), additional...)
return
}
body, err := h.getBodyReader(r)
if err != nil {
h.logAndSendError(ctx, w, "failed to get body reader", reqInfo, err, additional...)
h.logAndSendError(w, "failed to get body reader", reqInfo, err, additional...)
return
}
size := h.getPutPayloadSize(r)
var size uint64
if r.ContentLength > 0 {
size = uint64(r.ContentLength)
}
p := &layer.UploadPartParams{
Info: &layer.UploadInfoParams{
@ -225,13 +225,13 @@ func (h *handler) UploadPartHandler(w http.ResponseWriter, r *http.Request) {
p.Info.Encryption, err = formEncryptionParams(r)
if err != nil {
h.logAndSendError(ctx, w, "invalid sse headers", reqInfo, err, additional...)
h.logAndSendError(w, "invalid sse headers", reqInfo, err, additional...)
return
}
hash, err := h.obj.UploadPart(ctx, p)
hash, err := h.obj.UploadPart(r.Context(), p)
if err != nil {
h.logAndSendError(ctx, w, "could not upload a part", reqInfo, err, additional...)
h.logAndSendError(w, "could not upload a part", reqInfo, err, additional...)
return
}
@ -241,7 +241,7 @@ func (h *handler) UploadPartHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set(api.ETag, data.Quote(hash))
if err = middleware.WriteSuccessResponseHeadersOnly(w); err != nil {
h.logAndSendError(ctx, w, "write response", reqInfo, err)
h.logAndSendError(w, "write response", reqInfo, err)
return
}
}
@ -259,7 +259,7 @@ func (h *handler) UploadPartCopy(w http.ResponseWriter, r *http.Request) {
partNumber, err := strconv.Atoi(partNumStr)
if err != nil || partNumber < layer.UploadMinPartNumber || partNumber > layer.UploadMaxPartNumber {
h.logAndSendError(ctx, w, "invalid part number", reqInfo, errors.GetAPIError(errors.ErrInvalidPartNumber), additional...)
h.logAndSendError(w, "invalid part number", reqInfo, errors.GetAPIError(errors.ErrInvalidPartNumber), additional...)
return
}
@ -270,26 +270,26 @@ func (h *handler) UploadPartCopy(w http.ResponseWriter, r *http.Request) {
}
srcBucket, srcObject, err := path2BucketObject(src)
if err != nil {
h.logAndSendError(ctx, w, "invalid source copy", reqInfo, err, additional...)
h.logAndSendError(w, "invalid source copy", reqInfo, err, additional...)
return
}
srcRange, err := parseRange(r.Header.Get(api.AmzCopySourceRange))
if err != nil {
h.logAndSendError(ctx, w, "could not parse copy range", reqInfo,
h.logAndSendError(w, "could not parse copy range", reqInfo,
errors.GetAPIError(errors.ErrInvalidCopyPartRange), additional...)
return
}
srcBktInfo, err := h.getBucketAndCheckOwner(r, srcBucket, api.AmzSourceExpectedBucketOwner)
if err != nil {
h.logAndSendError(ctx, w, "could not get source bucket info", reqInfo, err, additional...)
h.logAndSendError(w, "could not get source bucket info", reqInfo, err, additional...)
return
}
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get target bucket info", reqInfo, err, additional...)
h.logAndSendError(w, "could not get target bucket info", reqInfo, err, additional...)
return
}
@ -302,35 +302,35 @@ func (h *handler) UploadPartCopy(w http.ResponseWriter, r *http.Request) {
srcInfo, err := h.obj.GetObjectInfo(ctx, headPrm)
if err != nil {
if errors.IsS3Error(err, errors.ErrNoSuchKey) && versionID != "" {
h.logAndSendError(ctx, w, "could not head source object version", reqInfo,
h.logAndSendError(w, "could not head source object version", reqInfo,
errors.GetAPIError(errors.ErrBadRequest), additional...)
return
}
h.logAndSendError(ctx, w, "could not head source object", reqInfo, err, additional...)
h.logAndSendError(w, "could not head source object", reqInfo, err, additional...)
return
}
args, err := parseCopyObjectArgs(r.Header)
if err != nil {
h.logAndSendError(ctx, w, "could not parse copy object args", reqInfo,
h.logAndSendError(w, "could not parse copy object args", reqInfo,
errors.GetAPIError(errors.ErrInvalidCopyPartRange), additional...)
return
}
if err = checkPreconditions(srcInfo, args.Conditional, h.cfg.MD5Enabled()); err != nil {
h.logAndSendError(ctx, w, "precondition failed", reqInfo, errors.GetAPIError(errors.ErrPreconditionFailed),
h.logAndSendError(w, "precondition failed", reqInfo, errors.GetAPIError(errors.ErrPreconditionFailed),
additional...)
return
}
srcEncryptionParams, err := formCopySourceEncryptionParams(r)
if err != nil {
h.logAndSendError(ctx, w, "invalid sse headers", reqInfo, err)
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return
}
if err = srcEncryptionParams.MatchObjectEncryption(layer.FormEncryptionInfo(srcInfo.Headers)); err != nil {
h.logAndSendError(ctx, w, "encryption doesn't match object", reqInfo, fmt.Errorf("%w: %s", errors.GetAPIError(errors.ErrBadRequest), err), additional...)
h.logAndSendError(w, "encryption doesn't match object", reqInfo, fmt.Errorf("%w: %s", errors.GetAPIError(errors.ErrBadRequest), err), additional...)
return
}
@ -350,13 +350,13 @@ func (h *handler) UploadPartCopy(w http.ResponseWriter, r *http.Request) {
p.Info.Encryption, err = formEncryptionParams(r)
if err != nil {
h.logAndSendError(ctx, w, "invalid sse headers", reqInfo, err, additional...)
h.logAndSendError(w, "invalid sse headers", reqInfo, err, additional...)
return
}
info, err := h.obj.UploadPartCopy(ctx, p)
if err != nil {
h.logAndSendError(ctx, w, "could not upload part copy", reqInfo, err, additional...)
h.logAndSendError(w, "could not upload part copy", reqInfo, err, additional...)
return
}
@ -370,23 +370,22 @@ func (h *handler) UploadPartCopy(w http.ResponseWriter, r *http.Request) {
}
if err = middleware.EncodeToResponse(w, response); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err, additional...)
h.logAndSendError(w, "something went wrong", reqInfo, err, additional...)
}
}
func (h *handler) CompleteMultipartUploadHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
settings, err := h.obj.GetBucketSettings(ctx, bktInfo)
settings, err := h.obj.GetBucketSettings(r.Context(), bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket settings", reqInfo, err)
h.logAndSendError(w, "could not get bucket settings", reqInfo, err)
return
}
@ -402,12 +401,12 @@ func (h *handler) CompleteMultipartUploadHandler(w http.ResponseWriter, r *http.
reqBody := new(CompleteMultipartUpload)
if err = h.cfg.NewXMLDecoder(r.Body).Decode(reqBody); err != nil {
h.logAndSendError(ctx, w, "could not read complete multipart upload xml", reqInfo,
h.logAndSendError(w, "could not read complete multipart upload xml", reqInfo,
fmt.Errorf("%w: %s", errors.GetAPIError(errors.ErrMalformedXML), err.Error()), additional...)
return
}
if len(reqBody.Parts) == 0 {
h.logAndSendError(ctx, w, "invalid xml with parts", reqInfo, errors.GetAPIError(errors.ErrMalformedXML), additional...)
h.logAndSendError(w, "invalid xml with parts", reqInfo, errors.GetAPIError(errors.ErrMalformedXML), additional...)
return
}
@ -421,7 +420,7 @@ func (h *handler) CompleteMultipartUploadHandler(w http.ResponseWriter, r *http.
objInfo, err := h.completeMultipartUpload(r, c, bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "complete multipart error", reqInfo, err, additional...)
h.logAndSendError(w, "complete multipart error", reqInfo, err, additional...)
return
}
@ -437,7 +436,7 @@ func (h *handler) CompleteMultipartUploadHandler(w http.ResponseWriter, r *http.
}
if err = middleware.EncodeToResponse(w, response); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err, additional...)
h.logAndSendError(w, "something went wrong", reqInfo, err, additional...)
}
}
@ -496,12 +495,11 @@ func (h *handler) completeMultipartUpload(r *http.Request, c *layer.CompleteMult
}
func (h *handler) ListMultipartUploadsHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
@ -514,7 +512,7 @@ func (h *handler) ListMultipartUploadsHandler(w http.ResponseWriter, r *http.Req
if maxUploadsStr != "" {
val, err := strconv.Atoi(maxUploadsStr)
if err != nil || val < 1 || val > 1000 {
h.logAndSendError(ctx, w, "invalid maxUploads", reqInfo, errors.GetAPIError(errors.ErrInvalidMaxUploads))
h.logAndSendError(w, "invalid maxUploads", reqInfo, errors.GetAPIError(errors.ErrInvalidMaxUploads))
return
}
maxUploads = val
@ -530,29 +528,23 @@ func (h *handler) ListMultipartUploadsHandler(w http.ResponseWriter, r *http.Req
UploadIDMarker: queryValues.Get(uploadIDMarkerQueryName),
}
if p.EncodingType != "" && strings.ToLower(p.EncodingType) != urlEncodingType {
h.logAndSendError(ctx, w, "invalid encoding type", reqInfo, errors.GetAPIError(errors.ErrInvalidEncodingMethod))
return
}
list, err := h.obj.ListMultipartUploads(ctx, p)
list, err := h.obj.ListMultipartUploads(r.Context(), p)
if err != nil {
h.logAndSendError(ctx, w, "could not list multipart uploads", reqInfo, err)
h.logAndSendError(w, "could not list multipart uploads", reqInfo, err)
return
}
if err = middleware.EncodeToResponse(w, encodeListMultipartUploadsToResponse(list, p)); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
}
}
func (h *handler) ListPartsHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
@ -561,14 +553,14 @@ func (h *handler) ListPartsHandler(w http.ResponseWriter, r *http.Request) {
queryValues = reqInfo.URL.Query()
uploadID = queryValues.Get(uploadIDHeaderName)
additional = []zap.Field{zap.String("uploadID", uploadID)}
additional = []zap.Field{zap.String("uploadID", uploadID), zap.String("Key", reqInfo.ObjectName)}
maxParts = layer.MaxSizePartsList
)
if queryValues.Get("max-parts") != "" {
val, err := strconv.Atoi(queryValues.Get("max-parts"))
if err != nil || val < 0 {
h.logAndSendError(ctx, w, "invalid MaxParts", reqInfo, errors.GetAPIError(errors.ErrInvalidMaxParts), additional...)
h.logAndSendError(w, "invalid MaxParts", reqInfo, errors.GetAPIError(errors.ErrInvalidMaxParts), additional...)
return
}
if val < layer.MaxSizePartsList {
@ -578,7 +570,7 @@ func (h *handler) ListPartsHandler(w http.ResponseWriter, r *http.Request) {
if queryValues.Get("part-number-marker") != "" {
if partNumberMarker, err = strconv.Atoi(queryValues.Get("part-number-marker")); err != nil || partNumberMarker < 0 {
h.logAndSendError(ctx, w, "invalid PartNumberMarker", reqInfo, err, additional...)
h.logAndSendError(w, "invalid PartNumberMarker", reqInfo, err, additional...)
return
}
}
@ -595,33 +587,32 @@ func (h *handler) ListPartsHandler(w http.ResponseWriter, r *http.Request) {
p.Info.Encryption, err = formEncryptionParams(r)
if err != nil {
h.logAndSendError(ctx, w, "invalid sse headers", reqInfo, err)
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return
}
list, err := h.obj.ListParts(ctx, p)
list, err := h.obj.ListParts(r.Context(), p)
if err != nil {
h.logAndSendError(ctx, w, "could not list parts", reqInfo, err, additional...)
h.logAndSendError(w, "could not list parts", reqInfo, err, additional...)
return
}
if err = middleware.EncodeToResponse(w, encodeListPartsToResponse(list, p)); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
}
}
func (h *handler) AbortMultipartUploadHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
uploadID := reqInfo.URL.Query().Get(uploadIDHeaderName)
additional := []zap.Field{zap.String("uploadID", uploadID)}
additional := []zap.Field{zap.String("uploadID", uploadID), zap.String("Key", reqInfo.ObjectName)}
p := &layer.UploadInfoParams{
UploadID: uploadID,
@ -631,12 +622,12 @@ func (h *handler) AbortMultipartUploadHandler(w http.ResponseWriter, r *http.Req
p.Encryption, err = formEncryptionParams(r)
if err != nil {
h.logAndSendError(ctx, w, "invalid sse headers", reqInfo, err)
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return
}
if err = h.obj.AbortMultipartUpload(ctx, p); err != nil {
h.logAndSendError(ctx, w, "could not abort multipart upload", reqInfo, err, additional...)
if err = h.obj.AbortMultipartUpload(r.Context(), p); err != nil {
h.logAndSendError(w, "could not abort multipart upload", reqInfo, err, additional...)
return
}
@ -647,14 +638,14 @@ func encodeListMultipartUploadsToResponse(info *layer.ListMultipartUploadsInfo,
res := ListMultipartUploadsResponse{
Bucket: params.Bkt.Name,
CommonPrefixes: fillPrefixes(info.Prefixes, params.EncodingType),
Delimiter: s3PathEncode(params.Delimiter, params.EncodingType),
Delimiter: params.Delimiter,
EncodingType: params.EncodingType,
IsTruncated: info.IsTruncated,
KeyMarker: s3PathEncode(params.KeyMarker, params.EncodingType),
KeyMarker: params.KeyMarker,
MaxUploads: params.MaxUploads,
NextKeyMarker: s3PathEncode(info.NextKeyMarker, params.EncodingType),
NextKeyMarker: info.NextKeyMarker,
NextUploadIDMarker: info.NextUploadIDMarker,
Prefix: s3PathEncode(params.Prefix, params.EncodingType),
Prefix: params.Prefix,
UploadIDMarker: params.UploadIDMarker,
}
@ -666,7 +657,7 @@ func encodeListMultipartUploadsToResponse(info *layer.ListMultipartUploadsInfo,
ID: u.Owner.String(),
DisplayName: u.Owner.String(),
},
Key: s3PathEncode(u.Key, params.EncodingType),
Key: u.Key,
Owner: Owner{
ID: u.Owner.String(),
DisplayName: u.Owner.String(),

View file

@ -2,21 +2,19 @@ package handler
import (
"crypto/md5"
"crypto/rand"
"crypto/tls"
"encoding/base64"
"encoding/hex"
"encoding/xml"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
s3Errors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer/encryption"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
@ -41,7 +39,7 @@ func TestMultipartUploadInvalidPart(t *testing.T) {
etag1, _ := uploadPart(hc, bktName, objName, multipartUpload.UploadID, 1, partSize)
etag2, _ := uploadPart(hc, bktName, objName, multipartUpload.UploadID, 2, partSize)
w := completeMultipartUploadBase(hc, bktName, objName, multipartUpload.UploadID, []string{etag1, etag2})
assertS3Error(hc.t, w, apierr.GetAPIError(apierr.ErrEntityTooSmall))
assertS3Error(hc.t, w, s3Errors.GetAPIError(s3Errors.ErrEntityTooSmall))
}
func TestDeleteMultipartAllParts(t *testing.T) {
@ -105,7 +103,7 @@ func TestMultipartReUploadPart(t *testing.T) {
require.Equal(t, etag2, list.Parts[1].ETag)
w := completeMultipartUploadBase(hc, bktName, objName, uploadInfo.UploadID, []string{etag1, etag2})
assertS3Error(hc.t, w, apierr.GetAPIError(apierr.ErrEntityTooSmall))
assertS3Error(hc.t, w, s3Errors.GetAPIError(s3Errors.ErrEntityTooSmall))
etag1, data1 := uploadPart(hc, bktName, objName, uploadInfo.UploadID, 1, partSizeFirst)
etag2, data2 := uploadPart(hc, bktName, objName, uploadInfo.UploadID, 2, partSizeLast)
@ -253,14 +251,14 @@ func TestListMultipartUploads(t *testing.T) {
})
t.Run("check max uploads", func(t *testing.T) {
listUploads := listMultipartUploads(hc, bktName, "", "", "", "", 2)
listUploads := listMultipartUploadsBase(hc, bktName, "", "", "", "", 2)
require.Len(t, listUploads.Uploads, 2)
require.Equal(t, uploadInfo1.UploadID, listUploads.Uploads[0].UploadID)
require.Equal(t, uploadInfo2.UploadID, listUploads.Uploads[1].UploadID)
})
t.Run("check prefix", func(t *testing.T) {
listUploads := listMultipartUploads(hc, bktName, "/my", "", "", "", -1)
listUploads := listMultipartUploadsBase(hc, bktName, "/my", "", "", "", -1)
require.Len(t, listUploads.Uploads, 2)
require.Equal(t, uploadInfo1.UploadID, listUploads.Uploads[0].UploadID)
require.Equal(t, uploadInfo2.UploadID, listUploads.Uploads[1].UploadID)
@ -268,7 +266,7 @@ func TestListMultipartUploads(t *testing.T) {
t.Run("check markers", func(t *testing.T) {
t.Run("check only key-marker", func(t *testing.T) {
listUploads := listMultipartUploads(hc, bktName, "", "", "", objName2, -1)
listUploads := listMultipartUploadsBase(hc, bktName, "", "", "", objName2, -1)
require.Len(t, listUploads.Uploads, 1)
// If upload-id-marker is not specified, only the keys lexicographically greater than the specified key-marker will be included in the list.
require.Equal(t, uploadInfo3.UploadID, listUploads.Uploads[0].UploadID)
@ -279,7 +277,7 @@ func TestListMultipartUploads(t *testing.T) {
if uploadIDMarker > uploadInfo2.UploadID {
uploadIDMarker = uploadInfo2.UploadID
}
listUploads := listMultipartUploads(hc, bktName, "", "", uploadIDMarker, "", -1)
listUploads := listMultipartUploadsBase(hc, bktName, "", "", uploadIDMarker, "", -1)
// If key-marker is not specified, the upload-id-marker parameter is ignored.
require.Len(t, listUploads.Uploads, 3)
})
@ -287,7 +285,7 @@ func TestListMultipartUploads(t *testing.T) {
t.Run("check key-marker along with upload-id-marker", func(t *testing.T) {
uploadIDMarker := "00000000-0000-0000-0000-000000000000"
listUploads := listMultipartUploads(hc, bktName, "", "", uploadIDMarker, objName3, -1)
listUploads := listMultipartUploadsBase(hc, bktName, "", "", uploadIDMarker, objName3, -1)
require.Len(t, listUploads.Uploads, 1)
// If upload-id-marker is specified, any multipart uploads for a key equal to the key-marker might also be included,
// provided those multipart uploads have upload IDs lexicographically greater than the specified upload-id-marker.
@ -522,7 +520,7 @@ func TestUploadPartCheckContentSHA256(t *testing.T) {
r.Header.Set(api.AmzContentSha256, tc.hash)
hc.Handler().UploadPartHandler(w, r)
if tc.error {
assertS3Error(t, w, apierr.GetAPIError(apierr.ErrContentSHA256Mismatch))
assertS3Error(t, w, s3Errors.GetAPIError(s3Errors.ErrContentSHA256Mismatch))
list := listParts(hc, bktName, objName, multipartUpload.UploadID, "0", http.StatusOK)
require.Len(t, list.Parts, 1)
@ -623,73 +621,6 @@ func TestMultipartObjectLocation(t *testing.T) {
}
}
func TestUploadPartWithNegativeContentLength(t *testing.T) {
hc := prepareHandlerContext(t)
bktName, objName := "bucket-to-upload-part", "object-multipart"
createTestBucket(hc, bktName)
partSize := 5 * 1024 * 1024
multipartUpload := createMultipartUpload(hc, bktName, objName, map[string]string{})
partBody := make([]byte, partSize)
_, err := rand.Read(partBody)
require.NoError(hc.t, err)
query := make(url.Values)
query.Set(uploadIDQuery, multipartUpload.UploadID)
query.Set(partNumberQuery, "1")
w, r := prepareTestRequestWithQuery(hc, bktName, objName, query, partBody)
r.ContentLength = -1
hc.Handler().UploadPartHandler(w, r)
assertStatus(hc.t, w, http.StatusOK)
completeMultipartUpload(hc, bktName, objName, multipartUpload.UploadID, []string{w.Header().Get(api.ETag)})
res, _ := getObject(hc, bktName, objName)
equalDataSlices(t, partBody, res)
resp := getObjectAttributes(hc, bktName, objName, objectParts)
require.Len(t, resp.ObjectParts.Parts, 1)
require.Equal(t, partSize, resp.ObjectParts.Parts[0].Size)
}
func TestListMultipartUploadsEncoding(t *testing.T) {
hc := prepareHandlerContext(t)
bktName := "bucket-to-list-uploads-encoding"
createTestBucket(hc, bktName)
listAllMultipartUploadsErr(hc, bktName, "invalid", apierr.GetAPIError(apierr.ErrInvalidEncodingMethod))
objects := []string{"foo()/bar", "foo()/bar/xyzzy", "asdf+b"}
for _, objName := range objects {
createMultipartUpload(hc, bktName, objName, nil)
}
listResponse := listMultipartUploadsURL(hc, bktName, "foo(", ")", "", "", -1)
require.Len(t, listResponse.CommonPrefixes, 1)
require.Equal(t, "foo%28%29", listResponse.CommonPrefixes[0].Prefix)
require.Equal(t, "foo%28", listResponse.Prefix)
require.Equal(t, "%29", listResponse.Delimiter)
require.Equal(t, "url", listResponse.EncodingType)
require.Equal(t, maxObjectList, listResponse.MaxUploads)
listResponse = listMultipartUploads(hc, bktName, "", "", "", "", 1)
require.Empty(t, listResponse.EncodingType)
listResponse = listMultipartUploadsURL(hc, bktName, "", "", "", listResponse.NextKeyMarker, 1)
require.Len(t, listResponse.CommonPrefixes, 0)
require.Len(t, listResponse.Uploads, 1)
require.Equal(t, "foo%28%29/bar", listResponse.Uploads[0].Key)
require.Equal(t, "asdf%2Bb", listResponse.KeyMarker)
require.Equal(t, "foo%28%29/bar", listResponse.NextKeyMarker)
require.Equal(t, "url", listResponse.EncodingType)
require.Equal(t, 1, listResponse.MaxUploads)
}
func uploadPartCopy(hc *handlerContext, bktName, objName, uploadID string, num int, srcObj string, start, end int) *UploadPartCopyResponse {
return uploadPartCopyBase(hc, bktName, objName, false, uploadID, num, srcObj, start, end)
}
@ -715,42 +646,16 @@ func uploadPartCopyBase(hc *handlerContext, bktName, objName string, encrypted b
return uploadPartCopyResponse
}
func listMultipartUploads(hc *handlerContext, bktName, prefix, delimiter, uploadIDMarker, keyMarker string, maxUploads int) *ListMultipartUploadsResponse {
w := listMultipartUploadsBase(hc, bktName, prefix, delimiter, uploadIDMarker, keyMarker, "", maxUploads)
assertStatus(hc.t, w, http.StatusOK)
res := &ListMultipartUploadsResponse{}
parseTestResponse(hc.t, w, res)
return res
}
func listMultipartUploadsURL(hc *handlerContext, bktName, prefix, delimiter, uploadIDMarker, keyMarker string, maxUploads int) *ListMultipartUploadsResponse {
w := listMultipartUploadsBase(hc, bktName, prefix, delimiter, uploadIDMarker, keyMarker, urlEncodingType, maxUploads)
assertStatus(hc.t, w, http.StatusOK)
res := &ListMultipartUploadsResponse{}
parseTestResponse(hc.t, w, res)
return res
}
func listAllMultipartUploads(hc *handlerContext, bktName string) *ListMultipartUploadsResponse {
w := listMultipartUploadsBase(hc, bktName, "", "", "", "", "", -1)
assertStatus(hc.t, w, http.StatusOK)
res := &ListMultipartUploadsResponse{}
parseTestResponse(hc.t, w, res)
return res
return listMultipartUploadsBase(hc, bktName, "", "", "", "", -1)
}
func listAllMultipartUploadsErr(hc *handlerContext, bktName, encoding string, err apierr.Error) {
w := listMultipartUploadsBase(hc, bktName, "", "", "", "", encoding, -1)
assertS3Error(hc.t, w, err)
}
func listMultipartUploadsBase(hc *handlerContext, bktName, prefix, delimiter, uploadIDMarker, keyMarker, encoding string, maxUploads int) *httptest.ResponseRecorder {
func listMultipartUploadsBase(hc *handlerContext, bktName, prefix, delimiter, uploadIDMarker, keyMarker string, maxUploads int) *ListMultipartUploadsResponse {
query := make(url.Values)
query.Set(prefixQueryName, prefix)
query.Set(delimiterQueryName, delimiter)
query.Set(uploadIDMarkerQueryName, uploadIDMarker)
query.Set(keyMarkerQueryName, keyMarker)
query.Set(encodingTypeQueryName, encoding)
if maxUploads != -1 {
query.Set(maxUploadsQueryName, strconv.Itoa(maxUploads))
}
@ -758,7 +663,10 @@ func listMultipartUploadsBase(hc *handlerContext, bktName, prefix, delimiter, up
w, r := prepareTestRequestWithQuery(hc, bktName, "", query, nil)
hc.Handler().ListMultipartUploadsHandler(w, r)
return w
listPartsResponse := &ListMultipartUploadsResponse{}
readResponse(hc.t, w, http.StatusOK, listPartsResponse)
return listPartsResponse
}
func listParts(hc *handlerContext, bktName, objName string, uploadID, partNumberMarker string, status int) *ListPartsResponse {

View file

@ -8,5 +8,5 @@ import (
)
func (h *handler) DeleteBucketEncryptionHandler(w http.ResponseWriter, r *http.Request) {
h.logAndSendError(r.Context(), w, "not supported", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotSupported))
h.logAndSendError(w, "not supported", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotSupported))
}

View file

@ -4,7 +4,6 @@ import (
"net/http"
"net/url"
"strconv"
"strings"
"time"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
@ -17,27 +16,26 @@ import (
// ListObjectsV1Handler handles objects listing requests for API version 1.
func (h *handler) ListObjectsV1Handler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
params, err := parseListObjectsArgsV1(reqInfo)
if err != nil {
h.logAndSendError(ctx, w, "failed to parse arguments", reqInfo, err)
h.logAndSendError(w, "failed to parse arguments", reqInfo, err)
return
}
if params.BktInfo, err = h.getBucketAndCheckOwner(r, reqInfo.BucketName); err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
list, err := h.obj.ListObjectsV1(ctx, params)
list, err := h.obj.ListObjectsV1(r.Context(), params)
if err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
return
}
if err = middleware.EncodeToResponse(w, h.encodeV1(params, list)); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
}
}
@ -62,27 +60,26 @@ func (h *handler) encodeV1(p *layer.ListObjectsParamsV1, list *layer.ListObjects
// ListObjectsV2Handler handles objects listing requests for API version 2.
func (h *handler) ListObjectsV2Handler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
params, err := parseListObjectsArgsV2(reqInfo)
if err != nil {
h.logAndSendError(ctx, w, "failed to parse arguments", reqInfo, err)
h.logAndSendError(w, "failed to parse arguments", reqInfo, err)
return
}
if params.BktInfo, err = h.getBucketAndCheckOwner(r, reqInfo.BucketName); err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
list, err := h.obj.ListObjectsV2(ctx, params)
list, err := h.obj.ListObjectsV2(r.Context(), params)
if err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
return
}
if err = middleware.EncodeToResponse(w, h.encodeV2(params, list)); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
}
}
@ -156,10 +153,6 @@ func parseListObjectArgs(reqInfo *middleware.ReqInfo) (*layer.ListObjectsParamsC
res.Delimiter = queryValues.Get("delimiter")
res.Encode = queryValues.Get("encoding-type")
if res.Encode != "" && strings.ToLower(res.Encode) != urlEncodingType {
return nil, errors.GetAPIError(errors.ErrInvalidEncodingMethod)
}
if queryValues.Get("max-keys") == "" {
res.MaxKeys = maxObjectList
} else if res.MaxKeys, err = strconv.Atoi(queryValues.Get("max-keys")); err != nil || res.MaxKeys < 0 {
@ -221,28 +214,27 @@ func fillContents(src []*data.ExtendedNodeVersion, encode string, fetchOwner, md
}
func (h *handler) ListBucketObjectVersionsHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
p, err := parseListObjectVersionsRequest(reqInfo)
if err != nil {
h.logAndSendError(ctx, w, "failed to parse request", reqInfo, err)
h.logAndSendError(w, "failed to parse request", reqInfo, err)
return
}
if p.BktInfo, err = h.getBucketAndCheckOwner(r, reqInfo.BucketName); err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
info, err := h.obj.ListObjectVersions(ctx, p)
info, err := h.obj.ListObjectVersions(r.Context(), p)
if err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
return
}
response := encodeListObjectVersionsToResponse(p, info, p.BktInfo.Name, h.cfg.MD5Enabled())
if err = middleware.EncodeToResponse(w, response); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
}
}
@ -265,10 +257,6 @@ func parseListObjectVersionsRequest(reqInfo *middleware.ReqInfo) (*layer.ListObj
res.Encode = queryValues.Get("encoding-type")
res.VersionIDMarker = queryValues.Get("version-id-marker")
if res.Encode != "" && strings.ToLower(res.Encode) != urlEncodingType {
return nil, errors.GetAPIError(errors.ErrInvalidEncodingMethod)
}
if res.VersionIDMarker != "" && res.KeyMarker == "" {
return nil, errors.GetAPIError(errors.VersionIDMarkerWithoutKeyMarker)
}

View file

@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"sort"
"strconv"
@ -15,7 +14,6 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/cache"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer/encryption"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/logs"
@ -757,16 +755,6 @@ func TestListObjectVersionsEncoding(t *testing.T) {
require.Equal(t, 3, listResponse.MaxKeys)
}
func TestListingsWithInvalidEncodingType(t *testing.T) {
hc := prepareHandlerContext(t)
bktName := "bucket-for-listing-invalid-encoding"
createTestBucket(hc, bktName)
listObjectsVersionsErr(hc, bktName, "invalid", apierr.GetAPIError(apierr.ErrInvalidEncodingMethod))
listObjectsV2Err(hc, bktName, "invalid", apierr.GetAPIError(apierr.ErrInvalidEncodingMethod))
listObjectsV1Err(hc, bktName, "invalid", apierr.GetAPIError(apierr.ErrInvalidEncodingMethod))
}
func checkVersionsNames(t *testing.T, versions *ListObjectsVersionsResponse, names []string) {
for i, v := range versions.Version {
require.Equal(t, names[i], v.Key)
@ -774,19 +762,10 @@ func checkVersionsNames(t *testing.T, versions *ListObjectsVersionsResponse, nam
}
func listObjectsV2(hc *handlerContext, bktName, prefix, delimiter, startAfter, continuationToken string, maxKeys int) *ListObjectsV2Response {
w := listObjectsV2Base(hc, bktName, prefix, delimiter, startAfter, continuationToken, "", maxKeys)
assertStatus(hc.t, w, http.StatusOK)
res := &ListObjectsV2Response{}
parseTestResponse(hc.t, w, res)
return res
return listObjectsV2Ext(hc, bktName, prefix, delimiter, startAfter, continuationToken, "", maxKeys)
}
func listObjectsV2Err(hc *handlerContext, bktName, encoding string, err apierr.Error) {
w := listObjectsV2Base(hc, bktName, "", "", "", "", encoding, -1)
assertS3Error(hc.t, w, err)
}
func listObjectsV2Base(hc *handlerContext, bktName, prefix, delimiter, startAfter, continuationToken, encodingType string, maxKeys int) *httptest.ResponseRecorder {
func listObjectsV2Ext(hc *handlerContext, bktName, prefix, delimiter, startAfter, continuationToken, encodingType string, maxKeys int) *ListObjectsV2Response {
query := prepareCommonListObjectsQuery(prefix, delimiter, maxKeys)
query.Add("fetch-owner", "true")
if len(startAfter) != 0 {
@ -801,7 +780,10 @@ func listObjectsV2Base(hc *handlerContext, bktName, prefix, delimiter, startAfte
w, r := prepareTestFullRequest(hc, bktName, "", query, nil)
hc.Handler().ListObjectsV2Handler(w, r)
return w
assertStatus(hc.t, w, http.StatusOK)
res := &ListObjectsV2Response{}
parseTestResponse(hc.t, w, res)
return res
}
func validateListV1(t *testing.T, tc *handlerContext, bktName, prefix, delimiter, marker string, maxKeys int,
@ -861,54 +843,28 @@ func prepareCommonListObjectsQuery(prefix, delimiter string, maxKeys int) url.Va
}
func listObjectsV1(hc *handlerContext, bktName, prefix, delimiter, marker string, maxKeys int) *ListObjectsV1Response {
w := listObjectsV1Base(hc, bktName, prefix, delimiter, marker, "", maxKeys)
query := prepareCommonListObjectsQuery(prefix, delimiter, maxKeys)
if len(marker) != 0 {
query.Add("marker", marker)
}
w, r := prepareTestFullRequest(hc, bktName, "", query, nil)
hc.Handler().ListObjectsV1Handler(w, r)
assertStatus(hc.t, w, http.StatusOK)
res := &ListObjectsV1Response{}
parseTestResponse(hc.t, w, res)
return res
}
func listObjectsV1Err(hc *handlerContext, bktName, encoding string, err apierr.Error) {
w := listObjectsV1Base(hc, bktName, "", "", "", encoding, -1)
assertS3Error(hc.t, w, err)
}
func listObjectsV1Base(hc *handlerContext, bktName, prefix, delimiter, marker, encoding string, maxKeys int) *httptest.ResponseRecorder {
query := prepareCommonListObjectsQuery(prefix, delimiter, maxKeys)
if len(marker) != 0 {
query.Add("marker", marker)
}
if len(encoding) != 0 {
query.Add("encoding-type", encoding)
}
w, r := prepareTestFullRequest(hc, bktName, "", query, nil)
hc.Handler().ListObjectsV1Handler(w, r)
return w
}
func listObjectsVersions(hc *handlerContext, bktName, prefix, delimiter, keyMarker, versionIDMarker string, maxKeys int) *ListObjectsVersionsResponse {
w := listObjectsVersionsBase(hc, bktName, prefix, delimiter, keyMarker, versionIDMarker, "", maxKeys)
assertStatus(hc.t, w, http.StatusOK)
res := &ListObjectsVersionsResponse{}
parseTestResponse(hc.t, w, res)
return res
return listObjectsVersionsBase(hc, bktName, prefix, delimiter, keyMarker, versionIDMarker, maxKeys, false)
}
func listObjectsVersionsURL(hc *handlerContext, bktName, prefix, delimiter, keyMarker, versionIDMarker string, maxKeys int) *ListObjectsVersionsResponse {
w := listObjectsVersionsBase(hc, bktName, prefix, delimiter, keyMarker, versionIDMarker, urlEncodingType, maxKeys)
assertStatus(hc.t, w, http.StatusOK)
res := &ListObjectsVersionsResponse{}
parseTestResponse(hc.t, w, res)
return res
return listObjectsVersionsBase(hc, bktName, prefix, delimiter, keyMarker, versionIDMarker, maxKeys, true)
}
func listObjectsVersionsErr(hc *handlerContext, bktName, encoding string, err apierr.Error) {
w := listObjectsVersionsBase(hc, bktName, "", "", "", "", encoding, -1)
assertS3Error(hc.t, w, err)
}
func listObjectsVersionsBase(hc *handlerContext, bktName, prefix, delimiter, keyMarker, versionIDMarker, encoding string, maxKeys int) *httptest.ResponseRecorder {
func listObjectsVersionsBase(hc *handlerContext, bktName, prefix, delimiter, keyMarker, versionIDMarker string, maxKeys int, encode bool) *ListObjectsVersionsResponse {
query := prepareCommonListObjectsQuery(prefix, delimiter, maxKeys)
if len(keyMarker) != 0 {
query.Add("key-marker", keyMarker)
@ -916,11 +872,14 @@ func listObjectsVersionsBase(hc *handlerContext, bktName, prefix, delimiter, key
if len(versionIDMarker) != 0 {
query.Add("version-id-marker", versionIDMarker)
}
if len(encoding) != 0 {
query.Add("encoding-type", encoding)
if encode {
query.Add("encoding-type", "url")
}
w, r := prepareTestFullRequest(hc, bktName, "", query, nil)
hc.Handler().ListBucketObjectVersionsHandler(w, r)
return w
assertStatus(hc.t, w, http.StatusOK)
res := &ListObjectsVersionsResponse{}
parseTestResponse(hc.t, w, res)
return res
}

View file

@ -22,30 +22,30 @@ func (h *handler) PatchObjectHandler(w http.ResponseWriter, r *http.Request) {
reqInfo := middleware.GetReqInfo(ctx)
if _, ok := r.Header[api.ContentRange]; !ok {
h.logAndSendError(ctx, w, "missing Content-Range", reqInfo, errors.GetAPIError(errors.ErrMissingContentRange))
h.logAndSendError(w, "missing Content-Range", reqInfo, errors.GetAPIError(errors.ErrMissingContentRange))
return
}
if _, ok := r.Header[api.ContentLength]; !ok {
h.logAndSendError(ctx, w, "missing Content-Length", reqInfo, errors.GetAPIError(errors.ErrMissingContentLength))
h.logAndSendError(w, "missing Content-Length", reqInfo, errors.GetAPIError(errors.ErrMissingContentLength))
return
}
conditional, err := parsePatchConditionalHeaders(r.Header)
if err != nil {
h.logAndSendError(ctx, w, "could not parse conditional headers", reqInfo, err)
h.logAndSendError(w, "could not parse conditional headers", reqInfo, err)
return
}
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
settings, err := h.obj.GetBucketSettings(ctx, bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket settings", reqInfo, err)
h.logAndSendError(w, "could not get bucket settings", reqInfo, err)
return
}
@ -57,40 +57,40 @@ func (h *handler) PatchObjectHandler(w http.ResponseWriter, r *http.Request) {
extendedSrcObjInfo, err := h.obj.GetExtendedObjectInfo(ctx, srcObjPrm)
if err != nil {
h.logAndSendError(ctx, w, "could not find object", reqInfo, err)
h.logAndSendError(w, "could not find object", reqInfo, err)
return
}
srcObjInfo := extendedSrcObjInfo.ObjectInfo
if err = checkPreconditions(srcObjInfo, conditional, h.cfg.MD5Enabled()); err != nil {
h.logAndSendError(ctx, w, "precondition failed", reqInfo, err)
h.logAndSendError(w, "precondition failed", reqInfo, err)
return
}
srcSize, err := layer.GetObjectSize(srcObjInfo)
if err != nil {
h.logAndSendError(ctx, w, "failed to get source object size", reqInfo, err)
h.logAndSendError(w, "failed to get source object size", reqInfo, err)
return
}
byteRange, err := parsePatchByteRange(r.Header.Get(api.ContentRange), srcSize)
if err != nil {
h.logAndSendError(ctx, w, "could not parse byte range", reqInfo, errors.GetAPIError(errors.ErrInvalidRange), zap.Error(err))
h.logAndSendError(w, "could not parse byte range", reqInfo, errors.GetAPIError(errors.ErrInvalidRange), zap.Error(err))
return
}
if maxPatchSize < byteRange.End-byteRange.Start+1 {
h.logAndSendError(ctx, w, "byte range length is longer than allowed", reqInfo, errors.GetAPIError(errors.ErrInvalidRange), zap.Error(err))
h.logAndSendError(w, "byte range length is longer than allowed", reqInfo, errors.GetAPIError(errors.ErrInvalidRange), zap.Error(err))
return
}
if uint64(r.ContentLength) != (byteRange.End - byteRange.Start + 1) {
h.logAndSendError(ctx, w, "content-length must be equal to byte range length", reqInfo, errors.GetAPIError(errors.ErrInvalidRangeLength))
h.logAndSendError(w, "content-length must be equal to byte range length", reqInfo, errors.GetAPIError(errors.ErrInvalidRangeLength))
return
}
if byteRange.Start > srcSize {
h.logAndSendError(ctx, w, "start byte is greater than object size", reqInfo, errors.GetAPIError(errors.ErrRangeOutOfBounds))
h.logAndSendError(w, "start byte is greater than object size", reqInfo, errors.GetAPIError(errors.ErrRangeOutOfBounds))
return
}
@ -104,16 +104,16 @@ func (h *handler) PatchObjectHandler(w http.ResponseWriter, r *http.Request) {
params.CopiesNumbers, err = h.pickCopiesNumbers(nil, reqInfo.Namespace, bktInfo.LocationConstraint)
if err != nil {
h.logAndSendError(ctx, w, "invalid copies number", reqInfo, err)
h.logAndSendError(w, "invalid copies number", reqInfo, err)
return
}
extendedObjInfo, err := h.obj.PatchObject(ctx, params)
if err != nil {
if isErrObjectLocked(err) {
h.logAndSendError(ctx, w, "object is locked", reqInfo, errors.GetAPIError(errors.ErrAccessDenied))
h.logAndSendError(w, "object is locked", reqInfo, errors.GetAPIError(errors.ErrAccessDenied))
} else {
h.logAndSendError(ctx, w, "could not patch object", reqInfo, err)
h.logAndSendError(w, "could not patch object", reqInfo, err)
}
return
}
@ -132,7 +132,7 @@ func (h *handler) PatchObjectHandler(w http.ResponseWriter, r *http.Request) {
}
if err = middleware.EncodeToResponse(w, resp); err != nil {
h.logAndSendError(ctx, w, "could not encode PatchObjectResult to response", reqInfo, err)
h.logAndSendError(w, "could not encode PatchObjectResult to response", reqInfo, err)
return
}
}

View file

@ -18,7 +18,7 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
s3errors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer"
"github.com/stretchr/testify/require"
)
@ -50,7 +50,7 @@ func TestPatch(t *testing.T) {
name string
rng string
headers map[string]string
code apierr.ErrorCode
code s3errors.ErrorCode
}{
{
name: "success",
@ -63,22 +63,22 @@ func TestPatch(t *testing.T) {
{
name: "invalid range syntax",
rng: "bytes 0-2",
code: apierr.ErrInvalidRange,
code: s3errors.ErrInvalidRange,
},
{
name: "invalid range length",
rng: "bytes 0-5/*",
code: apierr.ErrInvalidRangeLength,
code: s3errors.ErrInvalidRangeLength,
},
{
name: "invalid range start",
rng: "bytes 20-22/*",
code: apierr.ErrRangeOutOfBounds,
code: s3errors.ErrRangeOutOfBounds,
},
{
name: "range is too long",
rng: "bytes 0-5368709120/*",
code: apierr.ErrInvalidRange,
code: s3errors.ErrInvalidRange,
},
{
name: "If-Unmodified-Since precondition are not satisfied",
@ -86,7 +86,7 @@ func TestPatch(t *testing.T) {
headers: map[string]string{
api.IfUnmodifiedSince: created.Add(-24 * time.Hour).Format(http.TimeFormat),
},
code: apierr.ErrPreconditionFailed,
code: s3errors.ErrPreconditionFailed,
},
{
name: "If-Match precondition are not satisfied",
@ -94,7 +94,7 @@ func TestPatch(t *testing.T) {
headers: map[string]string{
api.IfMatch: "etag",
},
code: apierr.ErrPreconditionFailed,
code: s3errors.ErrPreconditionFailed,
},
} {
t.Run(tt.name, func(t *testing.T) {
@ -377,7 +377,7 @@ func TestPatchEncryptedObject(t *testing.T) {
tc.Handler().PutObjectHandler(w, r)
assertStatus(t, w, http.StatusOK)
patchObjectErr(t, tc, bktName, objName, "bytes 2-4/*", []byte("new"), nil, apierr.ErrInternalError)
patchObjectErr(t, tc, bktName, objName, "bytes 2-4/*", []byte("new"), nil, s3errors.ErrInternalError)
}
func TestPatchMissingHeaders(t *testing.T) {
@ -393,13 +393,13 @@ func TestPatchMissingHeaders(t *testing.T) {
w = httptest.NewRecorder()
r = httptest.NewRequest(http.MethodPatch, defaultURL, strings.NewReader("new"))
tc.Handler().PatchObjectHandler(w, r)
assertS3Error(t, w, apierr.GetAPIError(apierr.ErrMissingContentRange))
assertS3Error(t, w, s3errors.GetAPIError(s3errors.ErrMissingContentRange))
w = httptest.NewRecorder()
r = httptest.NewRequest(http.MethodPatch, defaultURL, strings.NewReader("new"))
r.Header.Set(api.ContentRange, "bytes 0-2/*")
tc.Handler().PatchObjectHandler(w, r)
assertS3Error(t, w, apierr.GetAPIError(apierr.ErrMissingContentLength))
assertS3Error(t, w, s3errors.GetAPIError(s3errors.ErrMissingContentLength))
}
func TestParsePatchByteRange(t *testing.T) {
@ -501,9 +501,9 @@ func patchObjectVersion(t *testing.T, tc *handlerContext, bktName, objName, vers
return result
}
func patchObjectErr(t *testing.T, tc *handlerContext, bktName, objName, rng string, payload []byte, headers map[string]string, code apierr.ErrorCode) {
func patchObjectErr(t *testing.T, tc *handlerContext, bktName, objName, rng string, payload []byte, headers map[string]string, code s3errors.ErrorCode) {
w := patchObjectBase(tc, bktName, objName, "", rng, payload, headers)
assertS3Error(t, w, apierr.GetAPIError(code))
assertS3Error(t, w, s3errors.GetAPIError(code))
}
func patchObjectBase(tc *handlerContext, bktName, objName, version, rng string, payload []byte, headers map[string]string) *httptest.ResponseRecorder {

View file

@ -2,12 +2,11 @@ package handler
import (
"bytes"
"context"
"crypto/md5"
"encoding/base64"
"encoding/json"
"encoding/xml"
"errors"
stderrors "errors"
"fmt"
"io"
"mime/multipart"
@ -21,7 +20,7 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer/encryption"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
@ -92,11 +91,11 @@ func (p *postPolicy) CheckField(key string, value string) error {
}
cond := p.condition(key)
if cond == nil {
return apierr.GetAPIError(apierr.ErrPostPolicyConditionInvalidFormat)
return errors.GetAPIError(errors.ErrPostPolicyConditionInvalidFormat)
}
if !cond.match(value) {
return apierr.GetAPIError(apierr.ErrPostPolicyConditionInvalidFormat)
return errors.GetAPIError(errors.ErrPostPolicyConditionInvalidFormat)
}
return nil
@ -193,24 +192,24 @@ func (h *handler) PutObjectHandler(w http.ResponseWriter, r *http.Request) {
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket objInfo", reqInfo, err)
h.logAndSendError(w, "could not get bucket objInfo", reqInfo, err)
return
}
settings, err := h.obj.GetBucketSettings(ctx, bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket settings", reqInfo, err)
h.logAndSendError(w, "could not get bucket settings", reqInfo, err)
return
}
if cannedACLStatus == aclStatusYes {
h.logAndSendError(ctx, w, "acl not supported for this bucket", reqInfo, apierr.GetAPIError(apierr.ErrAccessControlListNotSupported))
h.logAndSendError(w, "acl not supported for this bucket", reqInfo, errors.GetAPIError(errors.ErrAccessControlListNotSupported))
return
}
tagSet, err := parseTaggingHeader(r.Header)
if err != nil {
h.logAndSendError(ctx, w, "could not parse tagging header", reqInfo, err)
h.logAndSendError(w, "could not parse tagging header", reqInfo, err)
return
}
@ -230,44 +229,44 @@ func (h *handler) PutObjectHandler(w http.ResponseWriter, r *http.Request) {
encryptionParams, err := formEncryptionParams(r)
if err != nil {
h.logAndSendError(ctx, w, "invalid sse headers", reqInfo, err)
h.logAndSendError(w, "invalid sse headers", reqInfo, err)
return
}
body, err := h.getBodyReader(r)
if err != nil {
h.logAndSendError(ctx, w, "failed to get body reader", reqInfo, err)
h.logAndSendError(w, "failed to get body reader", reqInfo, err)
return
}
if encodings := r.Header.Get(api.ContentEncoding); len(encodings) > 0 {
metadata[api.ContentEncoding] = encodings
}
size := h.getPutPayloadSize(r)
var size uint64
if r.ContentLength > 0 {
size = uint64(r.ContentLength)
}
params := &layer.PutObjectParams{
BktInfo: bktInfo,
Object: reqInfo.ObjectName,
Reader: body,
Size: size,
Header: metadata,
Encryption: encryptionParams,
ContentMD5: r.Header.Get(api.ContentMD5),
ContentSHA256Hash: r.Header.Get(api.AmzContentSha256),
}
if size > 0 {
params.Size = &size
}
params.CopiesNumbers, err = h.pickCopiesNumbers(metadata, reqInfo.Namespace, bktInfo.LocationConstraint)
if err != nil {
h.logAndSendError(ctx, w, "invalid copies number", reqInfo, err)
h.logAndSendError(w, "invalid copies number", reqInfo, err)
return
}
params.Lock, err = formObjectLock(ctx, bktInfo, settings.LockConfiguration, r.Header)
if err != nil {
h.logAndSendError(ctx, w, "could not form object lock", reqInfo, err)
h.logAndSendError(w, "could not form object lock", reqInfo, err)
return
}
@ -275,7 +274,7 @@ func (h *handler) PutObjectHandler(w http.ResponseWriter, r *http.Request) {
if err != nil {
_, err2 := io.Copy(io.Discard, body)
err3 := body.Close()
h.logAndSendError(ctx, w, "could not upload object", reqInfo, err, zap.Errors("body close errors", []error{err2, err3}))
h.logAndSendError(w, "could not upload object", reqInfo, err, zap.Errors("body close errors", []error{err2, err3}))
return
}
objInfo := extendedObjInfo.ObjectInfo
@ -290,8 +289,8 @@ func (h *handler) PutObjectHandler(w http.ResponseWriter, r *http.Request) {
TagSet: tagSet,
NodeVersion: extendedObjInfo.NodeVersion,
}
if err = h.obj.PutObjectTagging(ctx, tagPrm); err != nil {
h.logAndSendError(ctx, w, "could not upload object tagging", reqInfo, err)
if err = h.obj.PutObjectTagging(r.Context(), tagPrm); err != nil {
h.logAndSendError(w, "could not upload object tagging", reqInfo, err)
return
}
}
@ -306,13 +305,14 @@ func (h *handler) PutObjectHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set(api.ETag, data.Quote(objInfo.ETag(h.cfg.MD5Enabled())))
if err = middleware.WriteSuccessResponseHeadersOnly(w); err != nil {
h.logAndSendError(ctx, w, "write response", reqInfo, err)
h.logAndSendError(w, "write response", reqInfo, err)
return
}
}
func (h *handler) getBodyReader(r *http.Request) (io.ReadCloser, error) {
if !api.IsSignedStreamingV4(r) {
shaType, streaming := api.IsSignedStreamingV4(r)
if !streaming {
return r.Body, nil
}
@ -333,19 +333,28 @@ func (h *handler) getBodyReader(r *http.Request) (io.ReadCloser, error) {
if !chunkedEncoding && !h.cfg.BypassContentEncodingInChunks() {
return nil, fmt.Errorf("%w: request is not chunk encoded, encodings '%s'",
apierr.GetAPIError(apierr.ErrInvalidEncodingMethod), strings.Join(encodings, ","))
errors.GetAPIError(errors.ErrInvalidEncodingMethod), strings.Join(encodings, ","))
}
decodeContentSize := r.Header.Get(api.AmzDecodedContentLength)
if len(decodeContentSize) == 0 {
return nil, apierr.GetAPIError(apierr.ErrMissingContentLength)
return nil, errors.GetAPIError(errors.ErrMissingContentLength)
}
if _, err := strconv.Atoi(decodeContentSize); err != nil {
return nil, fmt.Errorf("%w: parse decoded content length: %s", apierr.GetAPIError(apierr.ErrMissingContentLength), err.Error())
return nil, fmt.Errorf("%w: parse decoded content length: %s", errors.GetAPIError(errors.ErrMissingContentLength), err.Error())
}
var (
err error
chunkReader io.ReadCloser
)
if shaType == api.StreamingContentV4aSHA256 {
chunkReader, err = newSignV4aChunkedReader(r)
} else {
chunkReader, err = newSignV4ChunkedReader(r)
}
chunkReader, err := newSignV4ChunkedReader(r)
if err != nil {
return nil, fmt.Errorf("initialize chunk reader: %w", err)
}
@ -378,43 +387,43 @@ func formEncryptionParamsBase(r *http.Request, isCopySource bool) (enc encryptio
}
if r.TLS == nil {
return enc, apierr.GetAPIError(apierr.ErrInsecureSSECustomerRequest)
return enc, errors.GetAPIError(errors.ErrInsecureSSECustomerRequest)
}
if len(sseCustomerKey) > 0 && len(sseCustomerAlgorithm) == 0 {
return enc, apierr.GetAPIError(apierr.ErrMissingSSECustomerAlgorithm)
return enc, errors.GetAPIError(errors.ErrMissingSSECustomerAlgorithm)
}
if len(sseCustomerAlgorithm) > 0 && len(sseCustomerKey) == 0 {
return enc, apierr.GetAPIError(apierr.ErrMissingSSECustomerKey)
return enc, errors.GetAPIError(errors.ErrMissingSSECustomerKey)
}
if sseCustomerAlgorithm != layer.AESEncryptionAlgorithm {
return enc, apierr.GetAPIError(apierr.ErrInvalidEncryptionAlgorithm)
return enc, errors.GetAPIError(errors.ErrInvalidEncryptionAlgorithm)
}
key, err := base64.StdEncoding.DecodeString(sseCustomerKey)
if err != nil {
if isCopySource {
return enc, apierr.GetAPIError(apierr.ErrInvalidSSECustomerParameters)
return enc, errors.GetAPIError(errors.ErrInvalidSSECustomerParameters)
}
return enc, apierr.GetAPIError(apierr.ErrInvalidSSECustomerKey)
return enc, errors.GetAPIError(errors.ErrInvalidSSECustomerKey)
}
if len(key) != layer.AESKeySize {
if isCopySource {
return enc, apierr.GetAPIError(apierr.ErrInvalidSSECustomerParameters)
return enc, errors.GetAPIError(errors.ErrInvalidSSECustomerParameters)
}
return enc, apierr.GetAPIError(apierr.ErrInvalidSSECustomerKey)
return enc, errors.GetAPIError(errors.ErrInvalidSSECustomerKey)
}
keyMD5, err := base64.StdEncoding.DecodeString(sseCustomerKeyMD5)
if err != nil {
return enc, apierr.GetAPIError(apierr.ErrSSECustomerKeyMD5Mismatch)
return enc, errors.GetAPIError(errors.ErrSSECustomerKeyMD5Mismatch)
}
md5Sum := md5.Sum(key)
if !bytes.Equal(md5Sum[:], keyMD5) {
return enc, apierr.GetAPIError(apierr.ErrSSECustomerKeyMD5Mismatch)
return enc, errors.GetAPIError(errors.ErrSSECustomerKeyMD5Mismatch)
}
params, err := encryption.NewParams(key)
@ -435,7 +444,7 @@ func (h *handler) PostObject(w http.ResponseWriter, r *http.Request) {
policy, err := checkPostPolicy(r, reqInfo, metadata)
if err != nil {
h.logAndSendError(ctx, w, "failed check policy", reqInfo, err)
h.logAndSendError(w, "failed check policy", reqInfo, err)
return
}
@ -443,31 +452,31 @@ func (h *handler) PostObject(w http.ResponseWriter, r *http.Request) {
buffer := bytes.NewBufferString(tagging)
tags := new(data.Tagging)
if err = h.cfg.NewXMLDecoder(buffer).Decode(tags); err != nil {
h.logAndSendError(ctx, w, "could not decode tag set", reqInfo,
fmt.Errorf("%w: %s", apierr.GetAPIError(apierr.ErrMalformedXML), err.Error()))
h.logAndSendError(w, "could not decode tag set", reqInfo,
fmt.Errorf("%w: %s", errors.GetAPIError(errors.ErrMalformedXML), err.Error()))
return
}
tagSet, err = h.readTagSet(tags)
if err != nil {
h.logAndSendError(ctx, w, "could not read tag set", reqInfo, err)
h.logAndSendError(w, "could not read tag set", reqInfo, err)
return
}
}
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket objInfo", reqInfo, err)
h.logAndSendError(w, "could not get bucket objInfo", reqInfo, err)
return
}
settings, err := h.obj.GetBucketSettings(ctx, bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket settings", reqInfo, err)
h.logAndSendError(w, "could not get bucket settings", reqInfo, err)
return
}
if acl := auth.MultipartFormValue(r, "acl"); acl != "" && acl != basicACLPrivate {
h.logAndSendError(ctx, w, "acl not supported for this bucket", reqInfo, apierr.GetAPIError(apierr.ErrAccessControlListNotSupported))
h.logAndSendError(w, "acl not supported for this bucket", reqInfo, errors.GetAPIError(errors.ErrAccessControlListNotSupported))
return
}
@ -485,7 +494,7 @@ func (h *handler) PostObject(w http.ResponseWriter, r *http.Request) {
if reqInfo.ObjectName == "" || strings.Contains(reqInfo.ObjectName, "${filename}") {
_, head, err := r.FormFile("file")
if err != nil {
h.logAndSendError(ctx, w, "could not parse file field", reqInfo, err)
h.logAndSendError(w, "could not parse file field", reqInfo, err)
return
}
filename = head.Filename
@ -494,7 +503,7 @@ func (h *handler) PostObject(w http.ResponseWriter, r *http.Request) {
var head *multipart.FileHeader
contentReader, head, err = r.FormFile("file")
if err != nil {
h.logAndSendError(ctx, w, "could not parse file field", reqInfo, err)
h.logAndSendError(w, "could not parse file field", reqInfo, err)
return
}
size = uint64(head.Size)
@ -508,12 +517,12 @@ func (h *handler) PostObject(w http.ResponseWriter, r *http.Request) {
}
if reqInfo.ObjectName == "" {
h.logAndSendError(ctx, w, "missing object name", reqInfo, apierr.GetAPIError(apierr.ErrInvalidArgument))
h.logAndSendError(w, "missing object name", reqInfo, errors.GetAPIError(errors.ErrInvalidArgument))
return
}
if !policy.CheckContentLength(size) {
h.logAndSendError(ctx, w, "invalid content-length", reqInfo, apierr.GetAPIError(apierr.ErrInvalidArgument))
h.logAndSendError(w, "invalid content-length", reqInfo, errors.GetAPIError(errors.ErrInvalidArgument))
return
}
@ -521,13 +530,13 @@ func (h *handler) PostObject(w http.ResponseWriter, r *http.Request) {
BktInfo: bktInfo,
Object: reqInfo.ObjectName,
Reader: contentReader,
Size: &size,
Size: size,
Header: metadata,
}
extendedObjInfo, err := h.obj.PutObject(ctx, params)
if err != nil {
h.logAndSendError(ctx, w, "could not upload object", reqInfo, err)
h.logAndSendError(w, "could not upload object", reqInfo, err)
return
}
objInfo := extendedObjInfo.ObjectInfo
@ -543,7 +552,7 @@ func (h *handler) PostObject(w http.ResponseWriter, r *http.Request) {
}
if err = h.obj.PutObjectTagging(ctx, tagPrm); err != nil {
h.logAndSendError(ctx, w, "could not upload object tagging", reqInfo, err)
h.logAndSendError(w, "could not upload object tagging", reqInfo, err)
return
}
}
@ -571,10 +580,10 @@ func (h *handler) PostObject(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(status)
respData, err := middleware.EncodeResponse(resp)
if err != nil {
h.logAndSendError(ctx, w, "encode response", reqInfo, err)
h.logAndSendError(w, "encode response", reqInfo, err)
}
if _, err = w.Write(respData); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
}
return
}
@ -595,13 +604,13 @@ func checkPostPolicy(r *http.Request, reqInfo *middleware.ReqInfo, metadata map[
return nil, fmt.Errorf("could not unmarshal policy: %w", err)
}
if policy.Expiration.Before(time.Now()) {
return nil, fmt.Errorf("policy is expired: %w", apierr.GetAPIError(apierr.ErrInvalidArgument))
return nil, fmt.Errorf("policy is expired: %w", errors.GetAPIError(errors.ErrInvalidArgument))
}
policy.empty = false
}
if r.MultipartForm == nil {
return nil, errors.New("empty multipart form")
return nil, stderrors.New("empty multipart form")
}
for key, v := range r.MultipartForm.Value {
@ -632,7 +641,7 @@ func checkPostPolicy(r *http.Request, reqInfo *middleware.ReqInfo, metadata map[
for _, cond := range policy.Conditions {
if cond.Key == "bucket" {
if !cond.match(reqInfo.BucketName) {
return nil, apierr.GetAPIError(apierr.ErrPostPolicyConditionInvalidFormat)
return nil, errors.GetAPIError(errors.ErrPostPolicyConditionInvalidFormat)
}
}
}
@ -674,10 +683,10 @@ func parseTaggingHeader(header http.Header) (map[string]string, error) {
if tagging := header.Get(api.AmzTagging); len(tagging) > 0 {
queries, err := url.ParseQuery(tagging)
if err != nil {
return nil, apierr.GetAPIError(apierr.ErrInvalidArgument)
return nil, errors.GetAPIError(errors.ErrInvalidArgument)
}
if len(queries) > maxTags {
return nil, apierr.GetAPIError(apierr.ErrInvalidTagsSizeExceed)
return nil, errors.GetAPIError(errors.ErrInvalidTagsSizeExceed)
}
tagSet = make(map[string]string, len(queries))
for k, v := range queries {
@ -727,7 +736,7 @@ func (h *handler) parseCommonCreateBucketParams(reqInfo *middleware.ReqInfo, box
}
if p.SessionContainerCreation == nil {
return nil, nil, fmt.Errorf("%w: couldn't find session token for put", apierr.GetAPIError(apierr.ErrAccessDenied))
return nil, nil, fmt.Errorf("%w: couldn't find session token for put", errors.GetAPIError(errors.ErrAccessDenied))
}
if err := checkBucketName(reqInfo.BucketName); err != nil {
@ -759,33 +768,32 @@ func (h *handler) createBucketHandlerPolicy(w http.ResponseWriter, r *http.Reque
boxData, err := middleware.GetBoxData(ctx)
if err != nil {
h.logAndSendError(ctx, w, "get access box from request", reqInfo, err)
h.logAndSendError(w, "get access box from request", reqInfo, err)
return
}
key, p, err := h.parseCommonCreateBucketParams(reqInfo, boxData, r)
if err != nil {
h.logAndSendError(ctx, w, "parse create bucket params", reqInfo, err)
h.logAndSendError(w, "parse create bucket params", reqInfo, err)
return
}
cannedACL, err := parseCannedACL(r.Header)
if err != nil {
h.logAndSendError(ctx, w, "could not parse canned ACL", reqInfo, err)
h.logAndSendError(w, "could not parse canned ACL", reqInfo, err)
return
}
bktInfo, err := h.obj.CreateBucket(ctx, p)
if err != nil {
h.logAndSendError(ctx, w, "could not create bucket", reqInfo, err)
h.logAndSendError(w, "could not create bucket", reqInfo, err)
return
}
h.reqLogger(ctx).Info(logs.BucketIsCreated, zap.Stringer("container_id", bktInfo.CID))
chains := bucketCannedACLToAPERules(cannedACL, reqInfo, bktInfo.CID)
if err = h.ape.SaveACLChains(bktInfo.CID.EncodeToString(), chains); err != nil {
cleanErr := h.cleanupBucketCreation(ctx, reqInfo, bktInfo, boxData, chains)
h.logAndSendError(ctx, w, "failed to add morph rule chain", reqInfo, err, zap.NamedError("cleanup_error", cleanErr))
h.logAndSendError(w, "failed to add morph rule chain", reqInfo, err)
return
}
@ -806,40 +814,17 @@ func (h *handler) createBucketHandlerPolicy(w http.ResponseWriter, r *http.Reque
return h.obj.PutBucketSettings(ctx, sp)
}, h.putBucketSettingsRetryer())
if err != nil {
cleanErr := h.cleanupBucketCreation(ctx, reqInfo, bktInfo, boxData, chains)
h.logAndSendError(ctx, w, "couldn't save bucket settings", reqInfo, err,
zap.String("container_id", bktInfo.CID.EncodeToString()), zap.NamedError("cleanup_error", cleanErr))
h.logAndSendError(w, "couldn't save bucket settings", reqInfo, err,
zap.String("container_id", bktInfo.CID.EncodeToString()))
return
}
if err = middleware.WriteSuccessResponseHeadersOnly(w); err != nil {
h.logAndSendError(ctx, w, "write response", reqInfo, err)
h.logAndSendError(w, "write response", reqInfo, err)
return
}
}
func (h *handler) cleanupBucketCreation(ctx context.Context, reqInfo *middleware.ReqInfo, bktInfo *data.BucketInfo, boxData *accessbox.Box, chains []*chain.Chain) error {
prm := &layer.DeleteBucketParams{
BktInfo: bktInfo,
SessionToken: boxData.Gate.SessionTokenForDelete(),
}
if err := h.obj.DeleteContainer(ctx, prm); err != nil {
return err
}
chainIDs := make([]chain.ID, len(chains))
for i, c := range chains {
chainIDs[i] = c.ID
}
if err := h.ape.DeleteBucketPolicy(reqInfo.Namespace, bktInfo.CID, chainIDs); err != nil {
return fmt.Errorf("delete bucket acl policy: %w", err)
}
return nil
}
func (h *handler) putBucketSettingsRetryer() aws.RetryerV2 {
return retry.NewStandard(func(options *retry.StandardOptions) {
options.MaxAttempts = h.cfg.RetryMaxAttempts()
@ -853,7 +838,7 @@ func (h *handler) putBucketSettingsRetryer() aws.RetryerV2 {
}
options.Retryables = []retry.IsErrorRetryable{retry.IsErrorRetryableFunc(func(err error) aws.Ternary {
if errors.Is(err, tree.ErrNodeAccessDenied) {
if stderrors.Is(err, tree.ErrNodeAccessDenied) {
return aws.TrueTernary
}
return aws.FalseTernary
@ -982,7 +967,7 @@ func (h handler) setPlacementPolicy(prm *layer.CreateBucketParams, namespace, lo
return nil
}
return apierr.GetAPIError(apierr.ErrInvalidLocationConstraint)
return errors.GetAPIError(errors.ErrInvalidLocationConstraint)
}
func isLockEnabled(log *zap.Logger, header http.Header) bool {
@ -1001,22 +986,28 @@ func isLockEnabled(log *zap.Logger, header http.Header) bool {
func checkBucketName(bucketName string) error {
if len(bucketName) < 3 || len(bucketName) > 63 {
return apierr.GetAPIError(apierr.ErrInvalidBucketName)
return errors.GetAPIError(errors.ErrInvalidBucketName)
}
if strings.HasPrefix(bucketName, "xn--") || strings.HasSuffix(bucketName, "-s3alias") {
return apierr.GetAPIError(apierr.ErrInvalidBucketName)
return errors.GetAPIError(errors.ErrInvalidBucketName)
}
if net.ParseIP(bucketName) != nil {
return apierr.GetAPIError(apierr.ErrInvalidBucketName)
return errors.GetAPIError(errors.ErrInvalidBucketName)
}
for i, r := range bucketName {
if r == '.' || (!isAlphaNum(r) && r != '-') {
return apierr.GetAPIError(apierr.ErrInvalidBucketName)
labels := strings.Split(bucketName, ".")
for _, label := range labels {
if len(label) == 0 {
return errors.GetAPIError(errors.ErrInvalidBucketName)
}
for i, r := range label {
if !isAlphaNum(r) && r != '-' {
return errors.GetAPIError(errors.ErrInvalidBucketName)
}
if (i == 0 || i == len(label)-1) && r == '-' {
return errors.GetAPIError(errors.ErrInvalidBucketName)
}
if (i == 0 || i == len(bucketName)-1) && r == '-' {
return apierr.GetAPIError(apierr.ErrInvalidBucketName)
}
}
@ -1034,7 +1025,7 @@ func (h *handler) parseLocationConstraint(r *http.Request) (*createBucketParams,
params := new(createBucketParams)
if err := h.cfg.NewXMLDecoder(r.Body).Decode(params); err != nil {
return nil, fmt.Errorf("%w: %s", apierr.GetAPIError(apierr.ErrMalformedXML), err.Error())
return nil, fmt.Errorf("%w: %s", errors.GetAPIError(errors.ErrMalformedXML), err.Error())
}
return params, nil
}

View file

@ -7,7 +7,6 @@ import (
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"io"
"mime/multipart"
"net/http"
@ -19,22 +18,17 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth"
v4 "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4"
v4 "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4sdk2/signer/v4"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
s3errors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/stretchr/testify/require"
)
const (
awsChunkedRequestExampleDecodedContentLength = 66560
awsChunkedRequestExampleContentLength = 66824
)
func TestCheckBucketName(t *testing.T) {
for _, tc := range []struct {
name string
@ -42,10 +36,10 @@ func TestCheckBucketName(t *testing.T) {
}{
{name: "bucket"},
{name: "2bucket"},
{name: "buc.ket"},
{name: "buc-ket"},
{name: "abc"},
{name: "63aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"},
{name: "buc.ket", err: true},
{name: "buc.-ket", err: true},
{name: "bucket.", err: true},
{name: ".bucket", err: true},
@ -205,7 +199,7 @@ func TestPostObject(t *testing.T) {
t.Run(tc.key+";"+tc.filename, func(t *testing.T) {
w := postObjectBase(hc, ns, bktName, tc.key, tc.filename, tc.content)
if tc.err {
assertS3Error(hc.t, w, apierr.GetAPIError(apierr.ErrInternalError))
assertS3Error(hc.t, w, s3errors.GetAPIError(s3errors.ErrInternalError))
return
}
assertStatus(hc.t, w, http.StatusNoContent)
@ -251,10 +245,6 @@ func TestPutObjectWithNegativeContentLength(t *testing.T) {
tc.Handler().HeadObjectHandler(w, r)
assertStatus(t, w, http.StatusOK)
require.Equal(t, strconv.Itoa(len(content)), w.Header().Get(api.ContentLength))
result := listVersions(t, tc, bktName)
require.Len(t, result.Version, 1)
require.EqualValues(t, len(content), result.Version[0].Size)
}
func TestPutObjectWithStreamBodyError(t *testing.T) {
@ -268,7 +258,7 @@ func TestPutObjectWithStreamBodyError(t *testing.T) {
r.Header.Set(api.AmzContentSha256, api.StreamingContentSHA256)
r.Header.Set(api.ContentEncoding, api.AwsChunked)
tc.Handler().PutObjectHandler(w, r)
assertS3Error(t, w, apierr.GetAPIError(apierr.ErrMissingContentLength))
assertS3Error(t, w, s3errors.GetAPIError(s3errors.ErrMissingContentLength))
checkNotFound(t, tc, bktName, objName, emptyVersion)
}
@ -284,7 +274,7 @@ func TestPutObjectWithInvalidContentMD5(t *testing.T) {
w, r := prepareTestPayloadRequest(tc, bktName, objName, bytes.NewReader(content))
r.Header.Set(api.ContentMD5, base64.StdEncoding.EncodeToString([]byte("invalid")))
tc.Handler().PutObjectHandler(w, r)
assertS3Error(t, w, apierr.GetAPIError(apierr.ErrInvalidDigest))
assertS3Error(t, w, s3errors.GetAPIError(s3errors.ErrInvalidDigest))
checkNotFound(t, tc, bktName, objName, emptyVersion)
}
@ -347,7 +337,7 @@ func TestPutObjectCheckContentSHA256(t *testing.T) {
hc.Handler().PutObjectHandler(w, r)
if tc.error {
assertS3Error(t, w, apierr.GetAPIError(apierr.ErrContentSHA256Mismatch))
assertS3Error(t, w, s3errors.GetAPIError(s3errors.ErrContentSHA256Mismatch))
w, r := prepareTestRequest(hc, bktName, objName, nil)
hc.Handler().GetObjectHandler(w, r)
@ -371,37 +361,12 @@ func TestPutObjectWithStreamBodyAWSExample(t *testing.T) {
hc.Handler().PutObjectHandler(w, req)
assertStatus(t, w, http.StatusOK)
w, req = prepareTestRequest(hc, bktName, objName, nil)
hc.Handler().HeadObjectHandler(w, req)
assertStatus(t, w, http.StatusOK)
require.Equal(t, strconv.Itoa(awsChunkedRequestExampleDecodedContentLength), w.Header().Get(api.ContentLength))
data := getObjectRange(t, hc, bktName, objName, 0, awsChunkedRequestExampleDecodedContentLength)
data := getObjectRange(t, hc, bktName, objName, 0, 66824)
for i := range chunk {
require.Equal(t, chunk[i], data[i])
}
}
func TestPutObjectWithStreamEmptyBodyAWSExample(t *testing.T) {
hc := prepareHandlerContext(t)
bktName, objName := "dkirillov", "tmp"
createTestBucket(hc, bktName)
w, req := getEmptyChunkedRequest(hc.context, t, bktName, objName)
hc.Handler().PutObjectHandler(w, req)
assertStatus(t, w, http.StatusOK)
w, req = prepareTestRequest(hc, bktName, objName, nil)
hc.Handler().HeadObjectHandler(w, req)
assertStatus(t, w, http.StatusOK)
require.Equal(t, "0", w.Header().Get(api.ContentLength))
res := listObjectsV1(hc, bktName, "", "", "", -1)
require.Len(t, res.Contents, 1)
require.Empty(t, res.Contents[0].Size)
}
func TestPutChunkedTestContentEncoding(t *testing.T) {
hc := prepareHandlerContext(t)
@ -420,7 +385,7 @@ func TestPutChunkedTestContentEncoding(t *testing.T) {
w, req, _ = getChunkedRequest(hc.context, t, bktName, objName)
req.Header.Set(api.ContentEncoding, "gzip")
hc.Handler().PutObjectHandler(w, req)
assertS3Error(t, w, apierr.GetAPIError(apierr.ErrInvalidEncodingMethod))
assertS3Error(t, w, s3errors.GetAPIError(s3errors.ErrInvalidEncodingMethod))
hc.config.bypassContentEncodingInChunks = true
w, req, _ = getChunkedRequest(hc.context, t, bktName, objName)
@ -432,8 +397,6 @@ func TestPutChunkedTestContentEncoding(t *testing.T) {
require.Equal(t, "gzip", resp.Header().Get(api.ContentEncoding))
}
// getChunkedRequest implements request example from
// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html
func getChunkedRequest(ctx context.Context, t *testing.T, bktName, objName string) (*httptest.ResponseRecorder, *http.Request, []byte) {
chunk := make([]byte, 65*1024)
for i := range chunk {
@ -445,8 +408,8 @@ func getChunkedRequest(ctx context.Context, t *testing.T, bktName, objName strin
AWSAccessKeyID := "AKIAIOSFODNN7EXAMPLE"
AWSSecretAccessKey := "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
awsCreds := credentials.NewStaticCredentials(AWSAccessKeyID, AWSSecretAccessKey, "")
signer := v4.NewSigner(awsCreds)
awsCreds := aws.Credentials{AccessKeyID: AWSAccessKeyID, SecretAccessKey: AWSSecretAccessKey}
signer := v4.NewSigner()
reqBody := bytes.NewBufferString("10000;chunk-signature=ad80c730a21e5b8d04586a2213dd63b9a0e99e0e2307b0ade35a65485a288648\r\n")
_, err := reqBody.Write(chunk1)
@ -461,15 +424,15 @@ func getChunkedRequest(ctx context.Context, t *testing.T, bktName, objName strin
req, err := http.NewRequest("PUT", "https://s3.amazonaws.com/"+bktName+"/"+objName, nil)
require.NoError(t, err)
req.Header.Set("content-encoding", "aws-chunked")
req.Header.Set("content-length", strconv.Itoa(awsChunkedRequestExampleContentLength))
req.Header.Set("content-length", "66824")
req.Header.Set("x-amz-content-sha256", "STREAMING-AWS4-HMAC-SHA256-PAYLOAD")
req.Header.Set("x-amz-decoded-content-length", strconv.Itoa(awsChunkedRequestExampleDecodedContentLength))
req.Header.Set("x-amz-decoded-content-length", "66560")
req.Header.Set("x-amz-storage-class", "REDUCED_REDUNDANCY")
signTime, err := time.Parse("20060102T150405Z", "20130524T000000Z")
require.NoError(t, err)
_, err = signer.Sign(req, nil, "s3", "us-east-1", signTime)
err = signer.SignHTTP(ctx, awsCreds, req, auth.UnsignedPayload, "s3", "us-east-1", signTime)
require.NoError(t, err)
req.Body = io.NopCloser(reqBody)
@ -494,72 +457,15 @@ func getChunkedRequest(ctx context.Context, t *testing.T, bktName, objName strin
return w, req, chunk
}
func getEmptyChunkedRequest(ctx context.Context, t *testing.T, bktName, objName string) (*httptest.ResponseRecorder, *http.Request) {
AWSAccessKeyID := "48c1K4PLVb7SvmV3PjDKEuXaMh8yZMXZ8Wx9msrkKcYw06dZeaxeiPe8vyFm2WsoeVaNt7UWEjNsVkagDs8oX4XXh"
AWSSecretAccessKey := "09260955b4eb0279dc017ba20a1ddac909cbd226c86cbb2d868e55534c8e64b0"
//awsCreds := credentials.NewStaticCredentials(AWSAccessKeyID, AWSSecretAccessKey, "")
//signer := v4.NewSigner(awsCreds)
reqBody := bytes.NewBufferString("0;chunk-signature=311a7142c8f3a07972c3aca65c36484b513a8fee48ab7178c7225388f2ae9894\r\n\r\n")
req, err := http.NewRequest("PUT", "http://localhost:8084/"+bktName+"/"+objName, reqBody)
require.NoError(t, err)
req.Header.Set("Amz-Sdk-Invocation-Id", "8a8cd4be-aef8-8034-f08d-a6144ade41f9")
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=2")
req.Header.Set(api.Authorization, "AWS4-HMAC-SHA256 Credential=48c1K4PLVb7SvmV3PjDKEuXaMh8yZMXZ8Wx9msrkKcYw06dZeaxeiPe8vyFm2WsoeVaNt7UWEjNsVkagDs8oX4XXh/20241003/us-east-1/s3/aws4_request, SignedHeaders=amz-sdk-invocation-id;amz-sdk-request;content-encoding;content-length;content-type;host;x-amz-content-sha256;x-amz-date;x-amz-decoded-content-length, Signature=4b530ab4af2381f214941af591266b209968264a2c94337fa1efc048c7dff352")
req.Header.Set(api.ContentEncoding, "aws-chunked")
req.Header.Set(api.ContentLength, "86")
req.Header.Set(api.ContentType, "text/plain; charset=UTF-8")
req.Header.Set(api.AmzDate, "20241003T100055Z")
req.Header.Set(api.AmzContentSha256, "STREAMING-AWS4-HMAC-SHA256-PAYLOAD")
req.Header.Set(api.AmzDecodedContentLength, "0")
signTime, err := time.Parse("20060102T150405Z", "20241003T100055Z")
require.NoError(t, err)
w := httptest.NewRecorder()
reqInfo := middleware.NewReqInfo(w, req, middleware.ObjectRequest{Bucket: bktName, Object: objName}, "")
req = req.WithContext(middleware.SetReqInfo(ctx, reqInfo))
req = req.WithContext(middleware.SetBox(req.Context(), &middleware.Box{
ClientTime: signTime,
AuthHeaders: &middleware.AuthHeader{
AccessKeyID: AWSAccessKeyID,
SignatureV4: "4b530ab4af2381f214941af591266b209968264a2c94337fa1efc048c7dff352",
Region: "us-east-1",
},
AccessBox: &accessbox.Box{
Gate: &accessbox.GateData{
SecretKey: AWSSecretAccessKey,
},
},
}))
return w, req
}
func TestCreateBucket(t *testing.T) {
hc := prepareHandlerContext(t)
bktName := "bkt-name"
info := createBucket(hc, bktName)
createBucketAssertS3Error(hc, bktName, info.Box, apierr.ErrBucketAlreadyOwnedByYou)
createBucketAssertS3Error(hc, bktName, info.Box, s3errors.ErrBucketAlreadyOwnedByYou)
box2, _ := createAccessBox(t)
createBucketAssertS3Error(hc, bktName, box2, apierr.ErrBucketAlreadyExists)
}
func TestCreateBucketWithoutPermissions(t *testing.T) {
hc := prepareHandlerContext(t)
bktName := "bkt-name"
hc.h.ape.(*apeMock).err = errors.New("no permissions")
box, _ := createAccessBox(t)
createBucketAssertS3Error(hc, bktName, box, apierr.ErrInternalError)
_, err := hc.tp.ContainerID(bktName)
require.Errorf(t, err, "container exists after failed creation, but shouldn't")
createBucketAssertS3Error(hc, bktName, box2, s3errors.ErrBucketAlreadyExists)
}
func TestCreateNamespacedBucket(t *testing.T) {

View file

@ -3,16 +3,17 @@ package handler
import (
"bufio"
"bytes"
"context"
"encoding/hex"
"errors"
"io"
"net/http"
"time"
v4 "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4"
v4 "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4sdk2/signer/v4"
errs "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go-v2/aws"
)
const (
@ -22,6 +23,7 @@ const (
type (
s3ChunkReader struct {
ctx context.Context
reader *bufio.Reader
streamSigner *v4.StreamSigner
@ -166,7 +168,7 @@ func (c *s3ChunkReader) Read(buf []byte) (num int, err error) {
// Once we have read the entire chunk successfully, we verify
// that the received signature matches our computed signature.
calculatedSignature, err := c.streamSigner.GetSignature(nil, c.buffer, c.requestTime)
calculatedSignature, err := c.streamSigner.GetSignature(c.ctx, nil, c.buffer, c.requestTime)
if err != nil {
c.err = err
return num, c.err
@ -189,29 +191,31 @@ func (c *s3ChunkReader) Read(buf []byte) (num int, err error) {
}
func newSignV4ChunkedReader(req *http.Request) (io.ReadCloser, error) {
box, err := middleware.GetBoxData(req.Context())
ctx := req.Context()
box, err := middleware.GetBoxData(ctx)
if err != nil {
return nil, errs.GetAPIError(errs.ErrAuthorizationHeaderMalformed)
}
authHeaders, err := middleware.GetAuthHeaders(req.Context())
authHeaders, err := middleware.GetAuthHeaders(ctx)
if err != nil {
return nil, errs.GetAPIError(errs.ErrAuthorizationHeaderMalformed)
}
currentCredentials := credentials.NewStaticCredentials(authHeaders.AccessKeyID, box.Gate.SecretKey, "")
currentCredentials := aws.Credentials{AccessKeyID: authHeaders.AccessKeyID, SecretAccessKey: box.Gate.SecretKey}
seed, err := hex.DecodeString(authHeaders.SignatureV4)
if err != nil {
return nil, errs.GetAPIError(errs.ErrSignatureDoesNotMatch)
}
reqTime, err := middleware.GetClientTime(req.Context())
reqTime, err := middleware.GetClientTime(ctx)
if err != nil {
return nil, errs.GetAPIError(errs.ErrMalformedDate)
}
newStreamSigner := v4.NewStreamSigner(authHeaders.Region, "s3", seed, currentCredentials)
newStreamSigner := v4.NewStreamSigner(currentCredentials, "s3", authHeaders.Region, seed)
return &s3ChunkReader{
ctx: ctx,
reader: bufio.NewReader(req.Body),
streamSigner: newStreamSigner,
requestTime: reqTime,

209
api/handler/s3v4aReader.go Normal file
View file

@ -0,0 +1,209 @@
package handler
import (
"bufio"
"bytes"
"encoding/hex"
"fmt"
"io"
"net/http"
"time"
v4a "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4asdk2"
errs "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials"
)
type (
s3v4aChunkReader struct {
reader *bufio.Reader
streamSigner *v4a.StreamSigner
requestTime time.Time
buffer []byte
offset int
err error
}
)
func (c *s3v4aChunkReader) Close() (err error) {
return nil
}
func (c *s3v4aChunkReader) Read(buf []byte) (num int, err error) {
if c.offset > 0 {
num = copy(buf, c.buffer[c.offset:])
if num == len(buf) {
c.offset += num
return num, nil
}
c.offset = 0
buf = buf[num:]
}
var size int
for {
b, err := c.reader.ReadByte()
if err != nil {
return c.handleErr(num, err)
}
if b == ';' { // separating character
break
}
// Manually deserialize the size since AWS specified
// the chunk size to be of variable width. In particular,
// a size of 16 is encoded as `10` while a size of 64 KB
// is `10000`.
switch {
case b >= '0' && b <= '9':
size = size<<4 | int(b-'0')
case b >= 'a' && b <= 'f':
size = size<<4 | int(b-('a'-10))
case b >= 'A' && b <= 'F':
size = size<<4 | int(b-('A'-10))
default:
c.err = errMalformedChunkedEncoding
return num, c.err
}
if size > maxChunkSize {
c.err = errGiantChunk
return num, c.err
}
}
// Now, we read the signature of the following payload and expect:
// chunk-signature=" + <signature-as-hex> + "**\r\n"
//
// The signature is 64 bytes long (hex-encoded SHA256 hash) and
// starts with a 16 byte header: len("chunk-signature=") + 144 == 160.
var signature [160]byte
_, err = io.ReadFull(c.reader, signature[:])
if err != nil {
return c.handleErr(num, err)
}
if !bytes.HasPrefix(signature[:], []byte(chunkSignatureHeader)) {
c.err = errMalformedChunkedEncoding
return num, c.err
}
b, err := c.reader.ReadByte()
if err != nil {
return c.handleErr(num, err)
}
if b != '\r' {
c.err = errMalformedChunkedEncoding
return num, c.err
}
b, err = c.reader.ReadByte()
if err != nil {
return c.handleErr(num, err)
}
if b != '\n' {
c.err = errMalformedChunkedEncoding
return num, c.err
}
if cap(c.buffer) < size {
c.buffer = make([]byte, size)
} else {
c.buffer = c.buffer[:size]
}
// Now, we read the payload and compute its SHA-256 hash.
_, err = io.ReadFull(c.reader, c.buffer)
if err == io.EOF && size != 0 {
err = io.ErrUnexpectedEOF
}
if err != nil && err != io.EOF {
c.err = err
return num, c.err
}
b, err = c.reader.ReadByte()
if b != '\r' || err != nil {
c.err = errMalformedChunkedEncoding
return num, c.err
}
b, err = c.reader.ReadByte()
if err != nil {
return c.handleErr(num, err)
}
if b != '\n' {
c.err = errMalformedChunkedEncoding
return num, c.err
}
// Once we have read the entire chunk successfully, we verify
// that the received signature is valid.
n, err := hex.Decode(signature[:], bytes.TrimSuffix(signature[:], []byte("**"))[16:])
if err != nil {
c.err = errMalformedChunkedEncoding
return num, c.err
}
if err = c.streamSigner.VerifySignature(nil, c.buffer, c.requestTime, signature[:n]); err != nil {
c.err = err
return num, c.err
}
// If the chunk size is zero we return io.EOF. As specified by AWS,
// only the last chunk is zero-sized.
if size == 0 {
c.err = io.EOF
return num, c.err
}
c.offset = copy(buf, c.buffer)
num += c.offset
return num, err
}
func (c *s3v4aChunkReader) handleErr(num int, err error) (int, error) {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
c.err = err
return num, c.err
}
func newSignV4aChunkedReader(req *http.Request) (io.ReadCloser, error) {
box, err := middleware.GetBoxData(req.Context())
if err != nil {
return nil, errs.GetAPIError(errs.ErrAuthorizationHeaderMalformed)
}
authHeaders, err := middleware.GetAuthHeaders(req.Context())
if err != nil {
return nil, errs.GetAPIError(errs.ErrAuthorizationHeaderMalformed)
}
seed, err := hex.DecodeString(authHeaders.SignatureV4)
if err != nil {
return nil, errs.GetAPIError(errs.ErrSignatureDoesNotMatch)
}
reqTime, err := middleware.GetClientTime(req.Context())
if err != nil {
return nil, errs.GetAPIError(errs.ErrMalformedDate)
}
credAdapter := v4a.SymmetricCredentialAdaptor{
SymmetricProvider: credentialsv2.NewStaticCredentialsProvider(authHeaders.AccessKeyID, box.Gate.SecretKey, ""),
}
creds, err := credAdapter.RetrievePrivateKey(req.Context())
if err != nil {
return nil, fmt.Errorf("failed to derive assymetric key from credentials: %w", err)
}
newStreamSigner := v4a.NewStreamSigner(creds, "s3", seed)
return &s3v4aChunkReader{
reader: bufio.NewReader(req.Body),
streamSigner: newStreamSigner,
requestTime: reqTime,
buffer: make([]byte, 64*1024),
}, nil
}

View file

@ -26,13 +26,13 @@ func (h *handler) PutObjectTaggingHandler(w http.ResponseWriter, r *http.Request
tagSet, err := h.readTagSet(reqInfo.Tagging)
if err != nil {
h.logAndSendError(ctx, w, "could not read tag set", reqInfo, err)
h.logAndSendError(w, "could not read tag set", reqInfo, err)
return
}
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
@ -46,7 +46,7 @@ func (h *handler) PutObjectTaggingHandler(w http.ResponseWriter, r *http.Request
}
if err = h.obj.PutObjectTagging(ctx, tagPrm); err != nil {
h.logAndSendError(ctx, w, "could not put object tagging", reqInfo, err)
h.logAndSendError(w, "could not put object tagging", reqInfo, err)
return
}
@ -54,18 +54,17 @@ func (h *handler) PutObjectTaggingHandler(w http.ResponseWriter, r *http.Request
}
func (h *handler) GetObjectTaggingHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
settings, err := h.obj.GetBucketSettings(ctx, bktInfo)
settings, err := h.obj.GetBucketSettings(r.Context(), bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket settings", reqInfo, err)
h.logAndSendError(w, "could not get bucket settings", reqInfo, err)
return
}
@ -77,9 +76,9 @@ func (h *handler) GetObjectTaggingHandler(w http.ResponseWriter, r *http.Request
},
}
versionID, tagSet, err := h.obj.GetObjectTagging(ctx, tagPrm)
versionID, tagSet, err := h.obj.GetObjectTagging(r.Context(), tagPrm)
if err != nil {
h.logAndSendError(ctx, w, "could not get object tagging", reqInfo, err)
h.logAndSendError(w, "could not get object tagging", reqInfo, err)
return
}
@ -87,7 +86,7 @@ func (h *handler) GetObjectTaggingHandler(w http.ResponseWriter, r *http.Request
w.Header().Set(api.AmzVersionID, versionID)
}
if err = middleware.EncodeToResponse(w, encodeTagging(tagSet)); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
}
}
@ -97,7 +96,7 @@ func (h *handler) DeleteObjectTaggingHandler(w http.ResponseWriter, r *http.Requ
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
@ -108,7 +107,7 @@ func (h *handler) DeleteObjectTaggingHandler(w http.ResponseWriter, r *http.Requ
}
if err = h.obj.DeleteObjectTagging(ctx, p); err != nil {
h.logAndSendError(ctx, w, "could not delete object tagging", reqInfo, err)
h.logAndSendError(w, "could not delete object tagging", reqInfo, err)
return
}
@ -116,61 +115,58 @@ func (h *handler) DeleteObjectTaggingHandler(w http.ResponseWriter, r *http.Requ
}
func (h *handler) PutBucketTaggingHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
tagSet, err := h.readTagSet(reqInfo.Tagging)
if err != nil {
h.logAndSendError(ctx, w, "could not read tag set", reqInfo, err)
h.logAndSendError(w, "could not read tag set", reqInfo, err)
return
}
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
if err = h.obj.PutBucketTagging(ctx, bktInfo, tagSet); err != nil {
h.logAndSendError(ctx, w, "could not put object tagging", reqInfo, err)
if err = h.obj.PutBucketTagging(r.Context(), bktInfo, tagSet); err != nil {
h.logAndSendError(w, "could not put object tagging", reqInfo, err)
return
}
}
func (h *handler) GetBucketTaggingHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
tagSet, err := h.obj.GetBucketTagging(ctx, bktInfo)
tagSet, err := h.obj.GetBucketTagging(r.Context(), bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "could not get object tagging", reqInfo, err)
h.logAndSendError(w, "could not get object tagging", reqInfo, err)
return
}
if err = middleware.EncodeToResponse(w, encodeTagging(tagSet)); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
return
}
}
func (h *handler) DeleteBucketTaggingHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
if err = h.obj.DeleteBucketTagging(ctx, bktInfo); err != nil {
h.logAndSendError(ctx, w, "could not delete bucket tagging", reqInfo, err)
if err = h.obj.DeleteBucketTagging(r.Context(), bktInfo); err != nil {
h.logAndSendError(w, "could not delete bucket tagging", reqInfo, err)
return
}
w.WriteHeader(http.StatusNoContent)

View file

@ -6,7 +6,7 @@ import (
"testing"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
apiErrors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
"github.com/stretchr/testify/require"
)
@ -98,7 +98,7 @@ func TestPutObjectTaggingCheckUniqueness(t *testing.T) {
middleware.GetReqInfo(r.Context()).Tagging = tc.body
hc.Handler().PutObjectTaggingHandler(w, r)
if tc.error {
assertS3Error(t, w, apierr.GetAPIError(apierr.ErrInvalidTagKeyUniqueness))
assertS3Error(t, w, apiErrors.GetAPIError(apiErrors.ErrInvalidTagKeyUniqueness))
return
}
assertStatus(t, w, http.StatusOK)

View file

@ -8,53 +8,53 @@ import (
)
func (h *handler) SelectObjectContentHandler(w http.ResponseWriter, r *http.Request) {
h.logAndSendError(r.Context(), w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
h.logAndSendError(w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
}
func (h *handler) GetBucketEncryptionHandler(w http.ResponseWriter, r *http.Request) {
h.logAndSendError(r.Context(), w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
h.logAndSendError(w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
}
func (h *handler) GetBucketWebsiteHandler(w http.ResponseWriter, r *http.Request) {
h.logAndSendError(r.Context(), w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
h.logAndSendError(w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
}
func (h *handler) GetBucketAccelerateHandler(w http.ResponseWriter, r *http.Request) {
h.logAndSendError(r.Context(), w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
h.logAndSendError(w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
}
func (h *handler) GetBucketRequestPaymentHandler(w http.ResponseWriter, r *http.Request) {
h.logAndSendError(r.Context(), w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
h.logAndSendError(w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
}
func (h *handler) GetBucketLoggingHandler(w http.ResponseWriter, r *http.Request) {
h.logAndSendError(r.Context(), w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
h.logAndSendError(w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
}
func (h *handler) GetBucketReplicationHandler(w http.ResponseWriter, r *http.Request) {
h.logAndSendError(r.Context(), w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
h.logAndSendError(w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
}
func (h *handler) DeleteBucketWebsiteHandler(w http.ResponseWriter, r *http.Request) {
h.logAndSendError(r.Context(), w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
h.logAndSendError(w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
}
func (h *handler) ListenBucketNotificationHandler(w http.ResponseWriter, r *http.Request) {
h.logAndSendError(r.Context(), w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
h.logAndSendError(w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
}
func (h *handler) ListObjectsV2MHandler(w http.ResponseWriter, r *http.Request) {
h.logAndSendError(r.Context(), w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
h.logAndSendError(w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
}
func (h *handler) PutBucketEncryptionHandler(w http.ResponseWriter, r *http.Request) {
h.logAndSendError(r.Context(), w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
h.logAndSendError(w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
}
func (h *handler) PutBucketNotificationHandler(w http.ResponseWriter, r *http.Request) {
h.logAndSendError(r.Context(), w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
h.logAndSendError(w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
}
func (h *handler) GetBucketNotificationHandler(w http.ResponseWriter, r *http.Request) {
h.logAndSendError(r.Context(), w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
h.logAndSendError(w, "not implemented", middleware.GetReqInfo(r.Context()), errors.GetAPIError(errors.ErrNotImplemented))
}

View file

@ -10,11 +10,13 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
s3errors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
frosterrors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/frostfs/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/logs"
cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
)
@ -26,14 +28,15 @@ func (h *handler) reqLogger(ctx context.Context) *zap.Logger {
return h.log
}
func (h *handler) logAndSendError(ctx context.Context, w http.ResponseWriter, logText string, reqInfo *middleware.ReqInfo, err error, additional ...zap.Field) {
func (h *handler) logAndSendError(w http.ResponseWriter, logText string, reqInfo *middleware.ReqInfo, err error, additional ...zap.Field) {
err = handleDeleteMarker(w, err)
if code, wrErr := middleware.WriteErrorResponse(w, reqInfo, apierr.TransformToS3Error(err)); wrErr != nil {
if code, wrErr := middleware.WriteErrorResponse(w, reqInfo, transformToS3Error(err)); wrErr != nil {
additional = append(additional, zap.NamedError("write_response_error", wrErr))
} else {
additional = append(additional, zap.Int("status", code))
}
fields := []zap.Field{
zap.String("request_id", reqInfo.RequestID),
zap.String("method", reqInfo.API),
zap.String("bucket", reqInfo.BucketName),
zap.String("object", reqInfo.ObjectName),
@ -41,7 +44,10 @@ func (h *handler) logAndSendError(ctx context.Context, w http.ResponseWriter, lo
zap.String("user", reqInfo.User),
zap.Error(err)}
fields = append(fields, additional...)
h.reqLogger(ctx).Error(logs.RequestFailed, fields...)
if traceID, err := trace.TraceIDFromHex(reqInfo.TraceID); err == nil && traceID.IsValid() {
fields = append(fields, zap.String("trace_id", reqInfo.TraceID))
}
h.log.Error(logs.RequestFailed, fields...) // consider using h.reqLogger (it requires accept context.Context or http.Request)
}
func handleDeleteMarker(w http.ResponseWriter, err error) error {
@ -51,7 +57,25 @@ func handleDeleteMarker(w http.ResponseWriter, err error) error {
}
w.Header().Set(api.AmzDeleteMarker, "true")
return fmt.Errorf("%w: %s", apierr.GetAPIError(target.ErrorCode), err)
return fmt.Errorf("%w: %s", s3errors.GetAPIError(target.ErrorCode), err)
}
func transformToS3Error(err error) error {
err = frosterrors.UnwrapErr(err) // this wouldn't work with errors.Join
if _, ok := err.(s3errors.Error); ok {
return err
}
if errors.Is(err, layer.ErrAccessDenied) ||
errors.Is(err, layer.ErrNodeAccessDenied) {
return s3errors.GetAPIError(s3errors.ErrAccessDenied)
}
if errors.Is(err, layer.ErrGatewayTimeout) {
return s3errors.GetAPIError(s3errors.ErrGatewayTimeout)
}
return s3errors.GetAPIError(s3errors.ErrInternalError)
}
func (h *handler) ResolveBucket(ctx context.Context, bucket string) (*data.BucketInfo, error) {
@ -82,19 +106,6 @@ func (h *handler) getBucketAndCheckOwner(r *http.Request, bucket string, header
return bktInfo, checkOwner(bktInfo, expected)
}
func (h *handler) getPutPayloadSize(r *http.Request) uint64 {
decodeContentSize := r.Header.Get(api.AmzDecodedContentLength)
decodedSize, err := strconv.Atoi(decodeContentSize)
if err != nil {
if r.ContentLength >= 0 {
return uint64(r.ContentLength)
}
return 0
}
return uint64(decodedSize)
}
func parseRange(s string) (*layer.RangeParams, error) {
if s == "" {
return nil, nil
@ -103,26 +114,26 @@ func parseRange(s string) (*layer.RangeParams, error) {
prefix := "bytes="
if !strings.HasPrefix(s, prefix) {
return nil, apierr.GetAPIError(apierr.ErrInvalidRange)
return nil, s3errors.GetAPIError(s3errors.ErrInvalidRange)
}
s = strings.TrimPrefix(s, prefix)
valuesStr := strings.Split(s, "-")
if len(valuesStr) != 2 {
return nil, apierr.GetAPIError(apierr.ErrInvalidRange)
return nil, s3errors.GetAPIError(s3errors.ErrInvalidRange)
}
values := make([]uint64, 0, len(valuesStr))
for _, v := range valuesStr {
num, err := strconv.ParseUint(v, 10, 64)
if err != nil {
return nil, apierr.GetAPIError(apierr.ErrInvalidRange)
return nil, s3errors.GetAPIError(s3errors.ErrInvalidRange)
}
values = append(values, num)
}
if values[0] > values[1] {
return nil, apierr.GetAPIError(apierr.ErrInvalidRange)
return nil, s3errors.GetAPIError(s3errors.ErrInvalidRange)
}
return &layer.RangeParams{

View file

@ -1 +1,64 @@
package handler
import (
"errors"
"fmt"
"testing"
s3errors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer"
"github.com/stretchr/testify/require"
)
func TestTransformS3Errors(t *testing.T) {
for _, tc := range []struct {
name string
err error
expected s3errors.ErrorCode
}{
{
name: "simple std error to internal error",
err: errors.New("some error"),
expected: s3errors.ErrInternalError,
},
{
name: "layer access denied error to s3 access denied error",
err: layer.ErrAccessDenied,
expected: s3errors.ErrAccessDenied,
},
{
name: "wrapped layer access denied error to s3 access denied error",
err: fmt.Errorf("wrap: %w", layer.ErrAccessDenied),
expected: s3errors.ErrAccessDenied,
},
{
name: "layer node access denied error to s3 access denied error",
err: layer.ErrNodeAccessDenied,
expected: s3errors.ErrAccessDenied,
},
{
name: "layer gateway timeout error to s3 gateway timeout error",
err: layer.ErrGatewayTimeout,
expected: s3errors.ErrGatewayTimeout,
},
{
name: "s3 error to s3 error",
err: s3errors.GetAPIError(s3errors.ErrInvalidPart),
expected: s3errors.ErrInvalidPart,
},
{
name: "wrapped s3 error to s3 error",
err: fmt.Errorf("wrap: %w", s3errors.GetAPIError(s3errors.ErrInvalidPart)),
expected: s3errors.ErrInvalidPart,
},
} {
t.Run(tc.name, func(t *testing.T) {
err := transformToS3Error(tc.err)
s3err, ok := err.(s3errors.Error)
require.True(t, ok, "error must be s3 error")
require.Equalf(t, tc.expected, s3err.ErrCode,
"expected: '%s', got: '%s'",
s3errors.GetAPIError(tc.expected).Code, s3errors.GetAPIError(s3err.ErrCode).Code)
})
}
}

View file

@ -10,29 +10,28 @@ import (
)
func (h *handler) PutBucketVersioningHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
configuration := new(VersioningConfiguration)
if err := h.cfg.NewXMLDecoder(r.Body).Decode(configuration); err != nil {
h.logAndSendError(ctx, w, "couldn't decode versioning configuration", reqInfo, errors.GetAPIError(errors.ErrIllegalVersioningConfigurationException))
h.logAndSendError(w, "couldn't decode versioning configuration", reqInfo, errors.GetAPIError(errors.ErrIllegalVersioningConfigurationException))
return
}
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
settings, err := h.obj.GetBucketSettings(ctx, bktInfo)
settings, err := h.obj.GetBucketSettings(r.Context(), bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "couldn't get bucket settings", reqInfo, err)
h.logAndSendError(w, "couldn't get bucket settings", reqInfo, err)
return
}
if configuration.Status != data.VersioningEnabled && configuration.Status != data.VersioningSuspended {
h.logAndSendError(ctx, w, "invalid versioning configuration", reqInfo, errors.GetAPIError(errors.ErrMalformedXML))
h.logAndSendError(w, "invalid versioning configuration", reqInfo, errors.GetAPIError(errors.ErrMalformedXML))
return
}
@ -46,34 +45,33 @@ func (h *handler) PutBucketVersioningHandler(w http.ResponseWriter, r *http.Requ
}
if p.Settings.VersioningSuspended() && bktInfo.ObjectLockEnabled {
h.logAndSendError(ctx, w, "couldn't suspend bucket versioning", reqInfo, errors.GetAPIError(errors.ErrObjectLockConfigurationVersioningCannotBeChanged))
h.logAndSendError(w, "couldn't suspend bucket versioning", reqInfo, errors.GetAPIError(errors.ErrObjectLockConfigurationVersioningCannotBeChanged))
return
}
if err = h.obj.PutBucketSettings(ctx, p); err != nil {
h.logAndSendError(ctx, w, "couldn't put update versioning settings", reqInfo, err)
if err = h.obj.PutBucketSettings(r.Context(), p); err != nil {
h.logAndSendError(w, "couldn't put update versioning settings", reqInfo, err)
}
}
// GetBucketVersioningHandler implements bucket versioning getter handler.
func (h *handler) GetBucketVersioningHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
reqInfo := middleware.GetReqInfo(ctx)
reqInfo := middleware.GetReqInfo(r.Context())
bktInfo, err := h.getBucketAndCheckOwner(r, reqInfo.BucketName)
if err != nil {
h.logAndSendError(ctx, w, "could not get bucket info", reqInfo, err)
h.logAndSendError(w, "could not get bucket info", reqInfo, err)
return
}
settings, err := h.obj.GetBucketSettings(ctx, bktInfo)
settings, err := h.obj.GetBucketSettings(r.Context(), bktInfo)
if err != nil {
h.logAndSendError(ctx, w, "couldn't get version settings", reqInfo, err)
h.logAndSendError(w, "couldn't get version settings", reqInfo, err)
return
}
if err = middleware.EncodeToResponse(w, formVersioningConfiguration(settings)); err != nil {
h.logAndSendError(ctx, w, "something went wrong", reqInfo, err)
h.logAndSendError(w, "something went wrong", reqInfo, err)
}
}

View file

@ -95,6 +95,7 @@ const (
DefaultLocationConstraint = "default"
StreamingContentSHA256 = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD"
StreamingContentV4aSHA256 = "STREAMING-AWS4-ECDSA-P256-SHA256-PAYLOAD"
DefaultStorageClass = "STANDARD"
)
@ -125,7 +126,9 @@ var SystemMetadata = map[string]struct{}{
ContentLanguage: {},
}
func IsSignedStreamingV4(r *http.Request) bool {
return r.Header.Get(AmzContentSha256) == StreamingContentSHA256 &&
func IsSignedStreamingV4(r *http.Request) (string, bool) {
shaHeader := r.Header.Get(AmzContentSha256)
return shaHeader,
(shaHeader == StreamingContentSHA256 || shaHeader == StreamingContentV4aSHA256) &&
r.Method == http.MethodPut
}

View file

@ -5,7 +5,6 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/logs"
cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap"
oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user"
"go.uber.org/zap"
@ -20,7 +19,6 @@ type Cache struct {
bucketCache *cache.BucketCache
systemCache *cache.SystemCache
accessCache *cache.AccessControlCache
networkInfoCache *cache.NetworkInfoCache
}
// CachesConfig contains params for caches.
@ -33,7 +31,6 @@ type CachesConfig struct {
Buckets *cache.Config
System *cache.Config
AccessControl *cache.Config
NetworkInfo *cache.NetworkInfoCacheConfig
}
// DefaultCachesConfigs returns filled configs.
@ -47,7 +44,6 @@ func DefaultCachesConfigs(logger *zap.Logger) *CachesConfig {
Buckets: cache.DefaultBucketConfig(logger),
System: cache.DefaultSystemConfig(logger),
AccessControl: cache.DefaultAccessControlConfig(logger),
NetworkInfo: cache.DefaultNetworkInfoConfig(logger),
}
}
@ -61,7 +57,6 @@ func NewCache(cfg *CachesConfig) *Cache {
bucketCache: cache.NewBucketCache(cfg.Buckets),
systemCache: cache.NewSystemCache(cfg.System),
accessCache: cache.NewAccessControlCache(cfg.AccessControl),
networkInfoCache: cache.NewNetworkInfoCache(cfg.NetworkInfo),
}
}
@ -288,13 +283,3 @@ func (c *Cache) PutLifecycleConfiguration(owner user.ID, bkt *data.BucketInfo, c
func (c *Cache) DeleteLifecycleConfiguration(bktInfo *data.BucketInfo) {
c.systemCache.Delete(bktInfo.LifecycleConfigurationObjectName())
}
func (c *Cache) GetNetworkInfo() *netmap.NetworkInfo {
return c.networkInfoCache.Get()
}
func (c *Cache) PutNetworkInfo(info netmap.NetworkInfo) {
if err := c.networkInfoCache.Put(info); err != nil {
c.logger.Warn(logs.CouldntCacheNetworkInfo, zap.Error(err))
}
}

View file

@ -6,8 +6,7 @@ import (
"fmt"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer/tree"
s3errors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
)
func (n *Layer) GetObjectTaggingAndLock(ctx context.Context, objVersion *data.ObjectVersion, nodeVersion *data.NodeVersion) (map[string]string, data.LockInfo, error) {
@ -30,8 +29,8 @@ func (n *Layer) GetObjectTaggingAndLock(ctx context.Context, objVersion *data.Ob
tags, lockInfo, err = n.treeService.GetObjectTaggingAndLock(ctx, objVersion.BktInfo, nodeVersion)
if err != nil {
if errors.Is(err, tree.ErrNodeNotFound) {
return nil, data.LockInfo{}, fmt.Errorf("%w: %s", apierr.GetAPIError(apierr.ErrNoSuchKey), err.Error())
if errors.Is(err, ErrNodeNotFound) {
return nil, data.LockInfo{}, fmt.Errorf("%w: %s", s3errors.GetAPIError(s3errors.ErrNoSuchKey), err.Error())
}
return nil, data.LockInfo{}, err
}

View file

@ -7,8 +7,7 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
apierr "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer/frostfs"
s3errors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/logs"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client"
@ -21,7 +20,7 @@ const (
AttributeLockEnabled = "LockEnabled"
)
func (n *Layer) containerInfo(ctx context.Context, prm frostfs.PrmContainer) (*data.BucketInfo, error) {
func (n *Layer) containerInfo(ctx context.Context, prm PrmContainer) (*data.BucketInfo, error) {
var (
err error
res *container.Container
@ -38,7 +37,7 @@ func (n *Layer) containerInfo(ctx context.Context, prm frostfs.PrmContainer) (*d
res, err = n.frostFS.Container(ctx, prm)
if err != nil {
if client.IsErrContainerNotFound(err) {
return nil, fmt.Errorf("%w: %s", apierr.GetAPIError(apierr.ErrNoSuchBucket), err.Error())
return nil, fmt.Errorf("%w: %s", s3errors.GetAPIError(s3errors.ErrNoSuchBucket), err.Error())
}
return nil, fmt.Errorf("get frostfs container: %w", err)
}
@ -78,7 +77,7 @@ func (n *Layer) containerInfo(ctx context.Context, prm frostfs.PrmContainer) (*d
func (n *Layer) containerList(ctx context.Context) ([]*data.BucketInfo, error) {
stoken := n.SessionTokenForRead(ctx)
prm := frostfs.PrmUserContainers{
prm := PrmUserContainers{
UserID: n.BearerOwner(ctx),
SessionToken: stoken,
}
@ -91,7 +90,7 @@ func (n *Layer) containerList(ctx context.Context) ([]*data.BucketInfo, error) {
list := make([]*data.BucketInfo, 0, len(res))
for i := range res {
getPrm := frostfs.PrmContainer{
getPrm := PrmContainer{
ContainerID: res[i],
SessionToken: stoken,
}
@ -133,7 +132,7 @@ func (n *Layer) createContainer(ctx context.Context, p *CreateBucketParams) (*da
})
}
res, err := n.frostFS.CreateContainer(ctx, frostfs.PrmContainerCreate{
res, err := n.frostFS.CreateContainer(ctx, PrmContainerCreate{
Creator: bktInfo.Owner,
Policy: p.Policy,
Name: p.Name,

Some files were not shown because too many files have changed in this diff Show more