aws: add tests covering metadata service versions

* Add constructor tests for the aws provisioner.
* Add a test to make sure the "v1" logic continues to work.

By and large, v2 is the way to go. However, there are some instances of
things that specifically request metadata service version 1 and so this
adds minimal coverage to make sure we don't accidentally break the path
should anyone need to depend on the former logic.
This commit is contained in:
David Cowden 2020-07-22 16:52:06 -07:00
parent 5efe5f3573
commit 51f16ee2e0
3 changed files with 134 additions and 6 deletions

View file

@ -386,6 +386,9 @@ func (p *AWS) readURL(url string) ([]byte, error) {
default:
return nil, fmt.Errorf("%s: not a supported AWS Instance Metadata Service version", v)
}
if resp != nil {
resp.Body.Close()
}
}
// all versions have been exhausted and we haven't returned successfully yet so pass

View file

@ -187,6 +187,31 @@ func TestAWS_GetIdentityToken(t *testing.T) {
}
}
func TestAWS_GetIdentityTokenV1Only(t *testing.T) {
aws, srv, err := generateAWSWithServerV1Only()
assert.FatalError(t, err)
defer srv.Close()
subject := "foo.local"
caURL := "https://ca.smallstep.com"
u, err := url.Parse(caURL)
assert.Nil(t, err)
token, err := aws.GetIdentityToken(subject, caURL)
assert.Nil(t, err)
_, c, err := parseAWSToken(token)
if assert.NoError(t, err) {
assert.Equals(t, awsIssuer, c.Issuer)
assert.Equals(t, subject, c.Subject)
assert.Equals(t, jose.Audience{u.ResolveReference(&url.URL{Path: "/1.0/sign", Fragment: aws.GetID()}).String()}, c.Audience)
assert.Equals(t, aws.Accounts[0], c.document.AccountID)
err = aws.config.certificate.CheckSignature(
aws.config.signatureAlgorithm, c.Amazon.Document, c.Amazon.Signature)
assert.NoError(t, err)
}
}
func TestAWS_Init(t *testing.T) {
config := Config{
Claims: globalProvisionerClaims,
@ -203,6 +228,7 @@ func TestAWS_Init(t *testing.T) {
DisableCustomSANs bool
DisableTrustOnFirstUse bool
InstanceAge Duration
IMDSVersions []string
Claims *Claims
}
type args struct {
@ -214,12 +240,15 @@ func TestAWS_Init(t *testing.T) {
args args
wantErr bool
}{
{"ok", fields{"AWS", "name", []string{"account"}, false, false, zero, nil}, args{config}, false},
{"ok", fields{"AWS", "name", []string{"account"}, true, true, Duration{Duration: 1 * time.Minute}, nil}, args{config}, false},
{"fail type ", fields{"", "name", []string{"account"}, false, false, zero, nil}, args{config}, true},
{"fail name", fields{"AWS", "", []string{"account"}, false, false, zero, nil}, args{config}, true},
{"bad instance age", fields{"AWS", "name", []string{"account"}, false, false, Duration{Duration: -1 * time.Minute}, nil}, args{config}, true},
{"fail claims", fields{"AWS", "name", []string{"account"}, false, false, zero, badClaims}, args{config}, true},
{"ok", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, false},
{"ok/v1", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1"}, nil}, args{config}, false},
{"ok/v2", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v2"}, nil}, args{config}, false},
{"ok/duration", fields{"AWS", "name", []string{"account"}, true, true, Duration{Duration: 1 * time.Minute}, []string{"v1", "v2"}, nil}, args{config}, false},
{"fail type ", fields{"", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, true},
{"fail name", fields{"AWS", "", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, true},
{"bad instance age", fields{"AWS", "name", []string{"account"}, false, false, Duration{Duration: -1 * time.Minute}, []string{"v1", "v2"}, nil}, args{config}, true},
{"fail/imds", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"bad"}, nil}, args{config}, true},
{"fail claims", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, badClaims}, args{config}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -230,6 +259,7 @@ func TestAWS_Init(t *testing.T) {
DisableCustomSANs: tt.fields.DisableCustomSANs,
DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse,
InstanceAge: tt.fields.InstanceAge,
IMDSVersions: tt.fields.IMDSVersions,
Claims: tt.fields.Claims,
}
if err := p.Init(tt.args.config); (err != nil) != tt.wantErr {

View file

@ -495,6 +495,101 @@ func generateAWSWithServer() (*AWS, *httptest.Server, error) {
return aws, srv, nil
}
func generateAWSV1Only() (*AWS, error) {
name, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
accountID, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
block, _ := pem.Decode([]byte(awsTestCertificate))
if block == nil || block.Type != "CERTIFICATE" {
return nil, errors.New("error decoding AWS certificate")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, errors.Wrap(err, "error parsing AWS certificate")
}
return &AWS{
Type: "AWS",
Name: name,
Accounts: []string{accountID},
Claims: &globalProvisionerClaims,
IMDSVersions: []string{"v1"},
claimer: claimer,
config: &awsConfig{
identityURL: awsIdentityURL,
signatureURL: awsSignatureURL,
tokenURL: awsAPITokenURL,
tokenTTL: awsAPITokenTTL,
certificate: cert,
signatureAlgorithm: awsSignatureAlgorithm,
},
audiences: testAudiences.WithFragment("aws/" + name),
}, nil
}
func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) {
aws, err := generateAWSV1Only()
if err != nil {
return nil, nil, err
}
block, _ := pem.Decode([]byte(awsTestKey))
if block == nil || block.Type != "RSA PRIVATE KEY" {
return nil, nil, errors.New("error decoding AWS key")
}
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, nil, errors.Wrap(err, "error parsing AWS private key")
}
doc, err := json.MarshalIndent(awsInstanceIdentityDocument{
AccountID: aws.Accounts[0],
Architecture: "x86_64",
AvailabilityZone: "us-west-2b",
ImageID: "image-id",
InstanceID: "instance-id",
InstanceType: "t2.micro",
PendingTime: time.Now(),
PrivateIP: "127.0.0.1",
Region: "us-west-1",
Version: "2017-09-30",
}, "", " ")
if err != nil {
return nil, nil, err
}
sum := sha256.Sum256(doc)
signature, err := key.Sign(rand.Reader, sum[:], crypto.SHA256)
if err != nil {
return nil, nil, errors.Wrap(err, "error signing document")
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/latest/dynamic/instance-identity/document":
w.Write(doc)
case "/latest/dynamic/instance-identity/signature":
w.Write([]byte(base64.StdEncoding.EncodeToString(signature)))
case "/bad-document":
w.Write([]byte("{}"))
case "/bad-signature":
w.Write([]byte("YmFkLXNpZ25hdHVyZQo="))
case "/bad-json":
w.Write([]byte("{"))
default:
http.NotFound(w, r)
}
}))
aws.config.identityURL = srv.URL + "/latest/dynamic/instance-identity/document"
aws.config.signatureURL = srv.URL + "/latest/dynamic/instance-identity/signature"
return aws, srv, nil
}
func generateAzure() (*Azure, error) {
name, err := randutil.Alphanumeric(10)
if err != nil {