diff --git a/api/auth/center_test.go b/api/auth/center_test.go index acca1ae..e2feb73 100644 --- a/api/auth/center_test.go +++ b/api/auth/center_test.go @@ -1,12 +1,31 @@ package auth import ( + "bytes" + "context" + "fmt" + "mime/multipart" + "net/http" + "net/http/httptest" + "net/url" "strings" "testing" "time" + v4 "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/auth/signer/v4" + "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/cache" "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/errors" + "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/accessbox" + "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/creds/tokens" + frostfsErrors "git.frostfs.info/TrueCloudLab/frostfs-s3-gw/internal/frostfs/errors" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/bearer" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object" + oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id" + oidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id/test" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" ) func TestAuthHeaderParse(t *testing.T) { @@ -123,6 +142,11 @@ func TestCheckFormatContentSHA256(t *testing.T) { hash: "ed7002b439e9ac845f22357d822bac1444730fbdb6016d3ec9432297b9ec9f7s", error: defaultErr, }, + { + name: "invalid hash format: hash size", + hash: "5aadb45520dcd8726b2822a7a78bb53d794f557199d5d4abdedd2c55a4bd6ca73607605c558de3db80c8e86c3196484566163ed1327e82e8b6757d1932113cb8", + error: defaultErr, + }, { name: "unsigned payload", hash: "UNSIGNED-PAYLOAD", @@ -145,3 +169,466 @@ func TestCheckFormatContentSHA256(t *testing.T) { }) } } + +type frostFSMock struct { + objects map[oid.Address]*object.Object +} + +func newFrostFSMock() *frostFSMock { + return &frostFSMock{ + objects: map[oid.Address]*object.Object{}, + } +} + +func (f *frostFSMock) GetCredsObject(_ context.Context, address oid.Address) (*object.Object, error) { + obj, ok := f.objects[address] + if !ok { + return nil, fmt.Errorf("not found") + } + + return obj, nil +} + +func (f *frostFSMock) CreateObject(context.Context, tokens.PrmObjectCreate) (oid.ID, error) { + return oid.ID{}, fmt.Errorf("the mock method is not implemented") +} + +func TestAuthenticate(t *testing.T) { + key, err := keys.NewPrivateKey() + require.NoError(t, err) + + cfg := &cache.Config{ + Size: 10, + Lifetime: 24 * time.Hour, + Logger: zaptest.NewLogger(t), + } + + gateData := []*accessbox.GateData{{ + BearerToken: &bearer.Token{}, + GateKey: key.PublicKey(), + }} + + accessBox, secret, err := accessbox.PackTokens(gateData, []byte("secret")) + require.NoError(t, err) + data, err := accessBox.Marshal() + require.NoError(t, err) + + var obj object.Object + obj.SetPayload(data) + addr := oidtest.Address() + obj.SetContainerID(addr.Container()) + obj.SetID(addr.Object()) + + frostfs := newFrostFSMock() + frostfs.objects[addr] = &obj + + accessKeyID := addr.Container().String() + "0" + addr.Object().String() + + awsCreds := credentials.NewStaticCredentials(accessKeyID, secret.SecretKey, "") + defaultSigner := v4.NewSigner(awsCreds) + + service, region := "s3", "default" + invalidValue := "invalid-value" + + bigConfig := tokens.Config{ + FrostFS: frostfs, + Key: key, + CacheConfig: cfg, + } + + for _, tc := range []struct { + name string + prefixes []string + request *http.Request + err bool + errCode errors.ErrorCode + }{ + { + name: "valid sign", + prefixes: []string{addr.Container().String()}, + request: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", nil) + _, err = defaultSigner.Sign(r, nil, service, region, time.Now()) + require.NoError(t, err) + return r + }(), + }, + { + name: "no authorization header", + request: func() *http.Request { + return httptest.NewRequest(http.MethodPost, "/", nil) + }(), + err: true, + }, + { + name: "invalid authorization header", + request: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", nil) + r.Header.Set(AuthorizationHdr, invalidValue) + return r + }(), + err: true, + errCode: errors.ErrAuthorizationHeaderMalformed, + }, + { + name: "invalid access key id format", + request: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", nil) + signer := v4.NewSigner(credentials.NewStaticCredentials(addr.Object().String(), secret.SecretKey, "")) + _, err = signer.Sign(r, nil, service, region, time.Now()) + require.NoError(t, err) + return r + }(), + err: true, + errCode: errors.ErrInvalidAccessKeyID, + }, + { + name: "not allowed access key id", + prefixes: []string{addr.Object().String()}, + request: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", nil) + _, err = defaultSigner.Sign(r, nil, service, region, time.Now()) + require.NoError(t, err) + return r + }(), + err: true, + errCode: errors.ErrAccessDenied, + }, + { + name: "invalid access key id value", + request: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", nil) + signer := v4.NewSigner(credentials.NewStaticCredentials(accessKeyID[:len(accessKeyID)-4], secret.SecretKey, "")) + _, err = signer.Sign(r, nil, service, region, time.Now()) + require.NoError(t, err) + return r + }(), + err: true, + errCode: errors.ErrInvalidAccessKeyID, + }, + { + name: "unknown access key id", + request: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", nil) + signer := v4.NewSigner(credentials.NewStaticCredentials(addr.Object().String()+"0"+addr.Container().String(), secret.SecretKey, "")) + _, err = signer.Sign(r, nil, service, region, time.Now()) + require.NoError(t, err) + return r + }(), + err: true, + }, + { + name: "invalid signature", + request: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", nil) + signer := v4.NewSigner(credentials.NewStaticCredentials(accessKeyID, "secret", "")) + _, err = signer.Sign(r, nil, service, region, time.Now()) + require.NoError(t, err) + return r + }(), + err: true, + errCode: errors.ErrSignatureDoesNotMatch, + }, + { + name: "invalid signature - AmzDate", + prefixes: []string{addr.Container().String()}, + request: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", nil) + _, err = defaultSigner.Sign(r, nil, service, region, time.Now()) + r.Header.Set(AmzDate, invalidValue) + require.NoError(t, err) + return r + }(), + err: true, + }, + { + name: "invalid AmzContentSHA256", + prefixes: []string{addr.Container().String()}, + request: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", nil) + _, err = defaultSigner.Sign(r, nil, service, region, time.Now()) + r.Header.Set(AmzContentSHA256, invalidValue) + require.NoError(t, err) + return r + }(), + err: true, + }, + { + name: "valid presign", + request: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", nil) + _, err = defaultSigner.Presign(r, nil, service, region, time.Minute, time.Now()) + require.NoError(t, err) + return r + }(), + }, + { + name: "presign, bad X-Amz-Credential", + request: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", nil) + query := url.Values{ + AmzAlgorithm: []string{"AWS4-HMAC-SHA256"}, + AmzCredential: []string{invalidValue}, + } + r.URL.RawQuery = query.Encode() + return r + }(), + err: true, + }, + { + name: "presign, bad X-Amz-Expires", + request: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", nil) + _, err = defaultSigner.Presign(r, nil, service, region, time.Minute, time.Now()) + queryParams := r.URL.Query() + queryParams.Set("X-Amz-Expires", invalidValue) + r.URL.RawQuery = queryParams.Encode() + require.NoError(t, err) + return r + }(), + err: true, + }, + { + name: "presign, expired", + request: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", nil) + _, err = defaultSigner.Presign(r, nil, service, region, time.Minute, time.Now().Add(-time.Minute)) + require.NoError(t, err) + return r + }(), + err: true, + errCode: errors.ErrExpiredPresignRequest, + }, + { + name: "presign, signature from future", + request: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", nil) + _, err = defaultSigner.Presign(r, nil, service, region, time.Minute, time.Now().Add(time.Minute)) + require.NoError(t, err) + return r + }(), + err: true, + errCode: errors.ErrBadRequest, + }, + } { + t.Run(tc.name, func(t *testing.T) { + creds := tokens.New(bigConfig) + cntr := New(creds, tc.prefixes) + box, err := cntr.Authenticate(tc.request) + + if tc.err { + require.Error(t, err) + if tc.errCode > 0 { + err = frostfsErrors.UnwrapErr(err) + require.Equal(t, errors.GetAPIError(tc.errCode), err) + } + } else { + require.NoError(t, err) + require.Equal(t, accessKeyID, box.AuthHeaders.AccessKeyID) + require.Equal(t, region, box.AuthHeaders.Region) + require.Equal(t, secret.SecretKey, box.AccessBox.Gate.SecretKey) + } + }) + } +} + +func TestHTTPPostAuthenticate(t *testing.T) { + const ( + policyBase64 = "eyAiZXhwaXJhdGlvbiI6ICIyMDA3LTEyLTAxVDEyOjAwOjAwLjAwMFoiLAogICJjb25kaXRpb25zIjogWwogICAgeyJhY2wiOiAicHVibGljLXJlYWQiIH0sCiAgICB7ImJ1Y2tldCI6ICJqb2huc21pdGgiIH0sCiAgICBbInN0YXJ0cy13aXRoIiwgIiRrZXkiLCAidXNlci9lcmljLyJdLAogIF0KfQ==" + invalidValue = "invalid-value" + defaultFieldName = "file" + service = "s3" + region = "default" + ) + + key, err := keys.NewPrivateKey() + require.NoError(t, err) + + cfg := &cache.Config{ + Size: 10, + Lifetime: 24 * time.Hour, + Logger: zaptest.NewLogger(t), + } + + gateData := []*accessbox.GateData{{ + BearerToken: &bearer.Token{}, + GateKey: key.PublicKey(), + }} + + accessBox, secret, err := accessbox.PackTokens(gateData, []byte("secret")) + require.NoError(t, err) + data, err := accessBox.Marshal() + require.NoError(t, err) + + var obj object.Object + obj.SetPayload(data) + addr := oidtest.Address() + obj.SetContainerID(addr.Container()) + obj.SetID(addr.Object()) + + frostfs := newFrostFSMock() + frostfs.objects[addr] = &obj + + accessKeyID := addr.Container().String() + "0" + addr.Object().String() + invalidAccessKeyID := oidtest.Address().String() + "0" + oidtest.Address().Object().String() + + timeToSign := time.Now() + timeToSignStr := timeToSign.Format("20060102T150405Z") + + bigConfig := tokens.Config{ + FrostFS: frostfs, + Key: key, + CacheConfig: cfg, + } + + for _, tc := range []struct { + name string + prefixes []string + request *http.Request + err bool + errCode errors.ErrorCode + }{ + { + name: "HTTP POST valid", + request: func() *http.Request { + creds := getCredsStr(accessKeyID, timeToSignStr, region, service) + sign := signStr(secret.SecretKey, service, region, timeToSign, policyBase64) + + return getRequestWithMultipartForm(t, policyBase64, creds, timeToSignStr, sign, defaultFieldName) + }(), + }, + { + name: "HTTP POST valid with custom field name", + request: func() *http.Request { + creds := getCredsStr(accessKeyID, timeToSignStr, region, service) + sign := signStr(secret.SecretKey, service, region, timeToSign, policyBase64) + + return getRequestWithMultipartForm(t, policyBase64, creds, timeToSignStr, sign, "files") + }(), + }, + { + name: "HTTP POST valid with field name with a capital letter", + request: func() *http.Request { + creds := getCredsStr(accessKeyID, timeToSignStr, region, service) + sign := signStr(secret.SecretKey, service, region, timeToSign, policyBase64) + + return getRequestWithMultipartForm(t, policyBase64, creds, timeToSignStr, sign, "File") + }(), + }, + { + name: "HTTP POST invalid multipart form", + request: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(ContentTypeHdr, "multipart/form-data") + + return req + }(), + err: true, + errCode: errors.ErrInvalidArgument, + }, + { + name: "HTTP POST invalid signature date time", + request: func() *http.Request { + creds := getCredsStr(accessKeyID, timeToSignStr, region, service) + sign := signStr(secret.SecretKey, service, region, timeToSign, policyBase64) + + return getRequestWithMultipartForm(t, policyBase64, creds, invalidValue, sign, defaultFieldName) + }(), + err: true, + }, + { + name: "HTTP POST invalid creds", + request: func() *http.Request { + sign := signStr(secret.SecretKey, service, region, timeToSign, policyBase64) + + return getRequestWithMultipartForm(t, policyBase64, invalidValue, timeToSignStr, sign, defaultFieldName) + }(), + err: true, + errCode: errors.ErrAuthorizationHeaderMalformed, + }, + { + name: "HTTP POST missing policy", + request: func() *http.Request { + creds := getCredsStr(accessKeyID, timeToSignStr, region, service) + sign := signStr(secret.SecretKey, service, region, timeToSign, policyBase64) + + return getRequestWithMultipartForm(t, "", creds, timeToSignStr, sign, defaultFieldName) + }(), + err: true, + }, + { + name: "HTTP POST invalid accessKeyId", + request: func() *http.Request { + creds := getCredsStr(invalidValue, timeToSignStr, region, service) + sign := signStr(secret.SecretKey, service, region, timeToSign, policyBase64) + + return getRequestWithMultipartForm(t, policyBase64, creds, timeToSignStr, sign, defaultFieldName) + }(), + err: true, + }, + { + name: "HTTP POST invalid accessKeyId - a non-existent box", + request: func() *http.Request { + creds := getCredsStr(invalidAccessKeyID, timeToSignStr, region, service) + sign := signStr(secret.SecretKey, service, region, timeToSign, policyBase64) + + return getRequestWithMultipartForm(t, policyBase64, creds, timeToSignStr, sign, defaultFieldName) + }(), + err: true, + }, + { + name: "HTTP POST invalid signature", + request: func() *http.Request { + creds := getCredsStr(accessKeyID, timeToSignStr, region, service) + sign := signStr(secret.SecretKey, service, region, timeToSign, invalidValue) + + return getRequestWithMultipartForm(t, policyBase64, creds, timeToSignStr, sign, defaultFieldName) + }(), + err: true, + errCode: errors.ErrSignatureDoesNotMatch, + }, + } { + t.Run(tc.name, func(t *testing.T) { + creds := tokens.New(bigConfig) + cntr := New(creds, tc.prefixes) + box, err := cntr.Authenticate(tc.request) + + if tc.err { + require.Error(t, err) + if tc.errCode > 0 { + err = frostfsErrors.UnwrapErr(err) + require.Equal(t, errors.GetAPIError(tc.errCode), err) + } + } else { + require.NoError(t, err) + require.Equal(t, secret.SecretKey, box.AccessBox.Gate.SecretKey) + } + }) + } +} + +func getCredsStr(accessKeyID, timeToSign, region, service string) string { + return accessKeyID + "/" + timeToSign + "/" + region + "/" + service + "/aws4_request" +} + +func getRequestWithMultipartForm(t *testing.T, policy, creds, date, sign, fieldName string) *http.Request { + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + defer writer.Close() + + err := writer.WriteField("Policy", policy) + require.NoError(t, err) + err = writer.WriteField(AmzCredential, creds) + require.NoError(t, err) + err = writer.WriteField(AmzDate, date) + require.NoError(t, err) + err = writer.WriteField(AmzSignature, sign) + require.NoError(t, err) + _, err = writer.CreateFormFile(fieldName, "test.txt") + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/", body) + req.Header.Set(ContentTypeHdr, writer.FormDataContentType()) + + return req +}