Refactor claims so they can be totally omitted if only the parent is set.

This commit is contained in:
Mariano Cano 2019-03-19 15:10:52 -07:00
parent 095ab891e7
commit 7378ed27ac
7 changed files with 85 additions and 54 deletions

View file

@ -60,7 +60,6 @@ type AuthConfig struct {
// Validate validates the authority configuration. // Validate validates the authority configuration.
func (c *AuthConfig) Validate(audiences []string) error { func (c *AuthConfig) Validate(audiences []string) error {
var err error
if c == nil { if c == nil {
return errors.New("authority cannot be undefined") return errors.New("authority cannot be undefined")
} }
@ -68,13 +67,15 @@ func (c *AuthConfig) Validate(audiences []string) error {
return errors.New("authority.provisioners cannot be empty") return errors.New("authority.provisioners cannot be empty")
} }
if c.Claims, err = c.Claims.Init(&globalProvisionerClaims); err != nil { // Merge global and configuration claims
claimer, err := provisioner.NewClaimer(c.Claims, globalProvisionerClaims)
if err != nil {
return err return err
} }
// Initialize provisioners // Initialize provisioners
config := provisioner.Config{ config := provisioner.Config{
Claims: *c.Claims, Claims: claimer.Claims(),
Audiences: audiences, Audiences: audiences,
} }
for _, p := range c.Provisioners { for _, p := range c.Provisioners {

View file

@ -8,76 +8,90 @@ import (
// Claims so that individual provisioners can override global claims. // Claims so that individual provisioners can override global claims.
type Claims struct { type Claims struct {
globalClaims *Claims
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"` MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"` MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"` DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
DisableRenewal *bool `json:"disableRenewal,omitempty"` DisableRenewal *bool `json:"disableRenewal,omitempty"`
} }
// Init initializes and validates the individual provisioner claims. // Claimer is the type that controls claims. It provides an interface around the
func (pc *Claims) Init(global *Claims) (*Claims, error) { // current claim and the global one.
if pc == nil { type Claimer struct {
pc = &Claims{} global Claims
claims *Claims
}
// NewClaimer initializes a new claimer with the given claims.
func NewClaimer(claims *Claims, global Claims) (*Claimer, error) {
c := &Claimer{global: global, claims: claims}
return c, c.Validate()
}
// Claims returns the merge of the inner and global claims.
func (c *Claimer) Claims() Claims {
disableRenewal := c.IsDisableRenewal()
return Claims{
MinTLSDur: &Duration{c.MinTLSCertDuration()},
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
DisableRenewal: &disableRenewal,
} }
pc.globalClaims = global
return pc, pc.Validate()
} }
// DefaultTLSCertDuration returns the default TLS cert duration for the // DefaultTLSCertDuration returns the default TLS cert duration for the
// provisioner. If the default is not set within the provisioner, then the global // provisioner. If the default is not set within the provisioner, then the global
// default from the authority configuration will be used. // default from the authority configuration will be used.
func (pc *Claims) DefaultTLSCertDuration() time.Duration { func (c *Claimer) DefaultTLSCertDuration() time.Duration {
if pc.DefaultTLSDur == nil || pc.DefaultTLSDur.Duration == 0 { if c.claims == nil || c.claims.DefaultTLSDur == nil {
return pc.globalClaims.DefaultTLSCertDuration() return c.global.DefaultTLSDur.Duration
} }
return pc.DefaultTLSDur.Duration return c.claims.DefaultTLSDur.Duration
} }
// MinTLSCertDuration returns the minimum TLS cert duration for the provisioner. // MinTLSCertDuration returns the minimum TLS cert duration for the provisioner.
// If the minimum is not set within the provisioner, then the global // If the minimum is not set within the provisioner, then the global
// minimum from the authority configuration will be used. // minimum from the authority configuration will be used.
func (pc *Claims) MinTLSCertDuration() time.Duration { func (c *Claimer) MinTLSCertDuration() time.Duration {
if pc.MinTLSDur == nil || pc.MinTLSDur.Duration == 0 { if c.claims == nil || c.claims.MinTLSDur == nil {
return pc.globalClaims.MinTLSCertDuration() return c.global.MinTLSDur.Duration
} }
return pc.MinTLSDur.Duration return c.claims.MinTLSDur.Duration
} }
// MaxTLSCertDuration returns the maximum TLS cert duration for the provisioner. // MaxTLSCertDuration returns the maximum TLS cert duration for the provisioner.
// If the maximum is not set within the provisioner, then the global // If the maximum is not set within the provisioner, then the global
// maximum from the authority configuration will be used. // maximum from the authority configuration will be used.
func (pc *Claims) MaxTLSCertDuration() time.Duration { func (c *Claimer) MaxTLSCertDuration() time.Duration {
if pc.MaxTLSDur == nil || pc.MaxTLSDur.Duration == 0 { if c.claims == nil || c.claims.MaxTLSDur == nil {
return pc.globalClaims.MaxTLSCertDuration() return c.global.MaxTLSDur.Duration
} }
return pc.MaxTLSDur.Duration return c.claims.MaxTLSDur.Duration
} }
// IsDisableRenewal returns if the renewal flow is disabled for the // IsDisableRenewal returns if the renewal flow is disabled for the
// provisioner. If the property is not set within the provisioner, then the // provisioner. If the property is not set within the provisioner, then the
// global value from the authority configuration will be used. // global value from the authority configuration will be used.
func (pc *Claims) IsDisableRenewal() bool { func (c *Claimer) IsDisableRenewal() bool {
if pc.DisableRenewal == nil { if c.claims == nil || c.claims.DisableRenewal == nil {
return pc.globalClaims.IsDisableRenewal() return *c.global.DisableRenewal
} }
return *pc.DisableRenewal return *c.claims.DisableRenewal
} }
// Validate validates and modifies the Claims with default values. // Validate validates and modifies the Claims with default values.
func (pc *Claims) Validate() error { func (c *Claimer) Validate() error {
var ( var (
min = pc.MinTLSCertDuration() min = c.MinTLSCertDuration()
max = pc.MaxTLSCertDuration() max = c.MaxTLSCertDuration()
def = pc.DefaultTLSCertDuration() def = c.DefaultTLSCertDuration()
) )
switch { switch {
case min == 0: case min <= 0:
return errors.Errorf("claims: MinTLSCertDuration cannot be empty") return errors.Errorf("claims: MinTLSCertDuration must be greater than 0")
case max == 0: case max <= 0:
return errors.Errorf("claims: MaxTLSCertDuration cannot be empty") return errors.Errorf("claims: MaxTLSCertDuration must be greater than 0")
case def == 0: case def <= 0:
return errors.Errorf("claims: DefaultTLSCertDuration cannot be empty") return errors.Errorf("claims: DefaultTLSCertDuration must be greater than 0")
case max < min: case max < min:
return errors.Errorf("claims: MaxCertDuration cannot be less "+ return errors.Errorf("claims: MaxCertDuration cannot be less "+
"than MinCertDuration: MaxCertDuration - %v, MinCertDuration - %v", max, min) "than MinCertDuration: MaxCertDuration - %v, MinCertDuration - %v", max, min)

View file

@ -23,6 +23,7 @@ type JWK struct {
Key *jose.JSONWebKey `json:"key"` Key *jose.JSONWebKey `json:"key"`
EncryptedKey string `json:"encryptedKey,omitempty"` EncryptedKey string `json:"encryptedKey,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
claimer *Claimer
audiences []string audiences []string
} }
@ -57,7 +58,12 @@ func (p *JWK) Init(config Config) (err error) {
case p.Key == nil: case p.Key == nil:
return errors.New("provisioner key cannot be empty") return errors.New("provisioner key cannot be empty")
} }
p.Claims, err = p.Claims.Init(&config.Claims)
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
p.audiences = config.Audiences p.audiences = config.Audiences
return err return err
} }
@ -104,15 +110,15 @@ func (p *JWK) Authorize(token string) ([]SignOption, error) {
commonNameValidator(claims.Subject), commonNameValidator(claims.Subject),
dnsNamesValidator(dnsNames), dnsNamesValidator(dnsNames),
ipAddressesValidator(ips), ipAddressesValidator(ips),
profileDefaultDuration(p.Claims.DefaultTLSCertDuration()), profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID), newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID),
newValidityValidator(p.Claims.MinTLSCertDuration(), p.Claims.MaxTLSCertDuration()), newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
}, nil }, nil
} }
// AuthorizeRenewal returns an error if the renewal is disabled. // AuthorizeRenewal returns an error if the renewal is disabled.
func (p *JWK) AuthorizeRenewal(cert *x509.Certificate) error { func (p *JWK) AuthorizeRenewal(cert *x509.Certificate) error {
if p.Claims.IsDisableRenewal() { if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID()) return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
} }
return nil return nil

View file

@ -201,10 +201,9 @@ func TestJWK_AuthorizeRenewal(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{ p2.Claims = &Claims{DisableRenewal: &disable}
globalClaims: &globalProvisionerClaims, p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
DisableRenewal: &disable, assert.FatalError(t, err)
}
type args struct { type args struct {
cert *x509.Certificate cert *x509.Certificate

View file

@ -55,6 +55,7 @@ type OIDC struct {
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
configuration openIDConfiguration configuration openIDConfiguration
keyStore *keyStore keyStore *keyStore
claimer *Claimer
} }
// IsAdmin returns true if the given email is in the Admins whitelist, false // IsAdmin returns true if the given email is in the Admins whitelist, false
@ -111,9 +112,10 @@ func (o *OIDC) Init(config Config) (err error) {
} }
// Update claims with global ones // Update claims with global ones
if o.Claims, err = o.Claims.Init(&config.Claims); err != nil { if o.claimer, err = NewClaimer(o.Claims, config.Claims); err != nil {
return err return err
} }
// Decode and validate openid-configuration endpoint // Decode and validate openid-configuration endpoint
if err := getAndDecode(o.ConfigurationEndpoint, &o.configuration); err != nil { if err := getAndDecode(o.ConfigurationEndpoint, &o.configuration); err != nil {
return err return err
@ -202,23 +204,23 @@ func (o *OIDC) Authorize(token string) ([]SignOption, error) {
// Admins should be able to authorize any SAN // Admins should be able to authorize any SAN
if o.IsAdmin(claims.Email) { if o.IsAdmin(claims.Email) {
return []SignOption{ return []SignOption{
profileDefaultDuration(o.Claims.DefaultTLSCertDuration()), profileDefaultDuration(o.claimer.DefaultTLSCertDuration()),
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID), newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
newValidityValidator(o.Claims.MinTLSCertDuration(), o.Claims.MaxTLSCertDuration()), newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()),
}, nil }, nil
} }
return []SignOption{ return []SignOption{
emailOnlyIdentity(claims.Email), emailOnlyIdentity(claims.Email),
profileDefaultDuration(o.Claims.DefaultTLSCertDuration()), profileDefaultDuration(o.claimer.DefaultTLSCertDuration()),
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID), newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
newValidityValidator(o.Claims.MinTLSCertDuration(), o.Claims.MaxTLSCertDuration()), newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()),
}, nil }, nil
} }
// AuthorizeRenewal returns an error if the renewal is disabled. // AuthorizeRenewal returns an error if the renewal is disabled.
func (o *OIDC) AuthorizeRenewal(cert *x509.Certificate) error { func (o *OIDC) AuthorizeRenewal(cert *x509.Certificate) error {
if o.Claims.IsDisableRenewal() { if o.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", o.GetID()) return errors.Errorf("renew is disabled for provisioner %s", o.GetID())
} }
return nil return nil

View file

@ -241,10 +241,9 @@ func TestOIDC_AuthorizeRenewal(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{ p2.Claims = &Claims{DisableRenewal: &disable}
globalClaims: &globalProvisionerClaims, p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
DisableRenewal: &disable, assert.FatalError(t, err)
}
type args struct { type args struct {
cert *x509.Certificate cert *x509.Certificate

View file

@ -109,6 +109,10 @@ func generateJWK() (*JWK, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
return &JWK{ return &JWK{
Name: name, Name: name,
Type: "JWK", Type: "JWK",
@ -116,6 +120,7 @@ func generateJWK() (*JWK, error) {
EncryptedKey: encrypted, EncryptedKey: encrypted,
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
audiences: testAudiences, audiences: testAudiences,
claimer: claimer,
}, nil }, nil
} }
@ -136,6 +141,10 @@ func generateOIDC() (*OIDC, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
return &OIDC{ return &OIDC{
Name: name, Name: name,
Type: "OIDC", Type: "OIDC",
@ -150,6 +159,7 @@ func generateOIDC() (*OIDC, error) {
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour), expiry: time.Now().Add(24 * time.Hour),
}, },
claimer: claimer,
}, nil }, nil
} }