51f16ee2e0
* Add constructor tests for the aws provisioner. * Add a test to make sure the "v1" logic continues to work. By and large, v2 is the way to go. However, there are some instances of things that specifically request metadata service version 1 and so this adds minimal coverage to make sure we don't accidentally break the path should anyone need to depend on the former logic.
795 lines
27 KiB
Go
795 lines
27 KiB
Go
package provisioner
|
|
|
|
import (
|
|
"context"
|
|
"crypto"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/sha256"
|
|
"crypto/x509"
|
|
"encoding/hex"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/smallstep/assert"
|
|
"github.com/smallstep/certificates/errs"
|
|
"github.com/smallstep/cli/jose"
|
|
)
|
|
|
|
func TestAWS_Getters(t *testing.T) {
|
|
p, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
aud := "aws/" + p.Name
|
|
if got := p.GetID(); got != aud {
|
|
t.Errorf("AWS.GetID() = %v, want %v", got, aud)
|
|
}
|
|
if got := p.GetName(); got != p.Name {
|
|
t.Errorf("AWS.GetName() = %v, want %v", got, p.Name)
|
|
}
|
|
if got := p.GetType(); got != TypeAWS {
|
|
t.Errorf("AWS.GetType() = %v, want %v", got, TypeAWS)
|
|
}
|
|
kid, key, ok := p.GetEncryptedKey()
|
|
if kid != "" || key != "" || ok == true {
|
|
t.Errorf("AWS.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
|
kid, key, ok, "", "", false)
|
|
}
|
|
}
|
|
|
|
func TestAWS_GetTokenID(t *testing.T) {
|
|
p1, srv, err := generateAWSWithServer()
|
|
assert.FatalError(t, err)
|
|
defer srv.Close()
|
|
|
|
p2, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
p2.Accounts = p1.Accounts
|
|
p2.config = p1.config
|
|
p2.DisableTrustOnFirstUse = true
|
|
|
|
t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
|
assert.FatalError(t, err)
|
|
_, claims, err := parseAWSToken(t1)
|
|
assert.FatalError(t, err)
|
|
sum := sha256.Sum256([]byte(fmt.Sprintf("%s.%s", p1.GetID(), claims.document.InstanceID)))
|
|
w1 := strings.ToLower(hex.EncodeToString(sum[:]))
|
|
|
|
t2, err := p2.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
|
assert.FatalError(t, err)
|
|
sum = sha256.Sum256([]byte(t2))
|
|
w2 := strings.ToLower(hex.EncodeToString(sum[:]))
|
|
|
|
type args struct {
|
|
token string
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
aws *AWS
|
|
args args
|
|
want string
|
|
wantErr bool
|
|
}{
|
|
{"ok", p1, args{t1}, w1, false},
|
|
{"ok no TOFU", p2, args{t2}, w2, false},
|
|
{"fail", p1, args{"bad-token"}, "", true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := tt.aws.GetTokenID(tt.args.token)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("AWS.GetTokenID() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if got != tt.want {
|
|
t.Errorf("AWS.GetTokenID() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAWS_GetIdentityToken(t *testing.T) {
|
|
p1, srv, err := generateAWSWithServer()
|
|
assert.FatalError(t, err)
|
|
defer srv.Close()
|
|
|
|
p2, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
p2.Accounts = p1.Accounts
|
|
p2.config.identityURL = srv.URL + "/bad-document"
|
|
p2.config.signatureURL = p1.config.signatureURL
|
|
p2.config.tokenURL = p1.config.tokenURL
|
|
|
|
p3, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
p3.Accounts = p1.Accounts
|
|
p3.config.signatureURL = srv.URL
|
|
p3.config.identityURL = p1.config.identityURL
|
|
p3.config.tokenURL = p1.config.tokenURL
|
|
|
|
p4, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
p4.Accounts = p1.Accounts
|
|
p4.config.signatureURL = srv.URL + "/bad-signature"
|
|
p4.config.identityURL = p1.config.identityURL
|
|
p4.config.tokenURL = p1.config.tokenURL
|
|
|
|
p5, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
p5.Accounts = p1.Accounts
|
|
p5.config.identityURL = "https://1234.1234.1234.1234"
|
|
p5.config.signatureURL = p1.config.signatureURL
|
|
p5.config.tokenURL = p1.config.tokenURL
|
|
|
|
p6, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
p6.Accounts = p1.Accounts
|
|
p6.config.identityURL = p1.config.identityURL
|
|
p6.config.signatureURL = "https://1234.1234.1234.1234"
|
|
p6.config.tokenURL = p1.config.tokenURL
|
|
|
|
p7, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
p7.Accounts = p1.Accounts
|
|
p7.config.identityURL = srv.URL + "/bad-json"
|
|
p7.config.signatureURL = p1.config.signatureURL
|
|
p7.config.tokenURL = p1.config.tokenURL
|
|
|
|
caURL := "https://ca.smallstep.com"
|
|
u, err := url.Parse(caURL)
|
|
assert.FatalError(t, err)
|
|
|
|
type args struct {
|
|
subject string
|
|
caURL string
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
aws *AWS
|
|
args args
|
|
wantErr bool
|
|
}{
|
|
{"ok", p1, args{"foo.local", caURL}, false},
|
|
{"fail ca url", p1, args{"foo.local", "://ca.smallstep.com"}, true},
|
|
{"fail identityURL", p2, args{"foo.local", caURL}, true},
|
|
{"fail signatureURL", p3, args{"foo.local", caURL}, true},
|
|
{"fail signature", p4, args{"foo.local", caURL}, true},
|
|
{"fail read identityURL", p5, args{"foo.local", caURL}, true},
|
|
{"fail read signatureURL", p6, args{"foo.local", caURL}, true},
|
|
{"fail unmarshal identityURL", p7, args{"foo.local", caURL}, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := tt.aws.GetIdentityToken(tt.args.subject, tt.args.caURL)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("AWS.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if tt.wantErr == false {
|
|
_, c, err := parseAWSToken(got)
|
|
if assert.NoError(t, err) {
|
|
assert.Equals(t, awsIssuer, c.Issuer)
|
|
assert.Equals(t, tt.args.subject, c.Subject)
|
|
assert.Equals(t, jose.Audience{u.ResolveReference(&url.URL{Path: "/1.0/sign", Fragment: tt.aws.GetID()}).String()}, c.Audience)
|
|
assert.Equals(t, tt.aws.Accounts[0], c.document.AccountID)
|
|
err = tt.aws.config.certificate.CheckSignature(
|
|
tt.aws.config.signatureAlgorithm, c.Amazon.Document, c.Amazon.Signature)
|
|
assert.NoError(t, err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAWS_GetIdentityTokenV1Only(t *testing.T) {
|
|
aws, srv, err := generateAWSWithServerV1Only()
|
|
assert.FatalError(t, err)
|
|
defer srv.Close()
|
|
|
|
subject := "foo.local"
|
|
caURL := "https://ca.smallstep.com"
|
|
u, err := url.Parse(caURL)
|
|
assert.Nil(t, err)
|
|
|
|
token, err := aws.GetIdentityToken(subject, caURL)
|
|
assert.Nil(t, err)
|
|
|
|
_, c, err := parseAWSToken(token)
|
|
if assert.NoError(t, err) {
|
|
assert.Equals(t, awsIssuer, c.Issuer)
|
|
assert.Equals(t, subject, c.Subject)
|
|
assert.Equals(t, jose.Audience{u.ResolveReference(&url.URL{Path: "/1.0/sign", Fragment: aws.GetID()}).String()}, c.Audience)
|
|
assert.Equals(t, aws.Accounts[0], c.document.AccountID)
|
|
err = aws.config.certificate.CheckSignature(
|
|
aws.config.signatureAlgorithm, c.Amazon.Document, c.Amazon.Signature)
|
|
assert.NoError(t, err)
|
|
}
|
|
}
|
|
|
|
func TestAWS_Init(t *testing.T) {
|
|
config := Config{
|
|
Claims: globalProvisionerClaims,
|
|
}
|
|
badClaims := &Claims{
|
|
DefaultTLSDur: &Duration{0},
|
|
}
|
|
zero := Duration{Duration: 0}
|
|
|
|
type fields struct {
|
|
Type string
|
|
Name string
|
|
Accounts []string
|
|
DisableCustomSANs bool
|
|
DisableTrustOnFirstUse bool
|
|
InstanceAge Duration
|
|
IMDSVersions []string
|
|
Claims *Claims
|
|
}
|
|
type args struct {
|
|
config Config
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
args args
|
|
wantErr bool
|
|
}{
|
|
{"ok", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, false},
|
|
{"ok/v1", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1"}, nil}, args{config}, false},
|
|
{"ok/v2", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v2"}, nil}, args{config}, false},
|
|
{"ok/duration", fields{"AWS", "name", []string{"account"}, true, true, Duration{Duration: 1 * time.Minute}, []string{"v1", "v2"}, nil}, args{config}, false},
|
|
{"fail type ", fields{"", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, true},
|
|
{"fail name", fields{"AWS", "", []string{"account"}, false, false, zero, []string{"v1", "v2"}, nil}, args{config}, true},
|
|
{"bad instance age", fields{"AWS", "name", []string{"account"}, false, false, Duration{Duration: -1 * time.Minute}, []string{"v1", "v2"}, nil}, args{config}, true},
|
|
{"fail/imds", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"bad"}, nil}, args{config}, true},
|
|
{"fail claims", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, badClaims}, args{config}, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
p := &AWS{
|
|
Type: tt.fields.Type,
|
|
Name: tt.fields.Name,
|
|
Accounts: tt.fields.Accounts,
|
|
DisableCustomSANs: tt.fields.DisableCustomSANs,
|
|
DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse,
|
|
InstanceAge: tt.fields.InstanceAge,
|
|
IMDSVersions: tt.fields.IMDSVersions,
|
|
Claims: tt.fields.Claims,
|
|
}
|
|
if err := p.Init(tt.args.config); (err != nil) != tt.wantErr {
|
|
t.Errorf("AWS.Init() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAWS_authorizeToken(t *testing.T) {
|
|
block, _ := pem.Decode([]byte(awsTestKey))
|
|
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
|
t.Fatal("error decoding AWS key")
|
|
}
|
|
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
assert.FatalError(t, err)
|
|
badKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
|
assert.FatalError(t, err)
|
|
|
|
type test struct {
|
|
p *AWS
|
|
token string
|
|
err error
|
|
code int
|
|
}
|
|
tests := map[string]func(*testing.T) test{
|
|
"fail/bad-token": func(t *testing.T) test {
|
|
p, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
return test{
|
|
p: p,
|
|
token: "foo",
|
|
code: http.StatusUnauthorized,
|
|
err: errors.New("aws.authorizeToken; error parsing aws token"),
|
|
}
|
|
},
|
|
"fail/cannot-validate-sig": func(t *testing.T) test {
|
|
p, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
tok, err := generateAWSToken(
|
|
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
|
|
"127.0.0.1", "us-west-1", time.Now(), badKey)
|
|
assert.FatalError(t, err)
|
|
return test{
|
|
p: p,
|
|
token: tok,
|
|
code: http.StatusUnauthorized,
|
|
err: errors.New("aws.authorizeToken; invalid aws token signature"),
|
|
}
|
|
},
|
|
"fail/empty-account-id": func(t *testing.T) test {
|
|
p, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
tok, err := generateAWSToken(
|
|
"instance-id", awsIssuer, p.GetID(), "", "instance-id",
|
|
"127.0.0.1", "us-west-1", time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
return test{
|
|
p: p,
|
|
token: tok,
|
|
code: http.StatusUnauthorized,
|
|
err: errors.New("aws.authorizeToken; aws identity document accountId cannot be empty"),
|
|
}
|
|
},
|
|
"fail/empty-instance-id": func(t *testing.T) test {
|
|
p, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
tok, err := generateAWSToken(
|
|
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "",
|
|
"127.0.0.1", "us-west-1", time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
return test{
|
|
p: p,
|
|
token: tok,
|
|
code: http.StatusUnauthorized,
|
|
err: errors.New("aws.authorizeToken; aws identity document instanceId cannot be empty"),
|
|
}
|
|
},
|
|
"fail/empty-private-ip": func(t *testing.T) test {
|
|
p, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
tok, err := generateAWSToken(
|
|
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
|
|
"", "us-west-1", time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
return test{
|
|
p: p,
|
|
token: tok,
|
|
code: http.StatusUnauthorized,
|
|
err: errors.New("aws.authorizeToken; aws identity document privateIp cannot be empty"),
|
|
}
|
|
},
|
|
"fail/empty-region": func(t *testing.T) test {
|
|
p, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
tok, err := generateAWSToken(
|
|
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
|
|
"127.0.0.1", "", time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
return test{
|
|
p: p,
|
|
token: tok,
|
|
code: http.StatusUnauthorized,
|
|
err: errors.New("aws.authorizeToken; aws identity document region cannot be empty"),
|
|
}
|
|
},
|
|
"fail/invalid-token-issuer": func(t *testing.T) test {
|
|
p, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
tok, err := generateAWSToken(
|
|
"instance-id", "bad-issuer", p.GetID(), p.Accounts[0], "instance-id",
|
|
"127.0.0.1", "us-west-1", time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
return test{
|
|
p: p,
|
|
token: tok,
|
|
code: http.StatusUnauthorized,
|
|
err: errors.New("aws.authorizeToken; invalid aws token"),
|
|
}
|
|
},
|
|
"fail/invalid-audience": func(t *testing.T) test {
|
|
p, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
tok, err := generateAWSToken(
|
|
"instance-id", awsIssuer, "bad-audience", p.Accounts[0], "instance-id",
|
|
"127.0.0.1", "us-west-1", time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
return test{
|
|
p: p,
|
|
token: tok,
|
|
code: http.StatusUnauthorized,
|
|
err: errors.New("aws.authorizeToken; invalid token - invalid audience claim (aud)"),
|
|
}
|
|
},
|
|
"fail/invalid-subject-disabled-custom-SANs": func(t *testing.T) test {
|
|
p, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
p.DisableCustomSANs = true
|
|
tok, err := generateAWSToken(
|
|
"foo", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
|
|
"127.0.0.1", "us-west-1", time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
return test{
|
|
p: p,
|
|
token: tok,
|
|
code: http.StatusUnauthorized,
|
|
err: errors.New("aws.authorizeToken; invalid token - invalid subject claim (sub)"),
|
|
}
|
|
},
|
|
"fail/invalid-account-id": func(t *testing.T) test {
|
|
p, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
tok, err := generateAWSToken(
|
|
"instance-id", awsIssuer, p.GetID(), "foo", "instance-id",
|
|
"127.0.0.1", "us-west-1", time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
return test{
|
|
p: p,
|
|
token: tok,
|
|
code: http.StatusUnauthorized,
|
|
err: errors.New("aws.authorizeToken; invalid aws identity document - accountId is not valid"),
|
|
}
|
|
},
|
|
"fail/instance-age": func(t *testing.T) test {
|
|
p, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
p.InstanceAge = Duration{1 * time.Minute}
|
|
tok, err := generateAWSToken(
|
|
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
|
|
"127.0.0.1", "us-west-1", time.Now().Add(-1*time.Minute), key)
|
|
assert.FatalError(t, err)
|
|
return test{
|
|
p: p,
|
|
token: tok,
|
|
code: http.StatusUnauthorized,
|
|
err: errors.New("aws.authorizeToken; aws identity document pendingTime is too old"),
|
|
}
|
|
},
|
|
"ok": func(t *testing.T) test {
|
|
p, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
tok, err := generateAWSToken(
|
|
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
|
|
"127.0.0.1", "us-west-1", time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
return test{
|
|
p: p,
|
|
token: tok,
|
|
}
|
|
},
|
|
}
|
|
for name, tt := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
tc := tt(t)
|
|
if claims, err := tc.p.authorizeToken(tc.token); err != nil {
|
|
if assert.NotNil(t, tc.err) {
|
|
sc, ok := err.(errs.StatusCoder)
|
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
|
}
|
|
} else {
|
|
if assert.Nil(t, tc.err) && assert.NotNil(t, claims) {
|
|
assert.Equals(t, claims.Subject, "instance-id")
|
|
assert.Equals(t, claims.Issuer, awsIssuer)
|
|
assert.NotNil(t, claims.Amazon)
|
|
|
|
aud, err := generateSignAudience("https://ca.smallstep.com", tc.p.GetID())
|
|
assert.FatalError(t, err)
|
|
assert.Equals(t, claims.Audience[0], aud)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAWS_AuthorizeSign(t *testing.T) {
|
|
p1, srv, err := generateAWSWithServer()
|
|
assert.FatalError(t, err)
|
|
defer srv.Close()
|
|
|
|
p2, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
p2.Accounts = p1.Accounts
|
|
p2.config = p1.config
|
|
p2.DisableCustomSANs = true
|
|
p2.InstanceAge = Duration{1 * time.Minute}
|
|
|
|
p3, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
p3.config = p1.config
|
|
|
|
t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
|
assert.FatalError(t, err)
|
|
t2, err := p2.GetIdentityToken("instance-id", "https://ca.smallstep.com")
|
|
assert.FatalError(t, err)
|
|
assert.FatalError(t, err)
|
|
t3, err := p3.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
|
assert.FatalError(t, err)
|
|
|
|
// Alternative common names with DisableCustomSANs = true
|
|
t2PrivateIP, err := p2.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com")
|
|
assert.FatalError(t, err)
|
|
t2Hostname, err := p2.GetIdentityToken("ip-127-0-0-1.us-west-1.compute.internal", "https://ca.smallstep.com")
|
|
assert.FatalError(t, err)
|
|
|
|
block, _ := pem.Decode([]byte(awsTestKey))
|
|
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
|
t.Fatal("error decoding AWS key")
|
|
}
|
|
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
assert.FatalError(t, err)
|
|
|
|
badKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
|
assert.FatalError(t, err)
|
|
|
|
t4, err := generateAWSToken(
|
|
"instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id",
|
|
"127.0.0.1", "us-west-1", time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
failSubject, err := generateAWSToken(
|
|
"bad-subject", awsIssuer, p2.GetID(), p2.Accounts[0], "instance-id",
|
|
"127.0.0.1", "us-west-1", time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
failIssuer, err := generateAWSToken(
|
|
"instance-id", "bad-issuer", p1.GetID(), p1.Accounts[0], "instance-id",
|
|
"127.0.0.1", "us-west-1", time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
failAudience, err := generateAWSToken(
|
|
"instance-id", awsIssuer, "bad-audience", p1.Accounts[0], "instance-id",
|
|
"127.0.0.1", "us-west-1", time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
failAccount, err := generateAWSToken(
|
|
"instance-id", awsIssuer, p1.GetID(), "", "instance-id",
|
|
"127.0.0.1", "us-west-1", time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
failInstanceID, err := generateAWSToken(
|
|
"instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "",
|
|
"127.0.0.1", "us-west-1", time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
failPrivateIP, err := generateAWSToken(
|
|
"instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id",
|
|
"", "us-west-1", time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
failRegion, err := generateAWSToken(
|
|
"instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id",
|
|
"127.0.0.1", "", time.Now(), key)
|
|
assert.FatalError(t, err)
|
|
failExp, err := generateAWSToken(
|
|
"instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id",
|
|
"127.0.0.1", "us-west-1", time.Now().Add(-360*time.Second), key)
|
|
assert.FatalError(t, err)
|
|
failNbf, err := generateAWSToken(
|
|
"instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id",
|
|
"127.0.0.1", "us-west-1", time.Now().Add(360*time.Second), key)
|
|
assert.FatalError(t, err)
|
|
failKey, err := generateAWSToken(
|
|
"instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id",
|
|
"127.0.0.1", "us-west-1", time.Now(), badKey)
|
|
assert.FatalError(t, err)
|
|
failInstanceAge, err := generateAWSToken(
|
|
"instance-id", awsIssuer, p2.GetID(), p2.Accounts[0], "instance-id",
|
|
"127.0.0.1", "us-west-1", time.Now().Add(-1*time.Minute), key)
|
|
assert.FatalError(t, err)
|
|
|
|
type args struct {
|
|
token, cn string
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
aws *AWS
|
|
args args
|
|
wantLen int
|
|
code int
|
|
wantErr bool
|
|
}{
|
|
{"ok", p1, args{t1, "foo.local"}, 5, http.StatusOK, false},
|
|
{"ok", p2, args{t2, "instance-id"}, 9, http.StatusOK, false},
|
|
{"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 9, http.StatusOK, false},
|
|
{"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 9, http.StatusOK, false},
|
|
{"ok", p1, args{t4, "instance-id"}, 5, http.StatusOK, false},
|
|
{"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true},
|
|
{"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true},
|
|
{"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true},
|
|
{"fail issuer", p1, args{token: failIssuer}, 0, http.StatusUnauthorized, true},
|
|
{"fail audience", p1, args{token: failAudience}, 0, http.StatusUnauthorized, true},
|
|
{"fail account", p1, args{token: failAccount}, 0, http.StatusUnauthorized, true},
|
|
{"fail instanceID", p1, args{token: failInstanceID}, 0, http.StatusUnauthorized, true},
|
|
{"fail privateIP", p1, args{token: failPrivateIP}, 0, http.StatusUnauthorized, true},
|
|
{"fail region", p1, args{token: failRegion}, 0, http.StatusUnauthorized, true},
|
|
{"fail exp", p1, args{token: failExp}, 0, http.StatusUnauthorized, true},
|
|
{"fail nbf", p1, args{token: failNbf}, 0, http.StatusUnauthorized, true},
|
|
{"fail key", p1, args{token: failKey}, 0, http.StatusUnauthorized, true},
|
|
{"fail instance age", p2, args{token: failInstanceAge}, 0, http.StatusUnauthorized, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
|
got, err := tt.aws.AuthorizeSign(ctx, tt.args.token)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
} else if err != nil {
|
|
sc, ok := err.(errs.StatusCoder)
|
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
|
} else {
|
|
assert.Len(t, tt.wantLen, got)
|
|
for _, o := range got {
|
|
switch v := o.(type) {
|
|
case *provisionerExtensionOption:
|
|
assert.Equals(t, v.Type, int(TypeAWS))
|
|
assert.Equals(t, v.Name, tt.aws.GetName())
|
|
assert.Equals(t, v.CredentialID, tt.aws.Accounts[0])
|
|
assert.Len(t, 2, v.KeyValuePairs)
|
|
case profileDefaultDuration:
|
|
assert.Equals(t, time.Duration(v), tt.aws.claimer.DefaultTLSCertDuration())
|
|
case commonNameValidator:
|
|
assert.Equals(t, string(v), tt.args.cn)
|
|
case defaultPublicKeyValidator:
|
|
case *validityValidator:
|
|
assert.Equals(t, v.min, tt.aws.claimer.MinTLSCertDuration())
|
|
assert.Equals(t, v.max, tt.aws.claimer.MaxTLSCertDuration())
|
|
case ipAddressesValidator:
|
|
assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")})
|
|
case emailAddressesValidator:
|
|
assert.Equals(t, v, nil)
|
|
case urisValidator:
|
|
assert.Equals(t, v, nil)
|
|
case dnsNamesValidator:
|
|
assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"})
|
|
default:
|
|
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAWS_AuthorizeSSHSign(t *testing.T) {
|
|
tm, fn := mockNow()
|
|
defer fn()
|
|
|
|
p1, srv, err := generateAWSWithServer()
|
|
assert.FatalError(t, err)
|
|
p1.DisableCustomSANs = true
|
|
defer srv.Close()
|
|
|
|
p2, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
p2.Accounts = p1.Accounts
|
|
p2.config = p1.config
|
|
p2.DisableCustomSANs = false
|
|
|
|
p3, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
// disable sshCA
|
|
disable := false
|
|
p3.Claims = &Claims{EnableSSHCA: &disable}
|
|
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
|
|
assert.FatalError(t, err)
|
|
|
|
t1, err := p1.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com")
|
|
assert.FatalError(t, err)
|
|
|
|
t2, err := p2.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
|
assert.FatalError(t, err)
|
|
|
|
key, err := generateJSONWebKey()
|
|
assert.FatalError(t, err)
|
|
|
|
signer, err := generateJSONWebKey()
|
|
assert.FatalError(t, err)
|
|
|
|
pub := key.Public().Key
|
|
rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
assert.FatalError(t, err)
|
|
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
|
assert.FatalError(t, err)
|
|
|
|
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
|
expectedHostOptions := &SSHOptions{
|
|
CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"},
|
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
|
}
|
|
expectedHostOptionsIP := &SSHOptions{
|
|
CertType: "host", Principals: []string{"127.0.0.1"},
|
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
|
}
|
|
expectedHostOptionsHostname := &SSHOptions{
|
|
CertType: "host", Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"},
|
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
|
}
|
|
expectedCustomOptions := &SSHOptions{
|
|
CertType: "host", Principals: []string{"foo.local"},
|
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
|
}
|
|
|
|
type args struct {
|
|
token string
|
|
sshOpts SSHOptions
|
|
key interface{}
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
aws *AWS
|
|
args args
|
|
expected *SSHOptions
|
|
code int
|
|
wantErr bool
|
|
wantSignErr bool
|
|
}{
|
|
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
|
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, http.StatusOK, false, false},
|
|
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
|
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
|
{"ok-principal-ip", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1"}}, pub}, expectedHostOptionsIP, http.StatusOK, false, false},
|
|
{"ok-principal-hostname", p1, args{t1, SSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptionsHostname, http.StatusOK, false, false},
|
|
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
|
{"ok-custom", p2, args{t2, SSHOptions{Principals: []string{"foo.local"}}, pub}, expectedCustomOptions, http.StatusOK, false, false},
|
|
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, http.StatusOK, false, true},
|
|
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, http.StatusOK, false, true},
|
|
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
|
|
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
|
|
{"fail-sshCA-disabled", p3, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
|
|
{"fail-invalid-token", p1, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := tt.aws.AuthorizeSSHSign(context.Background(), tt.args.token)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("AWS.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if err != nil {
|
|
sc, ok := err.(errs.StatusCoder)
|
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
|
assert.Nil(t, got)
|
|
} else if assert.NotNil(t, got) {
|
|
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
|
if (err != nil) != tt.wantSignErr {
|
|
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
|
} else {
|
|
if tt.wantSignErr {
|
|
assert.Nil(t, cert)
|
|
} else {
|
|
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAWS_AuthorizeRenew(t *testing.T) {
|
|
p1, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
p2, err := generateAWS()
|
|
assert.FatalError(t, err)
|
|
|
|
// disable renewal
|
|
disable := true
|
|
p2.Claims = &Claims{DisableRenewal: &disable}
|
|
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
|
assert.FatalError(t, err)
|
|
|
|
type args struct {
|
|
cert *x509.Certificate
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
aws *AWS
|
|
args args
|
|
code int
|
|
wantErr bool
|
|
}{
|
|
{"ok", p1, args{nil}, http.StatusOK, false},
|
|
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
|
|
t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
|
} else if err != nil {
|
|
sc, ok := err.(errs.StatusCoder)
|
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
|
}
|
|
})
|
|
}
|
|
}
|