forked from TrueCloudLab/certificates
Allow custom common names in cloud identity provisioners.
This commit is contained in:
parent
0c3e0088cf
commit
900ab9cc12
10 changed files with 182 additions and 88 deletions
|
@ -171,7 +171,7 @@ func (p *AWS) GetEncryptedKey() (kid string, key string, ok bool) {
|
||||||
|
|
||||||
// GetIdentityToken retrieves the identity document and it's signature and
|
// GetIdentityToken retrieves the identity document and it's signature and
|
||||||
// generates a token with them.
|
// generates a token with them.
|
||||||
func (p *AWS) GetIdentityToken(caURL string) (string, error) {
|
func (p *AWS) GetIdentityToken(subject, caURL string) (string, error) {
|
||||||
// Initialize the config if this method is used from the cli.
|
// Initialize the config if this method is used from the cli.
|
||||||
if err := p.assertConfig(); err != nil {
|
if err := p.assertConfig(); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -221,7 +221,7 @@ func (p *AWS) GetIdentityToken(caURL string) (string, error) {
|
||||||
payload := awsPayload{
|
payload := awsPayload{
|
||||||
Claims: jose.Claims{
|
Claims: jose.Claims{
|
||||||
Issuer: awsIssuer,
|
Issuer: awsIssuer,
|
||||||
Subject: idoc.InstanceID,
|
Subject: subject,
|
||||||
Audience: []string{audience},
|
Audience: []string{audience},
|
||||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
NotBefore: jose.NewNumericDate(now),
|
NotBefore: jose.NewNumericDate(now),
|
||||||
|
@ -273,8 +273,8 @@ func (p *AWS) AuthorizeSign(token string) ([]SignOption, error) {
|
||||||
}
|
}
|
||||||
doc := payload.document
|
doc := payload.document
|
||||||
|
|
||||||
// Enforce default DNS and IP if configured.
|
// Enforce known CN and default DNS and IP if configured.
|
||||||
// By default we'll accept the SANs in the CSR.
|
// By default we'll accept the CN and SANs in the CSR.
|
||||||
// There's no way to trust them other than TOFU.
|
// There's no way to trust them other than TOFU.
|
||||||
var so []SignOption
|
var so []SignOption
|
||||||
if p.DisableCustomSANs {
|
if p.DisableCustomSANs {
|
||||||
|
@ -287,9 +287,9 @@ func (p *AWS) AuthorizeSign(token string) ([]SignOption, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return append(so,
|
return append(so,
|
||||||
commonNameValidator(doc.InstanceID),
|
commonNameValidator(payload.Claims.Subject),
|
||||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
||||||
newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID),
|
newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID),
|
||||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
||||||
), nil
|
), nil
|
||||||
}
|
}
|
||||||
|
@ -388,19 +388,26 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
||||||
// more than a few minutes.
|
// more than a few minutes.
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
if err = payload.ValidateWithLeeway(jose.Expected{
|
if err = payload.ValidateWithLeeway(jose.Expected{
|
||||||
Issuer: awsIssuer,
|
Issuer: awsIssuer,
|
||||||
Subject: doc.InstanceID,
|
Time: now,
|
||||||
Time: now,
|
|
||||||
}, time.Minute); err != nil {
|
}, time.Minute); err != nil {
|
||||||
return nil, errors.Wrapf(err, "invalid token")
|
return nil, errors.Wrapf(err, "invalid token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate audiences with the defaults
|
// validate audiences with the defaults
|
||||||
if !matchesAudience(payload.Audience, p.audiences.Sign) {
|
if !matchesAudience(payload.Audience, p.audiences.Sign) {
|
||||||
fmt.Println(payload.Audience, "vs", p.audiences.Sign)
|
|
||||||
return nil, errors.New("invalid token: invalid audience claim (aud)")
|
return nil, errors.New("invalid token: invalid audience claim (aud)")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate subject, it has to be known if disableCustomSANs is enabled
|
||||||
|
if p.DisableCustomSANs {
|
||||||
|
if payload.Subject != doc.InstanceID &&
|
||||||
|
payload.Subject != doc.PrivateIP &&
|
||||||
|
payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) {
|
||||||
|
return nil, errors.New("invalid token: invalid subject claim (sub)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// validate accounts
|
// validate accounts
|
||||||
if len(p.Accounts) > 0 {
|
if len(p.Accounts) > 0 {
|
||||||
var found bool
|
var found bool
|
||||||
|
|
|
@ -48,14 +48,14 @@ func TestAWS_GetTokenID(t *testing.T) {
|
||||||
p2.config = p1.config
|
p2.config = p1.config
|
||||||
p2.DisableTrustOnFirstUse = true
|
p2.DisableTrustOnFirstUse = true
|
||||||
|
|
||||||
t1, err := p1.GetIdentityToken("https://ca.smallstep.com")
|
t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
_, claims, err := parseAWSToken(t1)
|
_, claims, err := parseAWSToken(t1)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
sum := sha256.Sum256([]byte(fmt.Sprintf("%s.%s", p1.GetID(), claims.document.InstanceID)))
|
sum := sha256.Sum256([]byte(fmt.Sprintf("%s.%s", p1.GetID(), claims.document.InstanceID)))
|
||||||
w1 := strings.ToLower(hex.EncodeToString(sum[:]))
|
w1 := strings.ToLower(hex.EncodeToString(sum[:]))
|
||||||
|
|
||||||
t2, err := p2.GetIdentityToken("https://ca.smallstep.com")
|
t2, err := p2.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
sum = sha256.Sum256([]byte(t2))
|
sum = sha256.Sum256([]byte(t2))
|
||||||
w2 := strings.ToLower(hex.EncodeToString(sum[:]))
|
w2 := strings.ToLower(hex.EncodeToString(sum[:]))
|
||||||
|
@ -111,12 +111,31 @@ func TestAWS_GetIdentityToken(t *testing.T) {
|
||||||
p4.config.signatureURL = srv.URL + "/bad-signature"
|
p4.config.signatureURL = srv.URL + "/bad-signature"
|
||||||
p4.config.identityURL = p1.config.identityURL
|
p4.config.identityURL = p1.config.identityURL
|
||||||
|
|
||||||
|
p5, err := generateAWS()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p5.Accounts = p1.Accounts
|
||||||
|
p5.config.identityURL = "https://1234.1234.1234.1234"
|
||||||
|
p5.config.signatureURL = p1.config.signatureURL
|
||||||
|
|
||||||
|
p6, err := generateAWS()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p6.Accounts = p1.Accounts
|
||||||
|
p6.config.identityURL = p1.config.identityURL
|
||||||
|
p6.config.signatureURL = "https://1234.1234.1234.1234"
|
||||||
|
|
||||||
|
p7, err := generateAWS()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p7.Accounts = p1.Accounts
|
||||||
|
p7.config.identityURL = srv.URL + "/bad-json"
|
||||||
|
p7.config.signatureURL = p1.config.signatureURL
|
||||||
|
|
||||||
caURL := "https://ca.smallstep.com"
|
caURL := "https://ca.smallstep.com"
|
||||||
u, err := url.Parse(caURL)
|
u, err := url.Parse(caURL)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
caURL string
|
subject string
|
||||||
|
caURL string
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -124,15 +143,18 @@ func TestAWS_GetIdentityToken(t *testing.T) {
|
||||||
args args
|
args args
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", p1, args{caURL}, false},
|
{"ok", p1, args{"foo.local", caURL}, false},
|
||||||
{"fail ca url", p1, args{"://ca.smallstep.com"}, true},
|
{"fail ca url", p1, args{"foo.local", "://ca.smallstep.com"}, true},
|
||||||
{"fail identityURL", p2, args{caURL}, true},
|
{"fail identityURL", p2, args{"foo.local", caURL}, true},
|
||||||
{"fail signatureURL", p3, args{caURL}, true},
|
{"fail signatureURL", p3, args{"foo.local", caURL}, true},
|
||||||
{"fail signature", p4, args{caURL}, true},
|
{"fail signature", p4, args{"foo.local", caURL}, true},
|
||||||
|
{"fail read identityURL", p5, args{"foo.local", caURL}, true},
|
||||||
|
{"fail read signatureURL", p6, args{"foo.local", caURL}, true},
|
||||||
|
{"fail unmarshal identityURL", p7, args{"foo.local", caURL}, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := tt.aws.GetIdentityToken(tt.args.caURL)
|
got, err := tt.aws.GetIdentityToken(tt.args.subject, tt.args.caURL)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AWS.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AWS.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -141,7 +163,7 @@ func TestAWS_GetIdentityToken(t *testing.T) {
|
||||||
_, c, err := parseAWSToken(got)
|
_, c, err := parseAWSToken(got)
|
||||||
if assert.NoError(t, err) {
|
if assert.NoError(t, err) {
|
||||||
assert.Equals(t, awsIssuer, c.Issuer)
|
assert.Equals(t, awsIssuer, c.Issuer)
|
||||||
assert.Equals(t, c.document.InstanceID, c.Subject)
|
assert.Equals(t, tt.args.subject, c.Subject)
|
||||||
assert.Equals(t, jose.Audience{u.ResolveReference(&url.URL{Path: "/1.0/sign", Fragment: tt.aws.GetID()}).String()}, c.Audience)
|
assert.Equals(t, jose.Audience{u.ResolveReference(&url.URL{Path: "/1.0/sign", Fragment: tt.aws.GetID()}).String()}, c.Audience)
|
||||||
assert.Equals(t, tt.aws.Accounts[0], c.document.AccountID)
|
assert.Equals(t, tt.aws.Accounts[0], c.document.AccountID)
|
||||||
err = tt.aws.config.certificate.CheckSignature(
|
err = tt.aws.config.certificate.CheckSignature(
|
||||||
|
@ -221,12 +243,18 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p3.config = p1.config
|
p3.config = p1.config
|
||||||
|
|
||||||
t1, err := p1.GetIdentityToken("https://ca.smallstep.com")
|
t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
t2, err := p2.GetIdentityToken("https://ca.smallstep.com")
|
t2, err := p2.GetIdentityToken("instance-id", "https://ca.smallstep.com")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
t3, err := p3.GetIdentityToken("https://ca.smallstep.com")
|
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
t3, err := p3.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
// Alternative common names with DisableCustomSANs = true
|
||||||
|
t2PrivateIP, err := p2.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
t2Hostname, err := p2.GetIdentityToken("ip-127-0-0-1.us-west-1.compute.internal", "https://ca.smallstep.com")
|
||||||
|
|
||||||
block, _ := pem.Decode([]byte(awsTestKey))
|
block, _ := pem.Decode([]byte(awsTestKey))
|
||||||
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
||||||
|
@ -243,7 +271,7 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
||||||
"127.0.0.1", "us-west-1", time.Now(), key)
|
"127.0.0.1", "us-west-1", time.Now(), key)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
failSubject, err := generateAWSToken(
|
failSubject, err := generateAWSToken(
|
||||||
"bad-subject", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id",
|
"bad-subject", awsIssuer, p2.GetID(), p2.Accounts[0], "instance-id",
|
||||||
"127.0.0.1", "us-west-1", time.Now(), key)
|
"127.0.0.1", "us-west-1", time.Now(), key)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
failIssuer, err := generateAWSToken(
|
failIssuer, err := generateAWSToken(
|
||||||
|
@ -299,6 +327,8 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{"ok", p1, args{t1}, 4, false},
|
{"ok", p1, args{t1}, 4, false},
|
||||||
{"ok", p2, args{t2}, 6, false},
|
{"ok", p2, args{t2}, 6, false},
|
||||||
|
{"ok", p2, args{t2Hostname}, 6, false},
|
||||||
|
{"ok", p2, args{t2PrivateIP}, 6, false},
|
||||||
{"ok", p1, args{t4}, 4, false},
|
{"ok", p1, args{t4}, 4, false},
|
||||||
{"fail account", p3, args{t3}, 0, true},
|
{"fail account", p3, args{t3}, 0, true},
|
||||||
{"fail token", p1, args{"token"}, 0, true},
|
{"fail token", p1, args{"token"}, 0, true},
|
||||||
|
@ -364,7 +394,7 @@ func TestAWS_AuthorizeRevoke(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
t1, err := p1.GetIdentityToken("https://ca.smallstep.com")
|
t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
|
|
|
@ -141,7 +141,7 @@ func (p *Azure) GetEncryptedKey() (kid string, key string, ok bool) {
|
||||||
|
|
||||||
// GetIdentityToken retrieves from the metadata service the identity token and
|
// GetIdentityToken retrieves from the metadata service the identity token and
|
||||||
// returns it.
|
// returns it.
|
||||||
func (p *Azure) GetIdentityToken() (string, error) {
|
func (p *Azure) GetIdentityToken(subject, caURL string) (string, error) {
|
||||||
// Initialize the config if this method is used from the cli.
|
// Initialize the config if this method is used from the cli.
|
||||||
p.assertConfig()
|
p.assertConfig()
|
||||||
|
|
||||||
|
@ -264,17 +264,17 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enforce default DNS if configured.
|
// Enforce known common name and default DNS if configured.
|
||||||
// By default we'll accept the SANs in the CSR.
|
// By default we'll accept the CN and SANs in the CSR.
|
||||||
// There's no way to trust them other than TOFU.
|
// There's no way to trust them other than TOFU.
|
||||||
var so []SignOption
|
var so []SignOption
|
||||||
if p.DisableCustomSANs {
|
if p.DisableCustomSANs {
|
||||||
// name will work only inside the virtual network
|
// name will work only inside the virtual network
|
||||||
|
so = append(so, commonNameValidator(name))
|
||||||
so = append(so, dnsNamesValidator([]string{name}))
|
so = append(so, dnsNamesValidator([]string{name}))
|
||||||
}
|
}
|
||||||
|
|
||||||
return append(so,
|
return append(so,
|
||||||
commonNameValidator(name),
|
|
||||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
||||||
newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID),
|
newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID),
|
||||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
||||||
|
|
|
@ -46,9 +46,9 @@ func TestAzure_GetTokenID(t *testing.T) {
|
||||||
p2.keyStore = p1.keyStore
|
p2.keyStore = p1.keyStore
|
||||||
p2.DisableTrustOnFirstUse = true
|
p2.DisableTrustOnFirstUse = true
|
||||||
|
|
||||||
t1, err := p1.GetIdentityToken()
|
t1, err := p1.GetIdentityToken("subject", "caURL")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
t2, err := p2.GetIdentityToken()
|
t2, err := p2.GetIdentityToken("subject", "caURL")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
sum := sha256.Sum256([]byte("/subscriptions/subscriptionID/resourceGroups/resourceGroup/providers/Microsoft.Compute/virtualMachines/virtualMachine"))
|
sum := sha256.Sum256([]byte("/subscriptions/subscriptionID/resourceGroups/resourceGroup/providers/Microsoft.Compute/virtualMachines/virtualMachine"))
|
||||||
|
@ -105,23 +105,28 @@ func TestAzure_GetIdentityToken(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
subject string
|
||||||
|
caURL string
|
||||||
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
azure *Azure
|
azure *Azure
|
||||||
|
args args
|
||||||
identityTokenURL string
|
identityTokenURL string
|
||||||
want string
|
want string
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", p1, srv.URL, t1, false},
|
{"ok", p1, args{"subject", "caURL"}, srv.URL, t1, false},
|
||||||
{"fail request", p1, srv.URL + "/bad-request", "", true},
|
{"fail request", p1, args{"subject", "caURL"}, srv.URL + "/bad-request", "", true},
|
||||||
{"fail unmarshal", p1, srv.URL + "/bad-json", "", true},
|
{"fail unmarshal", p1, args{"subject", "caURL"}, srv.URL + "/bad-json", "", true},
|
||||||
{"fail url", p1, "://ca.smallstep.com", "", true},
|
{"fail url", p1, args{"subject", "caURL"}, "://ca.smallstep.com", "", true},
|
||||||
{"fail connect", p1, "foobarzar", "", true},
|
{"fail connect", p1, args{"subject", "caURL"}, "foobarzar", "", true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tt.azure.config.identityTokenURL = tt.identityTokenURL
|
tt.azure.config.identityTokenURL = tt.identityTokenURL
|
||||||
got, err := tt.azure.GetIdentityToken()
|
got, err := tt.azure.GetIdentityToken(tt.args.subject, tt.args.caURL)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("Azure.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Azure.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -231,13 +236,13 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
||||||
badKey, err := generateJSONWebKey()
|
badKey, err := generateJSONWebKey()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
t1, err := p1.GetIdentityToken()
|
t1, err := p1.GetIdentityToken("subject", "caURL")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
t2, err := p2.GetIdentityToken()
|
t2, err := p2.GetIdentityToken("subject", "caURL")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
t3, err := p3.GetIdentityToken()
|
t3, err := p3.GetIdentityToken("subject", "caURL")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
t4, err := p4.GetIdentityToken()
|
t4, err := p4.GetIdentityToken("subject", "caURL")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
||||||
|
@ -276,9 +281,9 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
||||||
wantLen int
|
wantLen int
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", p1, args{t1}, 4, false},
|
{"ok", p1, args{t1}, 3, false},
|
||||||
{"ok", p2, args{t2}, 5, false},
|
{"ok", p2, args{t2}, 5, false},
|
||||||
{"ok", p1, args{t11}, 4, false},
|
{"ok", p1, args{t11}, 3, false},
|
||||||
{"fail tenant", p3, args{t3}, 0, true},
|
{"fail tenant", p3, args{t3}, 0, true},
|
||||||
{"fail resource group", p4, args{t4}, 0, true},
|
{"fail resource group", p4, args{t4}, 0, true},
|
||||||
{"fail token", p1, args{"token"}, 0, true},
|
{"fail token", p1, args{"token"}, 0, true},
|
||||||
|
@ -338,7 +343,7 @@ func TestAzure_AuthorizeRevoke(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
token, err := az.GetIdentityToken()
|
token, err := az.GetIdentityToken("subject", "caURL")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
|
|
|
@ -150,7 +150,7 @@ func (p *GCP) GetIdentityURL(audience string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetIdentityToken does an HTTP request to the identity url.
|
// GetIdentityToken does an HTTP request to the identity url.
|
||||||
func (p *GCP) GetIdentityToken(caURL string) (string, error) {
|
func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) {
|
||||||
audience, err := generateSignAudience(caURL, p.GetID())
|
audience, err := generateSignAudience(caURL, p.GetID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -212,21 +212,24 @@ func (p *GCP) AuthorizeSign(token string) ([]SignOption, error) {
|
||||||
}
|
}
|
||||||
ce := claims.Google.ComputeEngine
|
ce := claims.Google.ComputeEngine
|
||||||
|
|
||||||
// Enforce default DNS if configured.
|
// Enforce known common name and default DNS if configured.
|
||||||
// By default we we'll accept the SANs in the CSR.
|
// By default we we'll accept the CN and SANs in the CSR.
|
||||||
// There's no way to trust them other than TOFU.
|
// There's no way to trust them other than TOFU.
|
||||||
var so []SignOption
|
var so []SignOption
|
||||||
if p.DisableCustomSANs {
|
if p.DisableCustomSANs {
|
||||||
|
dnsName1 := fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID)
|
||||||
|
dnsName2 := fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID)
|
||||||
|
so = append(so, commonNameSliceValidator([]string{
|
||||||
|
ce.InstanceName, ce.InstanceID, dnsName1, dnsName2,
|
||||||
|
}))
|
||||||
so = append(so, dnsNamesValidator([]string{
|
so = append(so, dnsNamesValidator([]string{
|
||||||
fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID),
|
dnsName1, dnsName2,
|
||||||
fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID),
|
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
return append(so,
|
return append(so,
|
||||||
commonNameValidator(ce.InstanceName),
|
|
||||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
||||||
newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject),
|
newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName),
|
||||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
||||||
), nil
|
), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -117,7 +117,8 @@ func TestGCP_GetIdentityToken(t *testing.T) {
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
caURL string
|
subject string
|
||||||
|
caURL string
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -127,16 +128,16 @@ func TestGCP_GetIdentityToken(t *testing.T) {
|
||||||
want string
|
want string
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", p1, args{"https://ca"}, srv.URL, t1, false},
|
{"ok", p1, args{"subject", "https://ca"}, srv.URL, t1, false},
|
||||||
{"fail ca url", p1, args{"://ca"}, srv.URL, "", true},
|
{"fail ca url", p1, args{"subject", "://ca"}, srv.URL, "", true},
|
||||||
{"fail request", p1, args{"https://ca"}, srv.URL + "/bad-request", "", true},
|
{"fail request", p1, args{"subject", "https://ca"}, srv.URL + "/bad-request", "", true},
|
||||||
{"fail url", p1, args{"https://ca"}, "://ca.smallstep.com", "", true},
|
{"fail url", p1, args{"subject", "https://ca"}, "://ca.smallstep.com", "", true},
|
||||||
{"fail connect", p1, args{"https://ca"}, "foobarzar", "", true},
|
{"fail connect", p1, args{"subject", "https://ca"}, "foobarzar", "", true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tt.gcp.config.IdentityURL = tt.identityURL
|
tt.gcp.config.IdentityURL = tt.identityURL
|
||||||
got, err := tt.gcp.GetIdentityToken(tt.args.caURL)
|
got, err := tt.gcp.GetIdentityToken(tt.args.subject, tt.args.caURL)
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("GCP.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("GCP.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
@ -310,9 +311,9 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
||||||
wantLen int
|
wantLen int
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", p1, args{t1}, 4, false},
|
{"ok", p1, args{t1}, 3, false},
|
||||||
{"ok", p2, args{t2}, 5, false},
|
{"ok", p2, args{t2}, 5, false},
|
||||||
{"ok", p3, args{t3}, 4, false},
|
{"ok", p3, args{t3}, 3, false},
|
||||||
{"fail token", p1, args{"token"}, 0, true},
|
{"fail token", p1, args{"token"}, 0, true},
|
||||||
{"fail key", p1, args{failKey}, 0, true},
|
{"fail key", p1, args{failKey}, 0, true},
|
||||||
{"fail iss", p1, args{failIss}, 0, true},
|
{"fail iss", p1, args{failIss}, 0, true},
|
||||||
|
|
|
@ -55,7 +55,12 @@ func (v profileWithOption) Option(Options) x509util.WithOption {
|
||||||
type profileDefaultDuration time.Duration
|
type profileDefaultDuration time.Duration
|
||||||
|
|
||||||
func (v profileDefaultDuration) Option(so Options) x509util.WithOption {
|
func (v profileDefaultDuration) Option(so Options) x509util.WithOption {
|
||||||
return x509util.WithNotBeforeAfterDuration(so.NotBefore.Time(), so.NotAfter.Time(), time.Duration(v))
|
notBefore := so.NotBefore.Time()
|
||||||
|
if notBefore.IsZero() {
|
||||||
|
notBefore = time.Now()
|
||||||
|
}
|
||||||
|
notAfter := so.NotAfter.RelativeTime(notBefore)
|
||||||
|
return x509util.WithNotBeforeAfterDuration(notBefore, notAfter, time.Duration(v))
|
||||||
}
|
}
|
||||||
|
|
||||||
// emailOnlyIdentity is a CertificateRequestValidator that checks that the only
|
// emailOnlyIdentity is a CertificateRequestValidator that checks that the only
|
||||||
|
@ -97,6 +102,21 @@ func (v commonNameValidator) Valid(req *x509.CertificateRequest) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// commonNameSliceValidator validates thats the common name of a certificate request is present in the slice.
|
||||||
|
type commonNameSliceValidator []string
|
||||||
|
|
||||||
|
func (v commonNameSliceValidator) Valid(req *x509.CertificateRequest) error {
|
||||||
|
if req.Subject.CommonName == "" {
|
||||||
|
return errors.New("certificate request cannot contain an empty common name")
|
||||||
|
}
|
||||||
|
for _, cn := range v {
|
||||||
|
if req.Subject.CommonName == cn {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errors.Errorf("certificate request does not contain the valid common name, got %s, want %s", req.Subject.CommonName, v)
|
||||||
|
}
|
||||||
|
|
||||||
// dnsNamesValidator validates the DNS names SAN of a certificate request.
|
// dnsNamesValidator validates the DNS names SAN of a certificate request.
|
||||||
type dnsNamesValidator []string
|
type dnsNamesValidator []string
|
||||||
|
|
||||||
|
@ -180,29 +200,32 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
type stepProvisionerASN1 struct {
|
type stepProvisionerASN1 struct {
|
||||||
Type int
|
Type int
|
||||||
Name []byte
|
Name []byte
|
||||||
CredentialID []byte
|
CredentialID []byte
|
||||||
|
KeyValuePairs []string `asn1:"optional,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type provisionerExtensionOption struct {
|
type provisionerExtensionOption struct {
|
||||||
Type int
|
Type int
|
||||||
Name string
|
Name string
|
||||||
CredentialID string
|
CredentialID string
|
||||||
|
KeyValuePairs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newProvisionerExtensionOption(typ Type, name, credentialID string) *provisionerExtensionOption {
|
func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValuePairs ...string) *provisionerExtensionOption {
|
||||||
return &provisionerExtensionOption{
|
return &provisionerExtensionOption{
|
||||||
Type: int(typ),
|
Type: int(typ),
|
||||||
Name: name,
|
Name: name,
|
||||||
CredentialID: credentialID,
|
CredentialID: credentialID,
|
||||||
|
KeyValuePairs: keyValuePairs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *provisionerExtensionOption) Option(Options) x509util.WithOption {
|
func (o *provisionerExtensionOption) Option(Options) x509util.WithOption {
|
||||||
return func(p x509util.Profile) error {
|
return func(p x509util.Profile) error {
|
||||||
crt := p.Subject()
|
crt := p.Subject()
|
||||||
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID)
|
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID, o.KeyValuePairs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -211,11 +234,12 @@ func (o *provisionerExtensionOption) Option(Options) x509util.WithOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createProvisionerExtension(typ int, name, credentialID string) (pkix.Extension, error) {
|
func createProvisionerExtension(typ int, name, credentialID string, keyValuePairs ...string) (pkix.Extension, error) {
|
||||||
b, err := asn1.Marshal(stepProvisionerASN1{
|
b, err := asn1.Marshal(stepProvisionerASN1{
|
||||||
Type: typ,
|
Type: typ,
|
||||||
Name: []byte(name),
|
Name: []byte(name),
|
||||||
CredentialID: []byte(credentialID),
|
CredentialID: []byte(credentialID),
|
||||||
|
KeyValuePairs: keyValuePairs,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return pkix.Extension{}, errors.Wrapf(err, "error marshaling provisioner extension")
|
return pkix.Extension{}, errors.Wrapf(err, "error marshaling provisioner extension")
|
||||||
|
|
|
@ -64,6 +64,30 @@ func Test_commonNameValidator_Valid(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_commonNameSliceValidator_Valid(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
req *x509.CertificateRequest
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
v commonNameSliceValidator
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "foo.bar.zar"}}}, false},
|
||||||
|
{"ok", []string{"example.com", "foo.bar.zar"}, args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "foo.bar.zar"}}}, false},
|
||||||
|
{"empty", []string{""}, args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: ""}}}, true},
|
||||||
|
{"wrong", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "example.com"}}}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("commonNameSliceValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Test_dnsNamesValidator_Valid(t *testing.T) {
|
func Test_dnsNamesValidator_Valid(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
req *x509.CertificateRequest
|
req *x509.CertificateRequest
|
||||||
|
|
|
@ -104,6 +104,12 @@ func (t *TimeDuration) UnmarshalJSON(data []byte) error {
|
||||||
|
|
||||||
// Time calculates the embedded time.Time, sets it if necessary, and returns it.
|
// Time calculates the embedded time.Time, sets it if necessary, and returns it.
|
||||||
func (t *TimeDuration) Time() time.Time {
|
func (t *TimeDuration) Time() time.Time {
|
||||||
|
return t.RelativeTime(now())
|
||||||
|
}
|
||||||
|
|
||||||
|
// RelativeTime returns the embedded time.Time or the base time plus the
|
||||||
|
// duration if this is not zero.
|
||||||
|
func (t *TimeDuration) RelativeTime(base time.Time) time.Time {
|
||||||
switch {
|
switch {
|
||||||
case t == nil:
|
case t == nil:
|
||||||
return time.Time{}
|
return time.Time{}
|
||||||
|
@ -111,8 +117,8 @@ func (t *TimeDuration) Time() time.Time {
|
||||||
if t.d == 0 {
|
if t.d == 0 {
|
||||||
return time.Time{}
|
return time.Time{}
|
||||||
}
|
}
|
||||||
t.t = now().Add(t.d)
|
t.t = base.Add(t.d)
|
||||||
return t.t
|
return t.t.UTC()
|
||||||
default:
|
default:
|
||||||
return t.t.UTC()
|
return t.t.UTC()
|
||||||
}
|
}
|
||||||
|
|
|
@ -283,20 +283,12 @@ func generateAWSWithServer() (*AWS, *httptest.Server, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, errors.Wrap(err, "error parsing AWS private key")
|
return nil, nil, errors.Wrap(err, "error parsing AWS private key")
|
||||||
}
|
}
|
||||||
instanceID, err := randutil.Alphanumeric(10)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
imageID, err := randutil.Alphanumeric(10)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
doc, err := json.MarshalIndent(awsInstanceIdentityDocument{
|
doc, err := json.MarshalIndent(awsInstanceIdentityDocument{
|
||||||
AccountID: aws.Accounts[0],
|
AccountID: aws.Accounts[0],
|
||||||
Architecture: "x86_64",
|
Architecture: "x86_64",
|
||||||
AvailabilityZone: "us-west-2b",
|
AvailabilityZone: "us-west-2b",
|
||||||
ImageID: imageID,
|
ImageID: "image-id",
|
||||||
InstanceID: instanceID,
|
InstanceID: "instance-id",
|
||||||
InstanceType: "t2.micro",
|
InstanceType: "t2.micro",
|
||||||
PendingTime: time.Now(),
|
PendingTime: time.Now(),
|
||||||
PrivateIP: "127.0.0.1",
|
PrivateIP: "127.0.0.1",
|
||||||
|
@ -322,6 +314,8 @@ func generateAWSWithServer() (*AWS, *httptest.Server, error) {
|
||||||
w.Write([]byte("{}"))
|
w.Write([]byte("{}"))
|
||||||
case "/bad-signature":
|
case "/bad-signature":
|
||||||
w.Write([]byte("YmFkLXNpZ25hdHVyZQo="))
|
w.Write([]byte("YmFkLXNpZ25hdHVyZQo="))
|
||||||
|
case "/bad-json":
|
||||||
|
w.Write([]byte("{"))
|
||||||
default:
|
default:
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue