aws: add tests covering metadata service versions
* Add constructor tests for the aws provisioner. * Add a test to make sure the "v1" logic continues to work. By and large, v2 is the way to go. However, there are some instances of things that specifically request metadata service version 1 and so this adds minimal coverage to make sure we don't accidentally break the path should anyone need to depend on the former logic.
This commit is contained in:
parent
5efe5f3573
commit
51f16ee2e0
3 changed files with 134 additions and 6 deletions
|
@ -386,6 +386,9 @@ func (p *AWS) readURL(url string) ([]byte, error) {
|
|||
default:
|
||||
return nil, fmt.Errorf("%s: not a supported AWS Instance Metadata Service version", v)
|
||||
}
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// all versions have been exhausted and we haven't returned successfully yet so pass
|
||||
|
|
|
@ -187,6 +187,31 @@ func TestAWS_GetIdentityToken(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAWS_GetIdentityTokenV1Only(t *testing.T) {
|
||||
aws, srv, err := generateAWSWithServerV1Only()
|
||||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
subject := "foo.local"
|
||||
caURL := "https://ca.smallstep.com"
|
||||
u, err := url.Parse(caURL)
|
||||
assert.Nil(t, err)
|
||||
|
||||
token, err := aws.GetIdentityToken(subject, caURL)
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, c, err := parseAWSToken(token)
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equals(t, awsIssuer, c.Issuer)
|
||||
assert.Equals(t, subject, c.Subject)
|
||||
assert.Equals(t, jose.Audience{u.ResolveReference(&url.URL{Path: "/1.0/sign", Fragment: aws.GetID()}).String()}, c.Audience)
|
||||
assert.Equals(t, aws.Accounts[0], c.document.AccountID)
|
||||
err = aws.config.certificate.CheckSignature(
|
||||
aws.config.signatureAlgorithm, c.Amazon.Document, c.Amazon.Signature)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAWS_Init(t *testing.T) {
|
||||
config := Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
|
@ -203,6 +228,7 @@ func TestAWS_Init(t *testing.T) {
|
|||
DisableCustomSANs bool
|
||||
DisableTrustOnFirstUse bool
|
||||
InstanceAge Duration
|
||||
IMDSVersions []string
|
||||
Claims *Claims
|
||||
}
|
||||
type args struct {
|
||||
|
@ -214,12 +240,15 @@ func TestAWS_Init(t *testing.T) {
|
|||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{"AWS", "name", []string{"account"}, false, false, zero, nil}, args{config}, false},
|
||||
{"ok", fields{"AWS", "name", []string{"account"}, true, true, Duration{Duration: 1 * time.Minute}, nil}, args{config}, false},
|
||||
{"fail type ", fields{"", "name", []string{"account"}, false, false, zero, nil}, args{config}, true},
|
||||
{"fail name", fields{"AWS", "", []string{"account"}, false, false, zero, nil}, args{config}, true},
|
||||
{"bad instance age", fields{"AWS", "name", []string{"account"}, false, false, Duration{Duration: -1 * time.Minute}, nil}, args{config}, true},
|
||||
{"fail claims", fields{"AWS", "name", []string{"account"}, false, false, zero, badClaims}, args{config}, true},
|
||||
{"ok", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, false},
|
||||
{"ok/v1", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1"}, nil}, args{config}, false},
|
||||
{"ok/v2", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v2"}, nil}, args{config}, false},
|
||||
{"ok/duration", fields{"AWS", "name", []string{"account"}, true, true, Duration{Duration: 1 * time.Minute}, []string{"v1", "v2"}, nil}, args{config}, false},
|
||||
{"fail type ", fields{"", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, true},
|
||||
{"fail name", fields{"AWS", "", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, true},
|
||||
{"bad instance age", fields{"AWS", "name", []string{"account"}, false, false, Duration{Duration: -1 * time.Minute}, []string{"v1", "v2"}, nil}, args{config}, true},
|
||||
{"fail/imds", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"bad"}, nil}, args{config}, true},
|
||||
{"fail claims", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, badClaims}, args{config}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -230,6 +259,7 @@ func TestAWS_Init(t *testing.T) {
|
|||
DisableCustomSANs: tt.fields.DisableCustomSANs,
|
||||
DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse,
|
||||
InstanceAge: tt.fields.InstanceAge,
|
||||
IMDSVersions: tt.fields.IMDSVersions,
|
||||
Claims: tt.fields.Claims,
|
||||
}
|
||||
if err := p.Init(tt.args.config); (err != nil) != tt.wantErr {
|
||||
|
|
|
@ -495,6 +495,101 @@ func generateAWSWithServer() (*AWS, *httptest.Server, error) {
|
|||
return aws, srv, nil
|
||||
}
|
||||
|
||||
func generateAWSV1Only() (*AWS, error) {
|
||||
name, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accountID, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
block, _ := pem.Decode([]byte(awsTestCertificate))
|
||||
if block == nil || block.Type != "CERTIFICATE" {
|
||||
return nil, errors.New("error decoding AWS certificate")
|
||||
}
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing AWS certificate")
|
||||
}
|
||||
return &AWS{
|
||||
Type: "AWS",
|
||||
Name: name,
|
||||
Accounts: []string{accountID},
|
||||
Claims: &globalProvisionerClaims,
|
||||
IMDSVersions: []string{"v1"},
|
||||
claimer: claimer,
|
||||
config: &awsConfig{
|
||||
identityURL: awsIdentityURL,
|
||||
signatureURL: awsSignatureURL,
|
||||
tokenURL: awsAPITokenURL,
|
||||
tokenTTL: awsAPITokenTTL,
|
||||
certificate: cert,
|
||||
signatureAlgorithm: awsSignatureAlgorithm,
|
||||
},
|
||||
audiences: testAudiences.WithFragment("aws/" + name),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) {
|
||||
aws, err := generateAWSV1Only()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
block, _ := pem.Decode([]byte(awsTestKey))
|
||||
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
||||
return nil, nil, errors.New("error decoding AWS key")
|
||||
}
|
||||
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "error parsing AWS private key")
|
||||
}
|
||||
doc, err := json.MarshalIndent(awsInstanceIdentityDocument{
|
||||
AccountID: aws.Accounts[0],
|
||||
Architecture: "x86_64",
|
||||
AvailabilityZone: "us-west-2b",
|
||||
ImageID: "image-id",
|
||||
InstanceID: "instance-id",
|
||||
InstanceType: "t2.micro",
|
||||
PendingTime: time.Now(),
|
||||
PrivateIP: "127.0.0.1",
|
||||
Region: "us-west-1",
|
||||
Version: "2017-09-30",
|
||||
}, "", " ")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sum := sha256.Sum256(doc)
|
||||
signature, err := key.Sign(rand.Reader, sum[:], crypto.SHA256)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "error signing document")
|
||||
}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/latest/dynamic/instance-identity/document":
|
||||
w.Write(doc)
|
||||
case "/latest/dynamic/instance-identity/signature":
|
||||
w.Write([]byte(base64.StdEncoding.EncodeToString(signature)))
|
||||
case "/bad-document":
|
||||
w.Write([]byte("{}"))
|
||||
case "/bad-signature":
|
||||
w.Write([]byte("YmFkLXNpZ25hdHVyZQo="))
|
||||
case "/bad-json":
|
||||
w.Write([]byte("{"))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
aws.config.identityURL = srv.URL + "/latest/dynamic/instance-identity/document"
|
||||
aws.config.signatureURL = srv.URL + "/latest/dynamic/instance-identity/signature"
|
||||
return aws, srv, nil
|
||||
}
|
||||
|
||||
func generateAzure() (*Azure, error) {
|
||||
name, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in a new issue