From 81bfd2c1cb0864591fdd5e8716a2c2628b5a8652 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 24 Apr 2019 19:52:58 -0700 Subject: [PATCH] Add tests for AWS provisioner Fixes #68 --- authority/provisioner/aws.go | 21 +- authority/provisioner/aws_test.go | 477 +++++++++++++++++++--------- authority/provisioner/utils_test.go | 199 ++++++++++++ 3 files changed, 542 insertions(+), 155 deletions(-) diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index 612a2498..4a017069 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -347,16 +347,6 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { return nil, errors.Wrap(err, "error verifying claims") } - // According to "rfc7519 JSON Web Token" acceptable skew should be no - // more than a few minutes. - if err = payload.ValidateWithLeeway(jose.Expected{ - Issuer: awsIssuer, - Audience: []string{p.GetID()}, - Time: time.Now().UTC(), - }, time.Minute); err != nil { - return nil, errors.Wrapf(err, "invalid token") - } - // Validate identity document signature if err := p.checkSignature(payload.Amazon.Document, payload.Amazon.Signature); err != nil { return nil, err @@ -378,6 +368,17 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { return nil, errors.New("identity document region cannot be empty") } + // According to "rfc7519 JSON Web Token" acceptable skew should be no + // more than a few minutes. + if err = payload.ValidateWithLeeway(jose.Expected{ + Issuer: awsIssuer, + Subject: doc.InstanceID, + Audience: []string{p.GetID()}, + Time: time.Now().UTC(), + }, time.Minute); err != nil { + return nil, errors.Wrapf(err, "invalid token") + } + // validate accounts if len(p.Accounts) > 0 { var found bool diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 1b998b7b..429e583e 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -1,181 +1,368 @@ -// +build ignore - package provisioner import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha256" "crypto/x509" - "encoding/base64" + "encoding/hex" "encoding/pem" - "net/http" - "net/http/httptest" + "fmt" + "strings" "testing" + "time" - "github.com/fullsailor/pkcs7" "github.com/smallstep/assert" + "github.com/smallstep/cli/jose" ) -var rsaCert = `-----BEGIN CERTIFICATE----- -MIIDIjCCAougAwIBAgIJAKnL4UEDMN/FMA0GCSqGSIb3DQEBBQUAMGoxCzAJBgNV -BAYTAlVTMRMwEQYDVQQIEwpXYXNoaW5ndG9uMRAwDgYDVQQHEwdTZWF0dGxlMRgw -FgYDVQQKEw9BbWF6b24uY29tIEluYy4xGjAYBgNVBAMTEWVjMi5hbWF6b25hd3Mu -Y29tMB4XDTE0MDYwNTE0MjgwMloXDTI0MDYwNTE0MjgwMlowajELMAkGA1UEBhMC -VVMxEzARBgNVBAgTCldhc2hpbmd0b24xEDAOBgNVBAcTB1NlYXR0bGUxGDAWBgNV -BAoTD0FtYXpvbi5jb20gSW5jLjEaMBgGA1UEAxMRZWMyLmFtYXpvbmF3cy5jb20w -gZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAIe9GN//SRK2knbjySG0ho3yqQM3 -e2TDhWO8D2e8+XZqck754gFSo99AbT2RmXClambI7xsYHZFapbELC4H91ycihvrD -jbST1ZjkLQgga0NE1q43eS68ZeTDccScXQSNivSlzJZS8HJZjgqzBlXjZftjtdJL -XeE4hwvo0sD4f3j9AgMBAAGjgc8wgcwwHQYDVR0OBBYEFCXWzAgVyrbwnFncFFIs -77VBdlE4MIGcBgNVHSMEgZQwgZGAFCXWzAgVyrbwnFncFFIs77VBdlE4oW6kbDBq -MQswCQYDVQQGEwJVUzETMBEGA1UECBMKV2FzaGluZ3RvbjEQMA4GA1UEBxMHU2Vh -dHRsZTEYMBYGA1UEChMPQW1hem9uLmNvbSBJbmMuMRowGAYDVQQDExFlYzIuYW1h -em9uYXdzLmNvbYIJAKnL4UEDMN/FMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEF -BQADgYEAFYcz1OgEhQBXIwIdsgCOS8vEtiJYF+j9uO6jz7VOmJqO+pRlAbRlvY8T -C1haGgSI/A1uZUKs/Zfnph0oEI0/hu1IIJ/SKBDtN5lvmZ/IzbOPIJWirlsllQIQ -7zvWbGd9c9+Rm3p04oTvhup99la7kZqevJK0QRdD/6NpCKsqP/0= ------END CERTIFICATE-----` - -var rsaSig = `eYko51V+DBTE/pLMwqH9tekcIGdIL6jGkgmh0faKQbHUrWVfaw2ffx032iqbEkvbqIMx0I4ewl+Cq5IejPQ5ax4+Nb9gSoMHS8VCjAUkpj9dUXPG2DEvTHukpvUTy8fGn1a/3LS5GdEPnDVkMj2QDHDBGskH4eA46x9c069xeyE=` - -var dsaCert = `-----BEGIN CERTIFICATE----- -MIIC7TCCAq0CCQCWukjZ5V4aZzAJBgcqhkjOOAQDMFwxCzAJBgNVBAYTAlVTMRkw -FwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdTZWF0dGxlMSAwHgYD -VQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzAeFw0xMjAxMDUxMjU2MTJaFw0z -ODAxMDUxMjU2MTJaMFwxCzAJBgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9u -IFN0YXRlMRAwDgYDVQQHEwdTZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNl -cnZpY2VzIExMQzCCAbcwggEsBgcqhkjOOAQBMIIBHwKBgQCjkvcS2bb1VQ4yt/5e -ih5OO6kK/n1Lzllr7D8ZwtQP8fOEpp5E2ng+D6Ud1Z1gYipr58Kj3nssSNpI6bX3 -VyIQzK7wLclnd/YozqNNmgIyZecN7EglK9ITHJLP+x8FtUpt3QbyYXJdmVMegN6P -hviYt5JH/nYl4hh3Pa1HJdskgQIVALVJ3ER11+Ko4tP6nwvHwh6+ERYRAoGBAI1j -k+tkqMVHuAFcvAGKocTgsjJem6/5qomzJuKDmbJNu9Qxw3rAotXau8Qe+MBcJl/U -hhy1KHVpCGl9fueQ2s6IL0CaO/buycU1CiYQk40KNHCcHfNiZbdlx1E9rpUp7bnF -lRa2v1ntMX3caRVDdbtPEWmdxSCYsYFDk4mZrOLBA4GEAAKBgEbmeve5f8LIE/Gf -MNmP9CM5eovQOGx5ho8WqD+aTebs+k2tn92BBPqeZqpWRa5P/+jrdKml1qx4llHW -MXrs3IgIb6+hUIB+S8dz8/mmO0bpr76RoZVCXYab2CZedFut7qc3WUH9+EUAH5mw -vSeDCOUMYQR7R9LINYwouHIziqQYMAkGByqGSM44BAMDLwAwLAIUWXBlk40xTwSw -7HX32MxXYruse9ACFBNGmdX2ZBrVNGrN9N2f6ROk0k9K ------END CERTIFICATE-----` - -var dsaSig = `MIAGCSqGSIb3DQEHAqCAMIACAQExCzAJBgUrDgMCGgUAMIAGCSqGSIb3DQEHAaCAJIAEggHTewog -ICJwcml2YXRlSXAiIDogIjE3Mi4zMS4yMy40NyIsCiAgImRldnBheVByb2R1Y3RDb2RlcyIgOiBu -dWxsLAogICJtYXJrZXRwbGFjZVByb2R1Y3RDb2RlcyIgOiBudWxsLAogICJ2ZXJzaW9uIiA6ICIy -MDE3LTA5LTMwIiwKICAiaW5zdGFuY2VJZCIgOiAiaS0wMmUzYmVjMWY2MDBmNWUzMyIsCiAgImJp -bGxpbmdQcm9kdWN0cyIgOiBudWxsLAogICJpbnN0YW5jZVR5cGUiIDogInQyLm1pY3JvIiwKICAi -YXZhaWxhYmlsaXR5Wm9uZSIgOiAidXMtd2VzdC0xYiIsCiAgImtlcm5lbElkIiA6IG51bGwsCiAg -InJhbWRpc2tJZCIgOiBudWxsLAogICJhY2NvdW50SWQiIDogIjgwNzQ5MjQ3MzI2MyIsCiAgImFy -Y2hpdGVjdHVyZSIgOiAieDg2XzY0IiwKICAiaW1hZ2VJZCIgOiAiYW1pLTFjMWQyMTdjIiwKICAi -cGVuZGluZ1RpbWUiIDogIjIwMTctMTEtMjFUMDA6MjU6MjNaIiwKICAicmVnaW9uIiA6ICJ1cy13 -ZXN0LTEiCn0AAAAAAAAxggEYMIIBFAIBATBpMFwxCzAJBgNVBAYTAlVTMRkwFwYDVQQIExBXYXNo -aW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdTZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZp -Y2VzIExMQwIJAJa6SNnlXhpnMAkGBSsOAwIaBQCgXTAYBgkqhkiG9w0BCQMxCwYJKoZIhvcNAQcB -MBwGCSqGSIb3DQEJBTEPFw0xODA3MzAyMzMxMDRaMCMGCSqGSIb3DQEJBDEWBBQUze548OLd+uOT -aOSTDLlV9mevbTAJBgcqhkjOOAQDBC8wLQIUDGeP44Ge1atMQghe+ENV4IDM0zQCFQCBTOEvfKu+ -uscwutj+7RCNgSVaWgAAAAAAAA==` - -var doc = `{ - "privateIp" : "172.31.23.47", - "devpayProductCodes" : null, - "marketplaceProductCodes" : null, - "version" : "2017-09-30", - "instanceId" : "i-02e3bec1f600f5e33", - "billingProducts" : null, - "instanceType" : "t2.micro", - "availabilityZone" : "us-west-1b", - "kernelId" : null, - "ramdiskId" : null, - "accountId" : "807492473263", - "architecture" : "x86_64", - "imageId" : "ami-1c1d217c", - "pendingTime" : "2017-11-21T00:25:23Z", - "region" : "us-west-1" -}` - -func TestAWSRSA(t *testing.T) { - block, _ := pem.Decode([]byte(rsaCert)) - - cert, err := x509.ParseCertificate(block.Bytes) - assert.FatalError(t, err) - - signature, err := base64.StdEncoding.DecodeString(rsaSig) - assert.FatalError(t, err) - - err = cert.CheckSignature(x509.SHA256WithRSA, []byte(doc), signature) +func TestAWS_Getters(t *testing.T) { + p, err := generateAWS() assert.FatalError(t, err) + aud := "aws:" + p.Name + if got := p.GetID(); got != aud { + t.Errorf("AWS.GetID() = %v, want %v", got, aud) + } + if got := p.GetName(); got != p.Name { + t.Errorf("AWS.GetName() = %v, want %v", got, p.Name) + } + if got := p.GetType(); got != TypeAWS { + t.Errorf("AWS.GetType() = %v, want %v", got, TypeAWS) + } + kid, key, ok := p.GetEncryptedKey() + if kid != "" || key != "" || ok == true { + t.Errorf("AWS.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", + kid, key, ok, "", "", false) + } } -func TestAWSDSA(t *testing.T) { - block, _ := pem.Decode([]byte(dsaCert)) - - cert, err := x509.ParseCertificate(block.Bytes) +func TestAWS_GetTokenID(t *testing.T) { + p1, srv, err := generateAWSWithServer() assert.FatalError(t, err) - - signature, err := base64.StdEncoding.DecodeString(dsaSig) - assert.FatalError(t, err) - - p7, err := pkcs7.Parse(signature) - assert.FatalError(t, err) - - p7.Certificates = append(p7.Certificates, cert) - - assert.FatalError(t, p7.Verify()) -} - -func TestAWS_GetIdentityToken(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/document": - w.Write([]byte(doc)) - case "/signature": - w.Write([]byte(rsaSig)) - default: - http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - } - })) defer srv.Close() - config, err := newAWSConfig() + p2, err := generateAWS() assert.FatalError(t, err) - config.identityURL = srv.URL + "/document" - config.signatureURL = srv.URL + "/signature" + p2.Accounts = p1.Accounts + p2.config = p1.config + p2.DisableTrustOnFirstUse = true - type fields struct { - Type string - Name string - Claims *Claims - claimer *Claimer - config *awsConfig + t1, err := p1.GetIdentityToken() + assert.FatalError(t, err) + _, claims, err := parseAWSToken(t1) + assert.FatalError(t, err) + sum := sha256.Sum256([]byte(fmt.Sprintf("%s.%s", p1.GetID(), claims.document.InstanceID))) + w1 := strings.ToLower(hex.EncodeToString(sum[:])) + + t2, err := p2.GetIdentityToken() + assert.FatalError(t, err) + sum = sha256.Sum256([]byte(t2)) + w2 := strings.ToLower(hex.EncodeToString(sum[:])) + + type args struct { + token string } tests := []struct { name string - fields fields + aws *AWS + args args want string wantErr bool }{ - {"ok", fields{"AWS", "name", nil, nil, config}, "", false}, + {"ok", p1, args{t1}, w1, false}, + {"ok no TOFU", p2, args{t2}, w2, false}, + {"fail", p1, args{"bad-token"}, "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p := &AWS{ - Type: tt.fields.Type, - Name: tt.fields.Name, - Claims: tt.fields.Claims, - claimer: tt.fields.claimer, - config: tt.fields.config, + got, err := tt.aws.GetTokenID(tt.args.token) + if (err != nil) != tt.wantErr { + t.Errorf("AWS.GetTokenID() error = %v, wantErr %v", err, tt.wantErr) + return } - got, err := p.GetIdentityToken() + if got != tt.want { + t.Errorf("AWS.GetTokenID() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAWS_GetIdentityToken(t *testing.T) { + p1, srv, err := generateAWSWithServer() + assert.FatalError(t, err) + defer srv.Close() + + p2, err := generateAWS() + assert.FatalError(t, err) + p2.Accounts = p1.Accounts + p2.config.identityURL = srv.URL + "/bad-document" + p2.config.signatureURL = p1.config.signatureURL + + p3, err := generateAWS() + assert.FatalError(t, err) + p3.Accounts = p1.Accounts + p3.config.signatureURL = srv.URL + p3.config.identityURL = p1.config.identityURL + + p4, err := generateAWS() + assert.FatalError(t, err) + p4.Accounts = p1.Accounts + p4.config.signatureURL = srv.URL + "/bad-signature" + p4.config.identityURL = p1.config.identityURL + + tests := []struct { + name string + aws *AWS + wantErr bool + }{ + {"ok", p1, false}, + {"fail identityURL", p2, true}, + {"fail signatureURL", p3, true}, + {"fail signature", p4, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.aws.GetIdentityToken() if (err != nil) != tt.wantErr { t.Errorf("AWS.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr) return } - if got != tt.want { - t.Errorf("AWS.GetIdentityToken() = %v, want %v", got, tt.want) + if tt.wantErr == false { + _, c, err := parseAWSToken(got) + if assert.NoError(t, err) { + assert.Equals(t, awsIssuer, c.Issuer) + assert.Equals(t, c.document.InstanceID, c.Subject) + assert.Equals(t, jose.Audience{tt.aws.GetID()}, c.Audience) + assert.Equals(t, tt.aws.Accounts[0], c.document.AccountID) + err = tt.aws.config.certificate.CheckSignature( + tt.aws.config.signatureAlgorithm, c.Amazon.Document, c.Amazon.Signature) + assert.NoError(t, err) + } + } + }) + } +} + +func TestAWS_Init(t *testing.T) { + config := Config{ + Claims: globalProvisionerClaims, + } + badClaims := &Claims{ + DefaultTLSDur: &Duration{0}, + } + + type fields struct { + Type string + Name string + Accounts []string + DisableCustomSANs bool + DisableTrustOnFirstUse bool + Claims *Claims + } + type args struct { + config Config + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{"AWS", "name", []string{"account"}, false, false, nil}, args{config}, false}, + {"fail type ", fields{"", "name", []string{"account"}, false, false, nil}, args{config}, true}, + {"fail name", fields{"AWS", "", []string{"account"}, false, false, nil}, args{config}, true}, + {"fail claims", fields{"AWS", "name", []string{"account"}, false, false, badClaims}, args{config}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &AWS{ + Type: tt.fields.Type, + Name: tt.fields.Name, + Accounts: tt.fields.Accounts, + DisableCustomSANs: tt.fields.DisableCustomSANs, + DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse, + Claims: tt.fields.Claims, + } + if err := p.Init(tt.args.config); (err != nil) != tt.wantErr { + t.Errorf("AWS.Init() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestAWS_AuthorizeSign(t *testing.T) { + p1, srv, err := generateAWSWithServer() + assert.FatalError(t, err) + defer srv.Close() + + p2, err := generateAWS() + assert.FatalError(t, err) + p2.Accounts = p1.Accounts + p2.config = p1.config + p2.DisableCustomSANs = true + + p3, err := generateAWS() + assert.FatalError(t, err) + p3.config = p1.config + + t1, err := p1.GetIdentityToken() + assert.FatalError(t, err) + t2, err := p2.GetIdentityToken() + assert.FatalError(t, err) + t3, err := p3.GetIdentityToken() + assert.FatalError(t, err) + + block, _ := pem.Decode([]byte(awsTestKey)) + if block == nil || block.Type != "RSA PRIVATE KEY" { + t.Fatal("error decoding AWS key") + } + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + assert.FatalError(t, err) + + badKey, err := rsa.GenerateKey(rand.Reader, 1024) + assert.FatalError(t, err) + + t4, err := generateAWSToken( + "instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id", + "127.0.0.1", "us-west-1", time.Now(), key) + assert.FatalError(t, err) + failSubject, err := generateAWSToken( + "bad-subject", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id", + "127.0.0.1", "us-west-1", time.Now(), key) + assert.FatalError(t, err) + failIssuer, err := generateAWSToken( + "instance-id", "bad-issuer", p1.GetID(), p1.Accounts[0], "instance-id", + "127.0.0.1", "us-west-1", time.Now(), key) + assert.FatalError(t, err) + failAudience, err := generateAWSToken( + "instance-id", awsIssuer, "bad-audience", p1.Accounts[0], "instance-id", + "127.0.0.1", "us-west-1", time.Now(), key) + assert.FatalError(t, err) + failAccount, err := generateAWSToken( + "instance-id", awsIssuer, p1.GetID(), "", "instance-id", + "127.0.0.1", "us-west-1", time.Now(), key) + assert.FatalError(t, err) + failInstanceID, err := generateAWSToken( + "instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "", + "127.0.0.1", "us-west-1", time.Now(), key) + assert.FatalError(t, err) + failPrivateIP, err := generateAWSToken( + "instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id", + "", "us-west-1", time.Now(), key) + assert.FatalError(t, err) + failRegion, err := generateAWSToken( + "instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id", + "127.0.0.1", "", time.Now(), key) + assert.FatalError(t, err) + failExp, err := generateAWSToken( + "instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id", + "127.0.0.1", "us-west-1", time.Now().Add(-360*time.Second), key) + assert.FatalError(t, err) + failNbf, err := generateAWSToken( + "instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id", + "127.0.0.1", "us-west-1", time.Now().Add(360*time.Second), key) + assert.FatalError(t, err) + failKey, err := generateAWSToken( + "instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id", + "127.0.0.1", "us-west-1", time.Now(), badKey) + assert.FatalError(t, err) + + type args struct { + token string + } + tests := []struct { + name string + aws *AWS + args args + wantLen int + wantErr bool + }{ + {"ok", p1, args{t1}, 4, false}, + {"ok", p2, args{t2}, 6, false}, + {"ok", p1, args{t4}, 4, false}, + {"fail account", p3, args{t3}, 0, true}, + {"fail token", p1, args{"token"}, 0, true}, + {"fail subject", p1, args{failSubject}, 0, true}, + {"fail issuer", p1, args{failIssuer}, 0, true}, + {"fail audience", p1, args{failAudience}, 0, true}, + {"fail account", p1, args{failAccount}, 0, true}, + {"fail instanceID", p1, args{failInstanceID}, 0, true}, + {"fail privateIP", p1, args{failPrivateIP}, 0, true}, + {"fail region", p1, args{failRegion}, 0, true}, + {"fail exp", p1, args{failExp}, 0, true}, + {"fail nbf", p1, args{failNbf}, 0, true}, + {"fail key", p1, args{failKey}, 0, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.aws.AuthorizeSign(tt.args.token) + if (err != nil) != tt.wantErr { + t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.Len(t, tt.wantLen, got) + }) + } +} + +func TestAWS_AuthorizeRenewal(t *testing.T) { + p1, err := generateAWS() + assert.FatalError(t, err) + p2, err := generateAWS() + assert.FatalError(t, err) + + // disable renewal + disable := true + p2.Claims = &Claims{DisableRenewal: &disable} + p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + assert.FatalError(t, err) + + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + aws *AWS + args args + wantErr bool + }{ + {"ok", p1, args{nil}, false}, + {"fail", p2, args{nil}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.aws.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("AWS.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestAWS_AuthorizeRevoke(t *testing.T) { + p1, srv, err := generateAWSWithServer() + assert.FatalError(t, err) + defer srv.Close() + + t1, err := p1.GetIdentityToken() + assert.FatalError(t, err) + + type args struct { + token string + } + tests := []struct { + name string + aws *AWS + args args + wantErr bool + }{ + {"ok", p1, args{t1}, true}, // revoke is disabled + {"fail", p1, args{"token"}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.aws.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr { + t.Errorf("AWS.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr) } - t.Error(got) - // parts := strings.Split(got, ".") - // signed, err := base64.RawURLEncoding.DecodeString(parts[0]) - // assert.FatalError(t, err) - // signature, err := base64.RawURLEncoding.DecodeString(parts[1]) - // assert.FatalError(t, err) - // assert.FatalError(t, err, config.certificate.CheckSignature(config.signatureAlgorithm, signed, signature)) }) } } diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index 5c078479..47d42622 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -2,12 +2,18 @@ package provisioner import ( "crypto" + "crypto/rand" + "crypto/sha256" + "crypto/x509" + "encoding/base64" "encoding/hex" "encoding/json" + "encoding/pem" "net/http" "net/http/httptest" "time" + "github.com/pkg/errors" "github.com/smallstep/cli/crypto/randutil" "github.com/smallstep/cli/jose" ) @@ -17,6 +23,37 @@ var testAudiences = Audiences{ Revoke: []string{"https://ca.smallstep.com/revoke", "https://ca.smallstep.com/1.0/revoke"}, } +const awsTestCertificate = `-----BEGIN CERTIFICATE----- +MIICFTCCAX6gAwIBAgIRAKmbVVYAl/1XEqRfF3eJ97MwDQYJKoZIhvcNAQELBQAw +GDEWMBQGA1UEAxMNQVdTIFRlc3QgQ2VydDAeFw0xOTA0MjQyMjU3MzlaFw0yOTA0 +MjEyMjU3MzlaMBgxFjAUBgNVBAMTDUFXUyBUZXN0IENlcnQwgZ8wDQYJKoZIhvcN +AQEBBQADgY0AMIGJAoGBAOHMmMXwbXN90SoRl/xXAcJs5TacaVYJ5iNAVWM5KYyF ++JwqYuJp/umLztFUi0oX0luu3EzD4KurVeUJSzZjTFTX1d/NX6hA45+bvdSUOcgV +UghO+2uhBZ4SNFxFRZ7SKvoWIN195l5bVX6/60Eo6+kUCKCkyxW4V/ksWzdXjHnf +AgMBAAGjXzBdMA4GA1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAGAQH/AgEBMB0G +A1UdDgQWBBRHfLOjEddK/CWCIHNg8Oc/oJa1IzAYBgNVHREEETAPgg1BV1MgVGVz +dCBDZXJ0MA0GCSqGSIb3DQEBCwUAA4GBAKNCiVM9eGb9dW2xNyHaHAmmy7ERB2OJ +7oXHfLjooOavk9lU/Gs2jfX/JSBa84+DzWg9ShmCNLti8CxU/dhzXW7jE/5CcdTa +DCA6B3Yl5TmfG9+D9dtFqRB2CiMgNcsJJE5Dc6pDwBIiSj/MkE0AaGVQmSwn6Cb6 +vX1TAxqeWJHq +-----END CERTIFICATE-----` + +const awsTestKey = `-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQDhzJjF8G1zfdEqEZf8VwHCbOU2nGlWCeYjQFVjOSmMhficKmLi +af7pi87RVItKF9JbrtxMw+Crq1XlCUs2Y0xU19XfzV+oQOOfm73UlDnIFVIITvtr +oQWeEjRcRUWe0ir6FiDdfeZeW1V+v+tBKOvpFAigpMsVuFf5LFs3V4x53wIDAQAB +AoGADZQFF9oWatyFCHeYYSdGRs/PlNIhD3h262XB/L6CPh4MTi/KVH01RAwROstP +uPvnvXWtb7xTtV8PQj+l0zZzb4W/DLCSBdoRwpuNXyffUCtbI22jPupTsVu+ENWR +3x7HHzoZYjU45ADSTMxEtwD7/zyNgpRKjIA2HYpkt+fI27ECQQD5/AOr9/yQD73x +cquF+FWahWgDL25YeMwdfe1HfpUxUxd9kJJKieB8E2BtBAv9XNguxIBpf7VlAKsF +NFhdfWFHAkEA5zuX8vqDecSzyNNEQd3tugxt1pGOXNesHzuPbdlw3ppN9Rbd93an +uU2TaAvTjr/3EkxulYNRmHs+RSVK54+uqQJAKWurhBQMAibJlzcj2ofiTz8pk9WJ +GBmz4HMcHMuJlumoq8KHqtgbnRNs18Ni5TE8FMu0Z0ak3L52l98rgRokQwJBAJS8 +9KTLF79AFBVeME3eH4jJbe3TeyulX4ZHnZ8fe0b1IqhAqU8A+CpuCB+pW9A7Ewam +O4vZCKd4vzljH6eL+OECQHHxhYoTW7lFpKGnUDG9fPZ3eYzWpgka6w1vvBk10BAu +6fbwppM9pQ7DPMg7V6YGEjjT0gX9B9TttfHxGhvtZNQ= +-----END RSA PRIVATE KEY-----` + func must(args ...interface{}) []interface{} { if l := len(args); l > 0 && args[l-1] != nil { if err, ok := args[l-1].(error); ok { @@ -194,6 +231,103 @@ func generateGCP() (*GCP, error) { }, nil } +func generateAWS() (*AWS, error) { + name, err := randutil.Alphanumeric(10) + if err != nil { + return nil, err + } + accountID, err := randutil.Alphanumeric(10) + if err != nil { + return nil, err + } + claimer, err := NewClaimer(nil, globalProvisionerClaims) + if err != nil { + return nil, err + } + block, _ := pem.Decode([]byte(awsTestCertificate)) + if block == nil || block.Type != "CERTIFICATE" { + return nil, errors.New("error decoding AWS certificate") + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, errors.Wrap(err, "error parsing AWS certificate") + } + return &AWS{ + Type: "AWS", + Name: name, + Accounts: []string{accountID}, + Claims: &globalProvisionerClaims, + claimer: claimer, + config: &awsConfig{ + identityURL: awsIdentityURL, + signatureURL: awsSignatureURL, + certificate: cert, + signatureAlgorithm: awsSignatureAlgorithm, + }, + }, nil +} + +func generateAWSWithServer() (*AWS, *httptest.Server, error) { + aws, err := generateAWS() + if err != nil { + return nil, nil, err + } + block, _ := pem.Decode([]byte(awsTestKey)) + if block == nil || block.Type != "RSA PRIVATE KEY" { + return nil, nil, errors.New("error decoding AWS key") + } + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, nil, errors.Wrap(err, "error parsing AWS private key") + } + instanceID, err := randutil.Alphanumeric(10) + if err != nil { + return nil, nil, err + } + imageID, err := randutil.Alphanumeric(10) + if err != nil { + return nil, nil, err + } + doc, err := json.MarshalIndent(awsInstanceIdentityDocument{ + AccountID: aws.Accounts[0], + Architecture: "x86_64", + AvailabilityZone: "us-west-2b", + ImageID: imageID, + InstanceID: instanceID, + InstanceType: "t2.micro", + PendingTime: time.Now(), + PrivateIP: "127.0.0.1", + Region: "us-west-1", + Version: "2017-09-30", + }, "", " ") + if err != nil { + return nil, nil, err + } + + sum := sha256.Sum256(doc) + signature, err := key.Sign(rand.Reader, sum[:], crypto.SHA256) + if err != nil { + return nil, nil, errors.Wrap(err, "error signing document") + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/latest/dynamic/instance-identity/document": + w.Write(doc) + case "/latest/dynamic/instance-identity/signature": + w.Write([]byte(base64.StdEncoding.EncodeToString(signature))) + case "/bad-document": + w.Write([]byte("{}")) + case "/bad-signature": + w.Write([]byte("YmFkLXNpZ25hdHVyZQo=")) + default: + http.NotFound(w, r) + } + })) + aws.config.identityURL = srv.URL + "/latest/dynamic/instance-identity/document" + aws.config.signatureURL = srv.URL + "/latest/dynamic/instance-identity/signature" + return aws, srv, nil +} + func generateCollection(nJWK, nOIDC int) (*Collection, error) { col := NewCollection(testAudiences) for i := 0; i < nJWK; i++ { @@ -286,6 +420,54 @@ func generateGCPToken(sub, iss, aud, instanceID, instanceName, projectID, zone s return jose.Signed(sig).Claims(claims).CompactSerialize() } +func generateAWSToken(sub, iss, aud, accountID, instanceID, privateIP, region string, iat time.Time, key crypto.Signer) (string, error) { + doc, err := json.MarshalIndent(awsInstanceIdentityDocument{ + AccountID: accountID, + Architecture: "x86_64", + AvailabilityZone: "us-west-2b", + ImageID: "ami-123123", + InstanceID: instanceID, + InstanceType: "t2.micro", + PendingTime: time.Now(), + PrivateIP: privateIP, + Region: region, + Version: "2017-09-30", + }, "", " ") + if err != nil { + return "", err + } + + sum := sha256.Sum256(doc) + signature, err := key.Sign(rand.Reader, sum[:], crypto.SHA256) + if err != nil { + return "", errors.Wrap(err, "error signing document") + } + + sig, err := jose.NewSigner( + jose.SigningKey{Algorithm: jose.HS256, Key: signature}, + new(jose.SignerOptions).WithType("JWT"), + ) + if err != nil { + return "", err + } + + claims := awsPayload{ + Claims: jose.Claims{ + Subject: sub, + Issuer: iss, + IssuedAt: jose.NewNumericDate(iat), + NotBefore: jose.NewNumericDate(iat), + Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), + Audience: []string{aud}, + }, + Amazon: awsAmazonPayload{ + Document: doc, + Signature: signature, + }, + } + return jose.Signed(sig).Claims(claims).CompactSerialize() +} + func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) { tok, err := jose.ParseSigned(token) if err != nil { @@ -298,6 +480,23 @@ func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) { return tok, claims, nil } +func parseAWSToken(token string) (*jose.JSONWebToken, *awsPayload, error) { + tok, err := jose.ParseSigned(token) + if err != nil { + return nil, nil, err + } + claims := new(awsPayload) + if err := tok.UnsafeClaimsWithoutVerification(claims); err != nil { + return nil, nil, err + } + var doc awsInstanceIdentityDocument + if err := json.Unmarshal(claims.Amazon.Document, &doc); err != nil { + return nil, nil, errors.Wrap(err, "error unmarshaling identity document") + } + claims.document = doc + return tok, claims, nil +} + func generateJWKServer(n int) *httptest.Server { hits := struct { Hits int `json:"hits"`