From 51f16ee2e01062aa7d31b80245d8e6663fe4e28c Mon Sep 17 00:00:00 2001 From: David Cowden Date: Wed, 22 Jul 2020 16:52:06 -0700 Subject: [PATCH] aws: add tests covering metadata service versions * Add constructor tests for the aws provisioner. * Add a test to make sure the "v1" logic continues to work. By and large, v2 is the way to go. However, there are some instances of things that specifically request metadata service version 1 and so this adds minimal coverage to make sure we don't accidentally break the path should anyone need to depend on the former logic. --- authority/provisioner/aws.go | 3 + authority/provisioner/aws_test.go | 42 +++++++++++-- authority/provisioner/utils_test.go | 95 +++++++++++++++++++++++++++++ 3 files changed, 134 insertions(+), 6 deletions(-) diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index 3e777e08..56a96490 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -386,6 +386,9 @@ func (p *AWS) readURL(url string) ([]byte, error) { default: return nil, fmt.Errorf("%s: not a supported AWS Instance Metadata Service version", v) } + if resp != nil { + resp.Body.Close() + } } // all versions have been exhausted and we haven't returned successfully yet so pass diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 50cdd32a..b5728de4 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -187,6 +187,31 @@ func TestAWS_GetIdentityToken(t *testing.T) { } } +func TestAWS_GetIdentityTokenV1Only(t *testing.T) { + aws, srv, err := generateAWSWithServerV1Only() + assert.FatalError(t, err) + defer srv.Close() + + subject := "foo.local" + caURL := "https://ca.smallstep.com" + u, err := url.Parse(caURL) + assert.Nil(t, err) + + token, err := aws.GetIdentityToken(subject, caURL) + assert.Nil(t, err) + + _, c, err := parseAWSToken(token) + if assert.NoError(t, err) { + assert.Equals(t, awsIssuer, c.Issuer) + assert.Equals(t, subject, c.Subject) + assert.Equals(t, jose.Audience{u.ResolveReference(&url.URL{Path: "/1.0/sign", Fragment: aws.GetID()}).String()}, c.Audience) + assert.Equals(t, aws.Accounts[0], c.document.AccountID) + err = aws.config.certificate.CheckSignature( + aws.config.signatureAlgorithm, c.Amazon.Document, c.Amazon.Signature) + assert.NoError(t, err) + } +} + func TestAWS_Init(t *testing.T) { config := Config{ Claims: globalProvisionerClaims, @@ -203,6 +228,7 @@ func TestAWS_Init(t *testing.T) { DisableCustomSANs bool DisableTrustOnFirstUse bool InstanceAge Duration + IMDSVersions []string Claims *Claims } type args struct { @@ -214,12 +240,15 @@ func TestAWS_Init(t *testing.T) { args args wantErr bool }{ - {"ok", fields{"AWS", "name", []string{"account"}, false, false, zero, nil}, args{config}, false}, - {"ok", fields{"AWS", "name", []string{"account"}, true, true, Duration{Duration: 1 * time.Minute}, nil}, args{config}, false}, - {"fail type ", fields{"", "name", []string{"account"}, false, false, zero, nil}, args{config}, true}, - {"fail name", fields{"AWS", "", []string{"account"}, false, false, zero, nil}, args{config}, true}, - {"bad instance age", fields{"AWS", "name", []string{"account"}, false, false, Duration{Duration: -1 * time.Minute}, nil}, args{config}, true}, - {"fail claims", fields{"AWS", "name", []string{"account"}, false, false, zero, badClaims}, args{config}, true}, + {"ok", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, false}, + {"ok/v1", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1"}, nil}, args{config}, false}, + {"ok/v2", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v2"}, nil}, args{config}, false}, + {"ok/duration", fields{"AWS", "name", []string{"account"}, true, true, Duration{Duration: 1 * time.Minute}, []string{"v1", "v2"}, nil}, args{config}, false}, + {"fail type ", fields{"", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, true}, + {"fail name", fields{"AWS", "", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, true}, + {"bad instance age", fields{"AWS", "name", []string{"account"}, false, false, Duration{Duration: -1 * time.Minute}, []string{"v1", "v2"}, nil}, args{config}, true}, + {"fail/imds", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"bad"}, nil}, args{config}, true}, + {"fail claims", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, badClaims}, args{config}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -230,6 +259,7 @@ func TestAWS_Init(t *testing.T) { DisableCustomSANs: tt.fields.DisableCustomSANs, DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse, InstanceAge: tt.fields.InstanceAge, + IMDSVersions: tt.fields.IMDSVersions, Claims: tt.fields.Claims, } if err := p.Init(tt.args.config); (err != nil) != tt.wantErr { diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index 84283631..52fb470a 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -495,6 +495,101 @@ func generateAWSWithServer() (*AWS, *httptest.Server, error) { return aws, srv, nil } +func generateAWSV1Only() (*AWS, error) { + name, err := randutil.Alphanumeric(10) + if err != nil { + return nil, err + } + accountID, err := randutil.Alphanumeric(10) + if err != nil { + return nil, err + } + claimer, err := NewClaimer(nil, globalProvisionerClaims) + if err != nil { + return nil, err + } + block, _ := pem.Decode([]byte(awsTestCertificate)) + if block == nil || block.Type != "CERTIFICATE" { + return nil, errors.New("error decoding AWS certificate") + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, errors.Wrap(err, "error parsing AWS certificate") + } + return &AWS{ + Type: "AWS", + Name: name, + Accounts: []string{accountID}, + Claims: &globalProvisionerClaims, + IMDSVersions: []string{"v1"}, + claimer: claimer, + config: &awsConfig{ + identityURL: awsIdentityURL, + signatureURL: awsSignatureURL, + tokenURL: awsAPITokenURL, + tokenTTL: awsAPITokenTTL, + certificate: cert, + signatureAlgorithm: awsSignatureAlgorithm, + }, + audiences: testAudiences.WithFragment("aws/" + name), + }, nil +} + +func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) { + aws, err := generateAWSV1Only() + if err != nil { + return nil, nil, err + } + block, _ := pem.Decode([]byte(awsTestKey)) + if block == nil || block.Type != "RSA PRIVATE KEY" { + return nil, nil, errors.New("error decoding AWS key") + } + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, nil, errors.Wrap(err, "error parsing AWS private key") + } + doc, err := json.MarshalIndent(awsInstanceIdentityDocument{ + AccountID: aws.Accounts[0], + Architecture: "x86_64", + AvailabilityZone: "us-west-2b", + ImageID: "image-id", + InstanceID: "instance-id", + InstanceType: "t2.micro", + PendingTime: time.Now(), + PrivateIP: "127.0.0.1", + Region: "us-west-1", + Version: "2017-09-30", + }, "", " ") + if err != nil { + return nil, nil, err + } + + sum := sha256.Sum256(doc) + signature, err := key.Sign(rand.Reader, sum[:], crypto.SHA256) + if err != nil { + return nil, nil, errors.Wrap(err, "error signing document") + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/latest/dynamic/instance-identity/document": + w.Write(doc) + case "/latest/dynamic/instance-identity/signature": + w.Write([]byte(base64.StdEncoding.EncodeToString(signature))) + case "/bad-document": + w.Write([]byte("{}")) + case "/bad-signature": + w.Write([]byte("YmFkLXNpZ25hdHVyZQo=")) + case "/bad-json": + w.Write([]byte("{")) + default: + http.NotFound(w, r) + } + })) + aws.config.identityURL = srv.URL + "/latest/dynamic/instance-identity/document" + aws.config.signatureURL = srv.URL + "/latest/dynamic/instance-identity/signature" + return aws, srv, nil +} + func generateAzure() (*Azure, error) { name, err := randutil.Alphanumeric(10) if err != nil {