forked from TrueCloudLab/certificates
Allow custom common names in cloud identity provisioners.
This commit is contained in:
parent
0c3e0088cf
commit
900ab9cc12
10 changed files with 182 additions and 88 deletions
|
@ -171,7 +171,7 @@ func (p *AWS) GetEncryptedKey() (kid string, key string, ok bool) {
|
|||
|
||||
// GetIdentityToken retrieves the identity document and it's signature and
|
||||
// generates a token with them.
|
||||
func (p *AWS) GetIdentityToken(caURL string) (string, error) {
|
||||
func (p *AWS) GetIdentityToken(subject, caURL string) (string, error) {
|
||||
// Initialize the config if this method is used from the cli.
|
||||
if err := p.assertConfig(); err != nil {
|
||||
return "", err
|
||||
|
@ -221,7 +221,7 @@ func (p *AWS) GetIdentityToken(caURL string) (string, error) {
|
|||
payload := awsPayload{
|
||||
Claims: jose.Claims{
|
||||
Issuer: awsIssuer,
|
||||
Subject: idoc.InstanceID,
|
||||
Subject: subject,
|
||||
Audience: []string{audience},
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
|
@ -273,8 +273,8 @@ func (p *AWS) AuthorizeSign(token string) ([]SignOption, error) {
|
|||
}
|
||||
doc := payload.document
|
||||
|
||||
// Enforce default DNS and IP if configured.
|
||||
// By default we'll accept the SANs in the CSR.
|
||||
// Enforce known CN and default DNS and IP if configured.
|
||||
// By default we'll accept the CN and SANs in the CSR.
|
||||
// There's no way to trust them other than TOFU.
|
||||
var so []SignOption
|
||||
if p.DisableCustomSANs {
|
||||
|
@ -287,9 +287,9 @@ func (p *AWS) AuthorizeSign(token string) ([]SignOption, error) {
|
|||
}
|
||||
|
||||
return append(so,
|
||||
commonNameValidator(doc.InstanceID),
|
||||
commonNameValidator(payload.Claims.Subject),
|
||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
||||
newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID),
|
||||
newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID),
|
||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
||||
), nil
|
||||
}
|
||||
|
@ -388,19 +388,26 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
|||
// more than a few minutes.
|
||||
now := time.Now().UTC()
|
||||
if err = payload.ValidateWithLeeway(jose.Expected{
|
||||
Issuer: awsIssuer,
|
||||
Subject: doc.InstanceID,
|
||||
Time: now,
|
||||
Issuer: awsIssuer,
|
||||
Time: now,
|
||||
}, time.Minute); err != nil {
|
||||
return nil, errors.Wrapf(err, "invalid token")
|
||||
}
|
||||
|
||||
// validate audiences with the defaults
|
||||
if !matchesAudience(payload.Audience, p.audiences.Sign) {
|
||||
fmt.Println(payload.Audience, "vs", p.audiences.Sign)
|
||||
return nil, errors.New("invalid token: invalid audience claim (aud)")
|
||||
}
|
||||
|
||||
// Validate subject, it has to be known if disableCustomSANs is enabled
|
||||
if p.DisableCustomSANs {
|
||||
if payload.Subject != doc.InstanceID &&
|
||||
payload.Subject != doc.PrivateIP &&
|
||||
payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) {
|
||||
return nil, errors.New("invalid token: invalid subject claim (sub)")
|
||||
}
|
||||
}
|
||||
|
||||
// validate accounts
|
||||
if len(p.Accounts) > 0 {
|
||||
var found bool
|
||||
|
|
|
@ -48,14 +48,14 @@ func TestAWS_GetTokenID(t *testing.T) {
|
|||
p2.config = p1.config
|
||||
p2.DisableTrustOnFirstUse = true
|
||||
|
||||
t1, err := p1.GetIdentityToken("https://ca.smallstep.com")
|
||||
t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
||||
assert.FatalError(t, err)
|
||||
_, claims, err := parseAWSToken(t1)
|
||||
assert.FatalError(t, err)
|
||||
sum := sha256.Sum256([]byte(fmt.Sprintf("%s.%s", p1.GetID(), claims.document.InstanceID)))
|
||||
w1 := strings.ToLower(hex.EncodeToString(sum[:]))
|
||||
|
||||
t2, err := p2.GetIdentityToken("https://ca.smallstep.com")
|
||||
t2, err := p2.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
||||
assert.FatalError(t, err)
|
||||
sum = sha256.Sum256([]byte(t2))
|
||||
w2 := strings.ToLower(hex.EncodeToString(sum[:]))
|
||||
|
@ -111,12 +111,31 @@ func TestAWS_GetIdentityToken(t *testing.T) {
|
|||
p4.config.signatureURL = srv.URL + "/bad-signature"
|
||||
p4.config.identityURL = p1.config.identityURL
|
||||
|
||||
p5, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
p5.Accounts = p1.Accounts
|
||||
p5.config.identityURL = "https://1234.1234.1234.1234"
|
||||
p5.config.signatureURL = p1.config.signatureURL
|
||||
|
||||
p6, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
p6.Accounts = p1.Accounts
|
||||
p6.config.identityURL = p1.config.identityURL
|
||||
p6.config.signatureURL = "https://1234.1234.1234.1234"
|
||||
|
||||
p7, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
p7.Accounts = p1.Accounts
|
||||
p7.config.identityURL = srv.URL + "/bad-json"
|
||||
p7.config.signatureURL = p1.config.signatureURL
|
||||
|
||||
caURL := "https://ca.smallstep.com"
|
||||
u, err := url.Parse(caURL)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
caURL string
|
||||
subject string
|
||||
caURL string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -124,15 +143,18 @@ func TestAWS_GetIdentityToken(t *testing.T) {
|
|||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{caURL}, false},
|
||||
{"fail ca url", p1, args{"://ca.smallstep.com"}, true},
|
||||
{"fail identityURL", p2, args{caURL}, true},
|
||||
{"fail signatureURL", p3, args{caURL}, true},
|
||||
{"fail signature", p4, args{caURL}, true},
|
||||
{"ok", p1, args{"foo.local", caURL}, false},
|
||||
{"fail ca url", p1, args{"foo.local", "://ca.smallstep.com"}, true},
|
||||
{"fail identityURL", p2, args{"foo.local", caURL}, true},
|
||||
{"fail signatureURL", p3, args{"foo.local", caURL}, true},
|
||||
{"fail signature", p4, args{"foo.local", caURL}, true},
|
||||
{"fail read identityURL", p5, args{"foo.local", caURL}, true},
|
||||
{"fail read signatureURL", p6, args{"foo.local", caURL}, true},
|
||||
{"fail unmarshal identityURL", p7, args{"foo.local", caURL}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.aws.GetIdentityToken(tt.args.caURL)
|
||||
got, err := tt.aws.GetIdentityToken(tt.args.subject, tt.args.caURL)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("AWS.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -141,7 +163,7 @@ func TestAWS_GetIdentityToken(t *testing.T) {
|
|||
_, c, err := parseAWSToken(got)
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equals(t, awsIssuer, c.Issuer)
|
||||
assert.Equals(t, c.document.InstanceID, c.Subject)
|
||||
assert.Equals(t, tt.args.subject, c.Subject)
|
||||
assert.Equals(t, jose.Audience{u.ResolveReference(&url.URL{Path: "/1.0/sign", Fragment: tt.aws.GetID()}).String()}, c.Audience)
|
||||
assert.Equals(t, tt.aws.Accounts[0], c.document.AccountID)
|
||||
err = tt.aws.config.certificate.CheckSignature(
|
||||
|
@ -221,12 +243,18 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
p3.config = p1.config
|
||||
|
||||
t1, err := p1.GetIdentityToken("https://ca.smallstep.com")
|
||||
t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
||||
assert.FatalError(t, err)
|
||||
t2, err := p2.GetIdentityToken("https://ca.smallstep.com")
|
||||
t2, err := p2.GetIdentityToken("instance-id", "https://ca.smallstep.com")
|
||||
assert.FatalError(t, err)
|
||||
t3, err := p3.GetIdentityToken("https://ca.smallstep.com")
|
||||
assert.FatalError(t, err)
|
||||
t3, err := p3.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// Alternative common names with DisableCustomSANs = true
|
||||
t2PrivateIP, err := p2.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com")
|
||||
assert.FatalError(t, err)
|
||||
t2Hostname, err := p2.GetIdentityToken("ip-127-0-0-1.us-west-1.compute.internal", "https://ca.smallstep.com")
|
||||
|
||||
block, _ := pem.Decode([]byte(awsTestKey))
|
||||
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
||||
|
@ -243,7 +271,7 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
|||
"127.0.0.1", "us-west-1", time.Now(), key)
|
||||
assert.FatalError(t, err)
|
||||
failSubject, err := generateAWSToken(
|
||||
"bad-subject", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id",
|
||||
"bad-subject", awsIssuer, p2.GetID(), p2.Accounts[0], "instance-id",
|
||||
"127.0.0.1", "us-west-1", time.Now(), key)
|
||||
assert.FatalError(t, err)
|
||||
failIssuer, err := generateAWSToken(
|
||||
|
@ -299,6 +327,8 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
|||
}{
|
||||
{"ok", p1, args{t1}, 4, false},
|
||||
{"ok", p2, args{t2}, 6, false},
|
||||
{"ok", p2, args{t2Hostname}, 6, false},
|
||||
{"ok", p2, args{t2PrivateIP}, 6, false},
|
||||
{"ok", p1, args{t4}, 4, false},
|
||||
{"fail account", p3, args{t3}, 0, true},
|
||||
{"fail token", p1, args{"token"}, 0, true},
|
||||
|
@ -364,7 +394,7 @@ func TestAWS_AuthorizeRevoke(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
t1, err := p1.GetIdentityToken("https://ca.smallstep.com")
|
||||
t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
|
|
|
@ -141,7 +141,7 @@ func (p *Azure) GetEncryptedKey() (kid string, key string, ok bool) {
|
|||
|
||||
// GetIdentityToken retrieves from the metadata service the identity token and
|
||||
// returns it.
|
||||
func (p *Azure) GetIdentityToken() (string, error) {
|
||||
func (p *Azure) GetIdentityToken(subject, caURL string) (string, error) {
|
||||
// Initialize the config if this method is used from the cli.
|
||||
p.assertConfig()
|
||||
|
||||
|
@ -264,17 +264,17 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Enforce default DNS if configured.
|
||||
// By default we'll accept the SANs in the CSR.
|
||||
// Enforce known common name and default DNS if configured.
|
||||
// By default we'll accept the CN and SANs in the CSR.
|
||||
// There's no way to trust them other than TOFU.
|
||||
var so []SignOption
|
||||
if p.DisableCustomSANs {
|
||||
// name will work only inside the virtual network
|
||||
so = append(so, commonNameValidator(name))
|
||||
so = append(so, dnsNamesValidator([]string{name}))
|
||||
}
|
||||
|
||||
return append(so,
|
||||
commonNameValidator(name),
|
||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
||||
newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID),
|
||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
||||
|
|
|
@ -46,9 +46,9 @@ func TestAzure_GetTokenID(t *testing.T) {
|
|||
p2.keyStore = p1.keyStore
|
||||
p2.DisableTrustOnFirstUse = true
|
||||
|
||||
t1, err := p1.GetIdentityToken()
|
||||
t1, err := p1.GetIdentityToken("subject", "caURL")
|
||||
assert.FatalError(t, err)
|
||||
t2, err := p2.GetIdentityToken()
|
||||
t2, err := p2.GetIdentityToken("subject", "caURL")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
sum := sha256.Sum256([]byte("/subscriptions/subscriptionID/resourceGroups/resourceGroup/providers/Microsoft.Compute/virtualMachines/virtualMachine"))
|
||||
|
@ -105,23 +105,28 @@ func TestAzure_GetIdentityToken(t *testing.T) {
|
|||
}))
|
||||
defer srv.Close()
|
||||
|
||||
type args struct {
|
||||
subject string
|
||||
caURL string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
azure *Azure
|
||||
args args
|
||||
identityTokenURL string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, srv.URL, t1, false},
|
||||
{"fail request", p1, srv.URL + "/bad-request", "", true},
|
||||
{"fail unmarshal", p1, srv.URL + "/bad-json", "", true},
|
||||
{"fail url", p1, "://ca.smallstep.com", "", true},
|
||||
{"fail connect", p1, "foobarzar", "", true},
|
||||
{"ok", p1, args{"subject", "caURL"}, srv.URL, t1, false},
|
||||
{"fail request", p1, args{"subject", "caURL"}, srv.URL + "/bad-request", "", true},
|
||||
{"fail unmarshal", p1, args{"subject", "caURL"}, srv.URL + "/bad-json", "", true},
|
||||
{"fail url", p1, args{"subject", "caURL"}, "://ca.smallstep.com", "", true},
|
||||
{"fail connect", p1, args{"subject", "caURL"}, "foobarzar", "", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.azure.config.identityTokenURL = tt.identityTokenURL
|
||||
got, err := tt.azure.GetIdentityToken()
|
||||
got, err := tt.azure.GetIdentityToken(tt.args.subject, tt.args.caURL)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Azure.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -231,13 +236,13 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
|||
badKey, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
t1, err := p1.GetIdentityToken()
|
||||
t1, err := p1.GetIdentityToken("subject", "caURL")
|
||||
assert.FatalError(t, err)
|
||||
t2, err := p2.GetIdentityToken()
|
||||
t2, err := p2.GetIdentityToken("subject", "caURL")
|
||||
assert.FatalError(t, err)
|
||||
t3, err := p3.GetIdentityToken()
|
||||
t3, err := p3.GetIdentityToken("subject", "caURL")
|
||||
assert.FatalError(t, err)
|
||||
t4, err := p4.GetIdentityToken()
|
||||
t4, err := p4.GetIdentityToken("subject", "caURL")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
||||
|
@ -276,9 +281,9 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
|||
wantLen int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1}, 4, false},
|
||||
{"ok", p1, args{t1}, 3, false},
|
||||
{"ok", p2, args{t2}, 5, false},
|
||||
{"ok", p1, args{t11}, 4, false},
|
||||
{"ok", p1, args{t11}, 3, false},
|
||||
{"fail tenant", p3, args{t3}, 0, true},
|
||||
{"fail resource group", p4, args{t4}, 0, true},
|
||||
{"fail token", p1, args{"token"}, 0, true},
|
||||
|
@ -338,7 +343,7 @@ func TestAzure_AuthorizeRevoke(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
token, err := az.GetIdentityToken()
|
||||
token, err := az.GetIdentityToken("subject", "caURL")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
|
|
|
@ -150,7 +150,7 @@ func (p *GCP) GetIdentityURL(audience string) string {
|
|||
}
|
||||
|
||||
// GetIdentityToken does an HTTP request to the identity url.
|
||||
func (p *GCP) GetIdentityToken(caURL string) (string, error) {
|
||||
func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) {
|
||||
audience, err := generateSignAudience(caURL, p.GetID())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -212,21 +212,24 @@ func (p *GCP) AuthorizeSign(token string) ([]SignOption, error) {
|
|||
}
|
||||
ce := claims.Google.ComputeEngine
|
||||
|
||||
// Enforce default DNS if configured.
|
||||
// By default we we'll accept the SANs in the CSR.
|
||||
// Enforce known common name and default DNS if configured.
|
||||
// By default we we'll accept the CN and SANs in the CSR.
|
||||
// There's no way to trust them other than TOFU.
|
||||
var so []SignOption
|
||||
if p.DisableCustomSANs {
|
||||
dnsName1 := fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID)
|
||||
dnsName2 := fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID)
|
||||
so = append(so, commonNameSliceValidator([]string{
|
||||
ce.InstanceName, ce.InstanceID, dnsName1, dnsName2,
|
||||
}))
|
||||
so = append(so, dnsNamesValidator([]string{
|
||||
fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID),
|
||||
fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID),
|
||||
dnsName1, dnsName2,
|
||||
}))
|
||||
}
|
||||
|
||||
return append(so,
|
||||
commonNameValidator(ce.InstanceName),
|
||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
||||
newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject),
|
||||
newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName),
|
||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
||||
), nil
|
||||
}
|
||||
|
|
|
@ -117,7 +117,8 @@ func TestGCP_GetIdentityToken(t *testing.T) {
|
|||
defer srv.Close()
|
||||
|
||||
type args struct {
|
||||
caURL string
|
||||
subject string
|
||||
caURL string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -127,16 +128,16 @@ func TestGCP_GetIdentityToken(t *testing.T) {
|
|||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{"https://ca"}, srv.URL, t1, false},
|
||||
{"fail ca url", p1, args{"://ca"}, srv.URL, "", true},
|
||||
{"fail request", p1, args{"https://ca"}, srv.URL + "/bad-request", "", true},
|
||||
{"fail url", p1, args{"https://ca"}, "://ca.smallstep.com", "", true},
|
||||
{"fail connect", p1, args{"https://ca"}, "foobarzar", "", true},
|
||||
{"ok", p1, args{"subject", "https://ca"}, srv.URL, t1, false},
|
||||
{"fail ca url", p1, args{"subject", "://ca"}, srv.URL, "", true},
|
||||
{"fail request", p1, args{"subject", "https://ca"}, srv.URL + "/bad-request", "", true},
|
||||
{"fail url", p1, args{"subject", "https://ca"}, "://ca.smallstep.com", "", true},
|
||||
{"fail connect", p1, args{"subject", "https://ca"}, "foobarzar", "", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.gcp.config.IdentityURL = tt.identityURL
|
||||
got, err := tt.gcp.GetIdentityToken(tt.args.caURL)
|
||||
got, err := tt.gcp.GetIdentityToken(tt.args.subject, tt.args.caURL)
|
||||
t.Log(err)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GCP.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
@ -310,9 +311,9 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
|||
wantLen int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1}, 4, false},
|
||||
{"ok", p1, args{t1}, 3, false},
|
||||
{"ok", p2, args{t2}, 5, false},
|
||||
{"ok", p3, args{t3}, 4, false},
|
||||
{"ok", p3, args{t3}, 3, false},
|
||||
{"fail token", p1, args{"token"}, 0, true},
|
||||
{"fail key", p1, args{failKey}, 0, true},
|
||||
{"fail iss", p1, args{failIss}, 0, true},
|
||||
|
|
|
@ -55,7 +55,12 @@ func (v profileWithOption) Option(Options) x509util.WithOption {
|
|||
type profileDefaultDuration time.Duration
|
||||
|
||||
func (v profileDefaultDuration) Option(so Options) x509util.WithOption {
|
||||
return x509util.WithNotBeforeAfterDuration(so.NotBefore.Time(), so.NotAfter.Time(), time.Duration(v))
|
||||
notBefore := so.NotBefore.Time()
|
||||
if notBefore.IsZero() {
|
||||
notBefore = time.Now()
|
||||
}
|
||||
notAfter := so.NotAfter.RelativeTime(notBefore)
|
||||
return x509util.WithNotBeforeAfterDuration(notBefore, notAfter, time.Duration(v))
|
||||
}
|
||||
|
||||
// emailOnlyIdentity is a CertificateRequestValidator that checks that the only
|
||||
|
@ -97,6 +102,21 @@ func (v commonNameValidator) Valid(req *x509.CertificateRequest) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// commonNameSliceValidator validates thats the common name of a certificate request is present in the slice.
|
||||
type commonNameSliceValidator []string
|
||||
|
||||
func (v commonNameSliceValidator) Valid(req *x509.CertificateRequest) error {
|
||||
if req.Subject.CommonName == "" {
|
||||
return errors.New("certificate request cannot contain an empty common name")
|
||||
}
|
||||
for _, cn := range v {
|
||||
if req.Subject.CommonName == cn {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.Errorf("certificate request does not contain the valid common name, got %s, want %s", req.Subject.CommonName, v)
|
||||
}
|
||||
|
||||
// dnsNamesValidator validates the DNS names SAN of a certificate request.
|
||||
type dnsNamesValidator []string
|
||||
|
||||
|
@ -180,29 +200,32 @@ var (
|
|||
)
|
||||
|
||||
type stepProvisionerASN1 struct {
|
||||
Type int
|
||||
Name []byte
|
||||
CredentialID []byte
|
||||
Type int
|
||||
Name []byte
|
||||
CredentialID []byte
|
||||
KeyValuePairs []string `asn1:"optional,omitempty"`
|
||||
}
|
||||
|
||||
type provisionerExtensionOption struct {
|
||||
Type int
|
||||
Name string
|
||||
CredentialID string
|
||||
Type int
|
||||
Name string
|
||||
CredentialID string
|
||||
KeyValuePairs []string
|
||||
}
|
||||
|
||||
func newProvisionerExtensionOption(typ Type, name, credentialID string) *provisionerExtensionOption {
|
||||
func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValuePairs ...string) *provisionerExtensionOption {
|
||||
return &provisionerExtensionOption{
|
||||
Type: int(typ),
|
||||
Name: name,
|
||||
CredentialID: credentialID,
|
||||
Type: int(typ),
|
||||
Name: name,
|
||||
CredentialID: credentialID,
|
||||
KeyValuePairs: keyValuePairs,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *provisionerExtensionOption) Option(Options) x509util.WithOption {
|
||||
return func(p x509util.Profile) error {
|
||||
crt := p.Subject()
|
||||
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID)
|
||||
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID, o.KeyValuePairs...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -211,11 +234,12 @@ func (o *provisionerExtensionOption) Option(Options) x509util.WithOption {
|
|||
}
|
||||
}
|
||||
|
||||
func createProvisionerExtension(typ int, name, credentialID string) (pkix.Extension, error) {
|
||||
func createProvisionerExtension(typ int, name, credentialID string, keyValuePairs ...string) (pkix.Extension, error) {
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{
|
||||
Type: typ,
|
||||
Name: []byte(name),
|
||||
CredentialID: []byte(credentialID),
|
||||
Type: typ,
|
||||
Name: []byte(name),
|
||||
CredentialID: []byte(credentialID),
|
||||
KeyValuePairs: keyValuePairs,
|
||||
})
|
||||
if err != nil {
|
||||
return pkix.Extension{}, errors.Wrapf(err, "error marshaling provisioner extension")
|
||||
|
|
|
@ -64,6 +64,30 @@ func Test_commonNameValidator_Valid(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_commonNameSliceValidator_Valid(t *testing.T) {
|
||||
type args struct {
|
||||
req *x509.CertificateRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
v commonNameSliceValidator
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "foo.bar.zar"}}}, false},
|
||||
{"ok", []string{"example.com", "foo.bar.zar"}, args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "foo.bar.zar"}}}, false},
|
||||
{"empty", []string{""}, args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: ""}}}, true},
|
||||
{"wrong", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "example.com"}}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr {
|
||||
t.Errorf("commonNameSliceValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_dnsNamesValidator_Valid(t *testing.T) {
|
||||
type args struct {
|
||||
req *x509.CertificateRequest
|
||||
|
|
|
@ -104,6 +104,12 @@ func (t *TimeDuration) UnmarshalJSON(data []byte) error {
|
|||
|
||||
// Time calculates the embedded time.Time, sets it if necessary, and returns it.
|
||||
func (t *TimeDuration) Time() time.Time {
|
||||
return t.RelativeTime(now())
|
||||
}
|
||||
|
||||
// RelativeTime returns the embedded time.Time or the base time plus the
|
||||
// duration if this is not zero.
|
||||
func (t *TimeDuration) RelativeTime(base time.Time) time.Time {
|
||||
switch {
|
||||
case t == nil:
|
||||
return time.Time{}
|
||||
|
@ -111,8 +117,8 @@ func (t *TimeDuration) Time() time.Time {
|
|||
if t.d == 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
t.t = now().Add(t.d)
|
||||
return t.t
|
||||
t.t = base.Add(t.d)
|
||||
return t.t.UTC()
|
||||
default:
|
||||
return t.t.UTC()
|
||||
}
|
||||
|
|
|
@ -283,20 +283,12 @@ func generateAWSWithServer() (*AWS, *httptest.Server, error) {
|
|||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "error parsing AWS private key")
|
||||
}
|
||||
instanceID, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
imageID, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
doc, err := json.MarshalIndent(awsInstanceIdentityDocument{
|
||||
AccountID: aws.Accounts[0],
|
||||
Architecture: "x86_64",
|
||||
AvailabilityZone: "us-west-2b",
|
||||
ImageID: imageID,
|
||||
InstanceID: instanceID,
|
||||
ImageID: "image-id",
|
||||
InstanceID: "instance-id",
|
||||
InstanceType: "t2.micro",
|
||||
PendingTime: time.Now(),
|
||||
PrivateIP: "127.0.0.1",
|
||||
|
@ -322,6 +314,8 @@ func generateAWSWithServer() (*AWS, *httptest.Server, error) {
|
|||
w.Write([]byte("{}"))
|
||||
case "/bad-signature":
|
||||
w.Write([]byte("YmFkLXNpZ25hdHVyZQo="))
|
||||
case "/bad-json":
|
||||
w.Write([]byte("{"))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue