aws: add tests covering metadata service versions

* Add constructor tests for the aws provisioner.
* Add a test to make sure the "v1" logic continues to work.

By and large, v2 is the way to go. However, there are some instances of
things that specifically request metadata service version 1 and so this
adds minimal coverage to make sure we don't accidentally break the path
should anyone need to depend on the former logic.
This commit is contained in:
David Cowden 2020-07-22 16:52:06 -07:00
parent 5efe5f3573
commit 51f16ee2e0
3 changed files with 134 additions and 6 deletions

View file

@ -386,6 +386,9 @@ func (p *AWS) readURL(url string) ([]byte, error) {
default: default:
return nil, fmt.Errorf("%s: not a supported AWS Instance Metadata Service version", v) return nil, fmt.Errorf("%s: not a supported AWS Instance Metadata Service version", v)
} }
if resp != nil {
resp.Body.Close()
}
} }
// all versions have been exhausted and we haven't returned successfully yet so pass // all versions have been exhausted and we haven't returned successfully yet so pass

View file

@ -187,6 +187,31 @@ func TestAWS_GetIdentityToken(t *testing.T) {
} }
} }
func TestAWS_GetIdentityTokenV1Only(t *testing.T) {
aws, srv, err := generateAWSWithServerV1Only()
assert.FatalError(t, err)
defer srv.Close()
subject := "foo.local"
caURL := "https://ca.smallstep.com"
u, err := url.Parse(caURL)
assert.Nil(t, err)
token, err := aws.GetIdentityToken(subject, caURL)
assert.Nil(t, err)
_, c, err := parseAWSToken(token)
if assert.NoError(t, err) {
assert.Equals(t, awsIssuer, c.Issuer)
assert.Equals(t, subject, c.Subject)
assert.Equals(t, jose.Audience{u.ResolveReference(&url.URL{Path: "/1.0/sign", Fragment: aws.GetID()}).String()}, c.Audience)
assert.Equals(t, aws.Accounts[0], c.document.AccountID)
err = aws.config.certificate.CheckSignature(
aws.config.signatureAlgorithm, c.Amazon.Document, c.Amazon.Signature)
assert.NoError(t, err)
}
}
func TestAWS_Init(t *testing.T) { func TestAWS_Init(t *testing.T) {
config := Config{ config := Config{
Claims: globalProvisionerClaims, Claims: globalProvisionerClaims,
@ -203,6 +228,7 @@ func TestAWS_Init(t *testing.T) {
DisableCustomSANs bool DisableCustomSANs bool
DisableTrustOnFirstUse bool DisableTrustOnFirstUse bool
InstanceAge Duration InstanceAge Duration
IMDSVersions []string
Claims *Claims Claims *Claims
} }
type args struct { type args struct {
@ -214,12 +240,15 @@ func TestAWS_Init(t *testing.T) {
args args args args
wantErr bool wantErr bool
}{ }{
{"ok", fields{"AWS", "name", []string{"account"}, false, false, zero, nil}, args{config}, false}, {"ok", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, false},
{"ok", fields{"AWS", "name", []string{"account"}, true, true, Duration{Duration: 1 * time.Minute}, nil}, args{config}, false}, {"ok/v1", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1"}, nil}, args{config}, false},
{"fail type ", fields{"", "name", []string{"account"}, false, false, zero, nil}, args{config}, true}, {"ok/v2", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v2"}, nil}, args{config}, false},
{"fail name", fields{"AWS", "", []string{"account"}, false, false, zero, nil}, args{config}, true}, {"ok/duration", fields{"AWS", "name", []string{"account"}, true, true, Duration{Duration: 1 * time.Minute}, []string{"v1", "v2"}, nil}, args{config}, false},
{"bad instance age", fields{"AWS", "name", []string{"account"}, false, false, Duration{Duration: -1 * time.Minute}, nil}, args{config}, true}, {"fail type ", fields{"", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, true},
{"fail claims", fields{"AWS", "name", []string{"account"}, false, false, zero, badClaims}, args{config}, true}, {"fail name", fields{"AWS", "", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, true},
{"bad instance age", fields{"AWS", "name", []string{"account"}, false, false, Duration{Duration: -1 * time.Minute}, []string{"v1", "v2"}, nil}, args{config}, true},
{"fail/imds", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"bad"}, nil}, args{config}, true},
{"fail claims", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, badClaims}, args{config}, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -230,6 +259,7 @@ func TestAWS_Init(t *testing.T) {
DisableCustomSANs: tt.fields.DisableCustomSANs, DisableCustomSANs: tt.fields.DisableCustomSANs,
DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse, DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse,
InstanceAge: tt.fields.InstanceAge, InstanceAge: tt.fields.InstanceAge,
IMDSVersions: tt.fields.IMDSVersions,
Claims: tt.fields.Claims, Claims: tt.fields.Claims,
} }
if err := p.Init(tt.args.config); (err != nil) != tt.wantErr { if err := p.Init(tt.args.config); (err != nil) != tt.wantErr {

View file

@ -495,6 +495,101 @@ func generateAWSWithServer() (*AWS, *httptest.Server, error) {
return aws, srv, nil return aws, srv, nil
} }
func generateAWSV1Only() (*AWS, error) {
name, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
accountID, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
block, _ := pem.Decode([]byte(awsTestCertificate))
if block == nil || block.Type != "CERTIFICATE" {
return nil, errors.New("error decoding AWS certificate")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, errors.Wrap(err, "error parsing AWS certificate")
}
return &AWS{
Type: "AWS",
Name: name,
Accounts: []string{accountID},
Claims: &globalProvisionerClaims,
IMDSVersions: []string{"v1"},
claimer: claimer,
config: &awsConfig{
identityURL: awsIdentityURL,
signatureURL: awsSignatureURL,
tokenURL: awsAPITokenURL,
tokenTTL: awsAPITokenTTL,
certificate: cert,
signatureAlgorithm: awsSignatureAlgorithm,
},
audiences: testAudiences.WithFragment("aws/" + name),
}, nil
}
func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) {
aws, err := generateAWSV1Only()
if err != nil {
return nil, nil, err
}
block, _ := pem.Decode([]byte(awsTestKey))
if block == nil || block.Type != "RSA PRIVATE KEY" {
return nil, nil, errors.New("error decoding AWS key")
}
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, nil, errors.Wrap(err, "error parsing AWS private key")
}
doc, err := json.MarshalIndent(awsInstanceIdentityDocument{
AccountID: aws.Accounts[0],
Architecture: "x86_64",
AvailabilityZone: "us-west-2b",
ImageID: "image-id",
InstanceID: "instance-id",
InstanceType: "t2.micro",
PendingTime: time.Now(),
PrivateIP: "127.0.0.1",
Region: "us-west-1",
Version: "2017-09-30",
}, "", " ")
if err != nil {
return nil, nil, err
}
sum := sha256.Sum256(doc)
signature, err := key.Sign(rand.Reader, sum[:], crypto.SHA256)
if err != nil {
return nil, nil, errors.Wrap(err, "error signing document")
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/latest/dynamic/instance-identity/document":
w.Write(doc)
case "/latest/dynamic/instance-identity/signature":
w.Write([]byte(base64.StdEncoding.EncodeToString(signature)))
case "/bad-document":
w.Write([]byte("{}"))
case "/bad-signature":
w.Write([]byte("YmFkLXNpZ25hdHVyZQo="))
case "/bad-json":
w.Write([]byte("{"))
default:
http.NotFound(w, r)
}
}))
aws.config.identityURL = srv.URL + "/latest/dynamic/instance-identity/document"
aws.config.signatureURL = srv.URL + "/latest/dynamic/instance-identity/signature"
return aws, srv, nil
}
func generateAzure() (*Azure, error) { func generateAzure() (*Azure, error) {
name, err := randutil.Alphanumeric(10) name, err := randutil.Alphanumeric(10)
if err != nil { if err != nil {