diff --git a/.golangci.yml b/.golangci.yml index 67aac2df..59c58490 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -73,3 +73,4 @@ issues: - error strings should not be capitalized or end with punctuation or a newline - Wrapf call needs 1 arg but has 2 args - cs.NegotiatedProtocolIsMutual is deprecated + - rewrite if-else to switch statement diff --git a/acme/api/middleware.go b/acme/api/middleware.go index d701f240..b826d1fa 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -283,7 +283,7 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { } // lookupProvisioner loads the provisioner associated with the request. -// Responsds 404 if the provisioner does not exist. +// Responds 404 if the provisioner does not exist. func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() diff --git a/acme/api/order.go b/acme/api/order.go old mode 100644 new mode 100755 index 9cf2c1eb..3d22ec0f --- a/acme/api/order.go +++ b/acme/api/order.go @@ -35,6 +35,8 @@ func (n *NewOrderRequest) Validate() error { if id.Type == acme.IP && net.ParseIP(id.Value) == nil { return acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", id.Value) } + // TODO: add some validations for DNS domains? + // TODO: combine the errors from this with allow/deny policy, like example error in https://datatracker.ietf.org/doc/html/rfc8555#section-6.7.1 } return nil } @@ -83,6 +85,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { api.WriteError(w, err) return } + var nor NewOrderRequest if err := json.Unmarshal(payload.value, &nor); err != nil { api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, @@ -95,6 +98,22 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { return } + // TODO(hs): this should also verify rules set in the Account (i.e. allowed/denied + // DNS and IPs; it's probably good to connect those to the EAB credentials and management? Or + // should we do it fully properly and connect them to the Account directly? The latter would allow + // management of allowed/denied names based on just the name, without having bound to EAB. Still, + // EAB is not illogical, because that's the way Accounts are connected to an external system and + // thus make sense to also set the allowed/denied names based on that info. + + for _, identifier := range nor.Identifiers { + // TODO: gather all errors, so that we can build subproblems; include the nor.Validate() error here too, like in example? + err = prov.AuthorizeOrderIdentifier(ctx, identifier.Value) + if err != nil { + api.WriteError(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) + return + } + } + now := clock.Now() // New order. o := &acme.Order{ diff --git a/acme/common.go b/acme/common.go index 0c9e83dc..4b086dd7 100644 --- a/acme/common.go +++ b/acme/common.go @@ -30,6 +30,7 @@ var clock Clock // Provisioner is an interface that implements a subset of the provisioner.Interface -- // only those methods required by the ACME api/authority. type Provisioner interface { + AuthorizeOrderIdentifier(ctx context.Context, identifier string) error AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) AuthorizeRevoke(ctx context.Context, token string) error GetID() string @@ -40,14 +41,15 @@ type Provisioner interface { // MockProvisioner for testing type MockProvisioner struct { - Mret1 interface{} - Merr error - MgetID func() string - MgetName func() string - MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) - MauthorizeRevoke func(ctx context.Context, token string) error - MdefaultTLSCertDuration func() time.Duration - MgetOptions func() *provisioner.Options + Mret1 interface{} + Merr error + MgetID func() string + MgetName func() string + MauthorizeOrderIdentifier func(ctx context.Context, identifier string) error + MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) + MauthorizeRevoke func(ctx context.Context, token string) error + MdefaultTLSCertDuration func() time.Duration + MgetOptions func() *provisioner.Options } // GetName mock @@ -58,6 +60,14 @@ func (m *MockProvisioner) GetName() string { return m.Mret1.(string) } +// AuthorizeOrderIdentifiers mock +func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier string) error { + if m.MauthorizeOrderIdentifier != nil { + return m.MauthorizeOrderIdentifier(ctx, identifier) + } + return m.Merr +} + // AuthorizeSign mock func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) { if m.MauthorizeSign != nil { diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 6d524a25..08090e22 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -483,7 +483,7 @@ func TestAuthority_authorizeSign(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - assert.Len(t, 7, got) + assert.Len(t, 8, got) // number of provisioner.SignOptions returned } } }) @@ -995,7 +995,7 @@ func TestAuthority_authorizeSSHSign(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - assert.Len(t, 7, got) + assert.Len(t, 8, got) // number of provisioner.SignOptions returned } } }) diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go old mode 100644 new mode 100755 index c8950568..c6cadf51 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -3,6 +3,7 @@ package provisioner import ( "context" "crypto/x509" + "net" "time" "github.com/pkg/errors" @@ -67,8 +68,9 @@ func (p *ACME) DefaultTLSCertDuration() time.Duration { return p.claimer.DefaultTLSCertDuration() } -// Init initializes and validates the fields of a JWK type. +// Init initializes and validates the fields of an ACME type. func (p *ACME) Init(config Config) (err error) { + p.base = &base{} // prevent nil pointers switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -81,6 +83,47 @@ func (p *ACME) Init(config Config) (err error) { return err } + // Initialize the x509 allow/deny policy engine + // TODO(hs): ensure no race conditions happen when reloading settings and requesting certs? + // TODO(hs): implement memoization strategy, so that reloading is not required when no changes were made to allow/deny? + if p.x509PolicyEngine, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { + return err + } + + return nil +} + +// ACMEIdentifierType encodes ACME Identifier types +type ACMEIdentifierType string + +const ( + // IP is the ACME ip identifier type + IP ACMEIdentifierType = "ip" + // DNS is the ACME dns identifier type + DNS ACMEIdentifierType = "dns" +) + +// ACMEIdentifier encodes ACME Order Identifiers +type ACMEIdentifier struct { + Type ACMEIdentifierType + Value string +} + +// AuthorizeOrderIdentifiers verifies the provisioner is authorized to issue a +// certificate for the Identifiers provided in an Order. +func (p *ACME) AuthorizeOrderIdentifier(ctx context.Context, identifier string) error { + + if p.x509PolicyEngine == nil { + return nil + } + + var err error + if ip := net.ParseIP(identifier); ip != nil { + _, err = p.x509PolicyEngine.IsIPAllowed(ip) + } else { + _, err = p.x509PolicyEngine.IsDNSAllowed(identifier) + } + return err } @@ -88,7 +131,7 @@ func (p *ACME) Init(config Config) (err error) { // in the ACME protocol. This method returns a list of modifiers / constraints // on the resulting certificate. func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - return []SignOption{ + opts := []SignOption{ // modifiers / withOptions newProvisionerExtensionOption(TypeACME, p.Name, ""), newForceCNOption(p.ForceCN), @@ -96,7 +139,10 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // validators defaultPublicKeyValidator{}, newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), - }, nil + newX509NamePolicyValidator(p.x509PolicyEngine), + } + + return opts, nil } // AuthorizeRevoke is called just before the certificate is to be revoked by diff --git a/authority/provisioner/acme_test.go b/authority/provisioner/acme_test.go index bd173f87..b9f52253 100644 --- a/authority/provisioner/acme_test.go +++ b/authority/provisioner/acme_test.go @@ -168,7 +168,7 @@ func TestACME_AuthorizeSign(t *testing.T) { } } else { if assert.Nil(t, tc.err) && assert.NotNil(t, opts) { - assert.Len(t, 5, opts) + assert.Len(t, 6, opts) // number of SignOptions returned for _, o := range opts { switch v := o.(type) { case *provisionerExtensionOption: @@ -184,6 +184,8 @@ func TestACME_AuthorizeSign(t *testing.T) { case *validityValidator: assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index fdad7b4a..9f542873 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -392,6 +392,7 @@ func (p *AWS) GetIdentityToken(subject, caURL string) (string, error) { // Init validates and initializes the AWS provisioner. func (p *AWS) Init(config Config) (err error) { + p.base = &base{} // prevent nil pointers switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -425,6 +426,16 @@ func (p *AWS) Init(config Config) (err error) { } } + // Initialize the x509 allow/deny policy engine + if p.x509PolicyEngine, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { + return err + } + + // Initialize the SSH allow/deny policy engine + if p.sshPolicyEngine, err = newSSHPolicyEngine(p.Options.GetSSHOptions()); err != nil { + return err + } + return nil } @@ -478,6 +489,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er defaultPublicKeyValidator{}, commonNameValidator(payload.Claims.Subject), newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.x509PolicyEngine), ), nil } @@ -759,5 +771,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, &sshCertValidityValidator{p.claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.sshPolicyEngine), ), nil } diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 0d2786db..beef8642 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -641,11 +641,11 @@ func TestAWS_AuthorizeSign(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{t1, "foo.local"}, 6, http.StatusOK, false}, - {"ok", p2, args{t2, "instance-id"}, 10, http.StatusOK, false}, - {"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 10, http.StatusOK, false}, - {"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 10, http.StatusOK, false}, - {"ok", p1, args{t4, "instance-id"}, 6, http.StatusOK, false}, + {"ok", p1, args{t1, "foo.local"}, 7, http.StatusOK, false}, + {"ok", p2, args{t2, "instance-id"}, 11, http.StatusOK, false}, + {"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 11, http.StatusOK, false}, + {"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 11, http.StatusOK, false}, + {"ok", p1, args{t4, "instance-id"}, 7, http.StatusOK, false}, {"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true}, {"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true}, {"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true}, @@ -697,6 +697,8 @@ func TestAWS_AuthorizeSign(t *testing.T) { assert.Equals(t, v, nil) case dnsNamesValidator: assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"}) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index 55d77f49..b8bbe143 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -191,6 +191,7 @@ func (p *Azure) GetIdentityToken(subject, caURL string) (string, error) { // Init validates and initializes the Azure provisioner. func (p *Azure) Init(config Config) (err error) { + p.base = &base{} // prevent nil pointers switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -221,6 +222,16 @@ func (p *Azure) Init(config Config) (err error) { return err } + // Initialize the x509 allow/deny policy engine + if p.x509PolicyEngine, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { + return err + } + + // Initialize the SSH allow/deny policy engine + if p.sshPolicyEngine, err = newSSHPolicyEngine(p.Options.GetSSHOptions()); err != nil { + return err + } + return nil } @@ -328,6 +339,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // validators defaultPublicKeyValidator{}, newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.x509PolicyEngine), ), nil } @@ -396,6 +408,8 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio &sshCertValidityValidator{p.claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.sshPolicyEngine), ), nil } diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index 7f8d6017..7e184a27 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -431,9 +431,9 @@ func TestAzure_AuthorizeSign(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{t1}, 5, http.StatusOK, false}, - {"ok", p2, args{t2}, 10, http.StatusOK, false}, - {"ok", p1, args{t11}, 5, http.StatusOK, false}, + {"ok", p1, args{t1}, 6, http.StatusOK, false}, + {"ok", p2, args{t2}, 11, http.StatusOK, false}, + {"ok", p1, args{t11}, 6, http.StatusOK, false}, {"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true}, {"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true}, {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true}, @@ -480,6 +480,8 @@ func TestAzure_AuthorizeSign(t *testing.T) { assert.Equals(t, v, nil) case dnsNamesValidator: assert.Equals(t, []string(v), []string{"virtualMachine"}) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/gcp.go b/authority/provisioner/gcp.go index e46f4ce4..4c7f2046 100644 --- a/authority/provisioner/gcp.go +++ b/authority/provisioner/gcp.go @@ -195,6 +195,7 @@ func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) { // Init validates and initializes the GCP provisioner. func (p *GCP) Init(config Config) error { + p.base = &base{} // prevent nil pointers var err error switch { case p.Type == "": @@ -216,6 +217,16 @@ func (p *GCP) Init(config Config) error { return err } + // Initialize the x509 allow/deny policy engine + if p.x509PolicyEngine, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { + return err + } + + // Initialize the SSH allow/deny policy engine + if p.sshPolicyEngine, err = newSSHPolicyEngine(p.Options.GetSSHOptions()); err != nil { + return err + } + p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) return nil } @@ -273,6 +284,7 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // validators defaultPublicKeyValidator{}, newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.x509PolicyEngine), ), nil } @@ -438,5 +450,7 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, &sshCertValidityValidator{p.claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.sshPolicyEngine), ), nil } diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index 5f6f9bc7..8c54c4c5 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -515,9 +515,9 @@ func TestGCP_AuthorizeSign(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{t1}, 5, http.StatusOK, false}, - {"ok", p2, args{t2}, 10, http.StatusOK, false}, - {"ok", p3, args{t3}, 5, http.StatusOK, false}, + {"ok", p1, args{t1}, 6, http.StatusOK, false}, + {"ok", p2, args{t2}, 11, http.StatusOK, false}, + {"ok", p3, args{t3}, 6, http.StatusOK, false}, {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true}, {"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true}, {"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true}, @@ -569,6 +569,8 @@ func TestGCP_AuthorizeSign(t *testing.T) { assert.Equals(t, v, nil) case dnsNamesValidator: assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go old mode 100644 new mode 100755 index 137915c8..081eb60c --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -89,6 +89,7 @@ func (p *JWK) GetEncryptedKey() (string, string, bool) { // Init initializes and validates the fields of a JWK type. func (p *JWK) Init(config Config) (err error) { + p.base = &base{} // prevent nil pointers switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -103,6 +104,16 @@ func (p *JWK) Init(config Config) (err error) { return err } + // Initialize the x509 allow/deny policy engine + if p.x509PolicyEngine, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { + return err + } + + // Initialize the SSH allow/deny policy engine + if p.sshPolicyEngine, err = newSSHPolicyEngine(p.Options.GetSSHOptions()); err != nil { + return err + } + p.audiences = config.Audiences return err } @@ -185,6 +196,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er defaultPublicKeyValidator{}, defaultSANsValidator(claims.SANs), newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.x509PolicyEngine), }, nil } @@ -268,6 +280,8 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, &sshCertValidityValidator{p.claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.sshPolicyEngine), ), nil } diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index deae8f7a..cb43627b 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -295,7 +295,7 @@ func TestJWK_AuthorizeSign(t *testing.T) { } } else { if assert.NotNil(t, got) { - assert.Len(t, 7, got) + assert.Len(t, 8, got) for _, o := range got { switch v := o.(type) { case certificateOptionsFunc: @@ -314,6 +314,8 @@ func TestJWK_AuthorizeSign(t *testing.T) { assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) case defaultSANsValidator: assert.Equals(t, []string(v), tt.sans) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/k8sSA.go b/authority/provisioner/k8sSA.go index d260f5ec..707e141e 100644 --- a/authority/provisioner/k8sSA.go +++ b/authority/provisioner/k8sSA.go @@ -92,6 +92,7 @@ func (p *K8sSA) GetEncryptedKey() (string, string, bool) { // Init initializes and validates the fields of a K8sSA type. func (p *K8sSA) Init(config Config) (err error) { + p.base = &base{} // prevent nil pointers switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -143,6 +144,16 @@ func (p *K8sSA) Init(config Config) (err error) { return err } + // Initialize the x509 allow/deny policy engine + if p.x509PolicyEngine, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { + return err + } + + // Initialize the SSH allow/deny policy engine + if p.sshPolicyEngine, err = newSSHPolicyEngine(p.Options.GetSSHOptions()); err != nil { + return err + } + p.audiences = config.Audiences return err } @@ -244,6 +255,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // validators defaultPublicKeyValidator{}, newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.x509PolicyEngine), }, nil } @@ -289,6 +301,8 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio &sshCertValidityValidator{p.claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.sshPolicyEngine), ), nil } diff --git a/authority/provisioner/k8sSA_test.go b/authority/provisioner/k8sSA_test.go index 176cdfd3..3ccce461 100644 --- a/authority/provisioner/k8sSA_test.go +++ b/authority/provisioner/k8sSA_test.go @@ -271,7 +271,6 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { - tot := 0 for _, o := range opts { switch v := o.(type) { case certificateOptionsFunc: @@ -286,12 +285,13 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { case *validityValidator: assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } - tot++ } - assert.Equals(t, tot, 5) + assert.Len(t, 6, opts) } } } @@ -358,7 +358,7 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { - tot := 0 + assert.Len(t, 7, opts) for _, o := range opts { switch v := o.(type) { case sshCertificateOptionsFunc: @@ -370,12 +370,12 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { case *sshCertDefaultValidator: case *sshDefaultDuration: assert.Equals(t, v.Claimer, tc.p.claimer) + case *sshNamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } - tot++ } - assert.Equals(t, tot, 6) } } } diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go index ac1f2a25..707f8228 100644 --- a/authority/provisioner/oidc.go +++ b/authority/provisioner/oidc.go @@ -154,6 +154,7 @@ func (o *OIDC) GetEncryptedKey() (kid, key string, ok bool) { // Init validates and initializes the OIDC provider. func (o *OIDC) Init(config Config) (err error) { + o.base = &base{} // prevent nil pointers switch { case o.Type == "": return errors.New("type cannot be empty") @@ -207,6 +208,17 @@ func (o *OIDC) Init(config Config) (err error) { } else { o.getIdentityFunc = config.GetIdentityFunc } + + // Initialize the x509 allow/deny policy engine + if o.x509PolicyEngine, err = newX509PolicyEngine(o.Options.GetX509Options()); err != nil { + return err + } + + // Initialize the SSH allow/deny policy engine + if o.sshPolicyEngine, err = newSSHPolicyEngine(o.Options.GetSSHOptions()); err != nil { + return err + } + return nil } @@ -363,6 +375,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // validators defaultPublicKeyValidator{}, newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(o.x509PolicyEngine), }, nil } @@ -452,6 +465,8 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption &sshCertValidityValidator{o.claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(o.sshPolicyEngine), ), nil } diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index 7bf6ad7a..92d4ca95 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -322,7 +322,7 @@ func TestOIDC_AuthorizeSign(t *testing.T) { assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { - assert.Len(t, 5, got) + assert.Len(t, 6, got) for _, o := range got { switch v := o.(type) { case certificateOptionsFunc: @@ -339,6 +339,8 @@ func TestOIDC_AuthorizeSign(t *testing.T) { assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) case emailOnlyIdentity: assert.Equals(t, string(v), "name@smallstep.com") + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/options.go b/authority/provisioner/options.go old mode 100644 new mode 100755 index f86c4863..7c516f6d --- a/authority/provisioner/options.go +++ b/authority/provisioner/options.go @@ -56,6 +56,12 @@ type X509Options struct { // TemplateData is a JSON object with variables that can be used in custom // templates. TemplateData json.RawMessage `json:"templateData,omitempty"` + + // AllowedNames contains the SANs the provisioner is authorized to sign + AllowedNames *AllowedX509NameOptions `json:"allow,omitempty"` + + // DeniedNames contains the SANs the provisioner is not authorized to sign + DeniedNames *DeniedX509NameOptions `json:"deny,omitempty"` } // HasTemplate returns true if a template is defined in the provisioner options. @@ -63,6 +69,58 @@ func (o *X509Options) HasTemplate() bool { return o != nil && (o.Template != "" || o.TemplateFile != "") } +// GetAllowedNameOptions returns the AllowedNameOptions, which models the +// SANs that a provisioner is authorized to sign x509 certificates for. +func (o *X509Options) GetAllowedNameOptions() *AllowedX509NameOptions { + if o == nil { + return nil + } + return o.AllowedNames +} + +// GetDeniedNameOptions returns the DeniedNameOptions, which models the +// SANs that a provisioner is NOT authorized to sign x509 certificates for. +func (o *X509Options) GetDeniedNameOptions() *DeniedX509NameOptions { + if o == nil { + return nil + } + return o.DeniedNames +} + +// AllowedX509NameOptions models the allowed names +type AllowedX509NameOptions struct { + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ip,omitempty"` // TODO(hs): support IPs as well as ranges + EmailAddresses []string `json:"email,omitempty"` + URIDomains []string `json:"uri,omitempty"` +} + +// DeniedX509NameOptions models the denied names +type DeniedX509NameOptions struct { + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ip,omitempty"` // TODO(hs): support IPs as well as ranges + EmailAddresses []string `json:"email,omitempty"` + URIDomains []string `json:"uri,omitempty"` +} + +// HasNames checks if the AllowedNameOptions has one or more +// names configured. +func (o *AllowedX509NameOptions) HasNames() bool { + return len(o.DNSDomains) > 0 || + len(o.IPRanges) > 0 || + len(o.EmailAddresses) > 0 || + len(o.URIDomains) > 0 +} + +// HasNames checks if the DeniedNameOptions has one or more +// names configured. +func (o *DeniedX509NameOptions) HasNames() bool { + return len(o.DNSDomains) > 0 || + len(o.IPRanges) > 0 || + len(o.EmailAddresses) > 0 || + len(o.URIDomains) > 0 +} + // TemplateOptions generates a CertificateOptions with the template and data // defined in the ProvisionerOptions, the provisioner generated data, and the // user data provided in the request. If no template has been provided, diff --git a/authority/provisioner/policy.go b/authority/provisioner/policy.go new file mode 100644 index 00000000..cf436d70 --- /dev/null +++ b/authority/provisioner/policy.go @@ -0,0 +1,68 @@ +package provisioner + +import ( + sshpolicy "github.com/smallstep/certificates/policy/ssh" + x509policy "github.com/smallstep/certificates/policy/x509" +) + +// newX509PolicyEngine creates a new x509 name policy engine +func newX509PolicyEngine(x509Opts *X509Options) (*x509policy.NamePolicyEngine, error) { + + if x509Opts == nil { + return nil, nil + } + + options := []x509policy.NamePolicyOption{} + + allowed := x509Opts.GetAllowedNameOptions() + if allowed != nil && allowed.HasNames() { + options = append(options, + x509policy.WithPermittedDNSDomains(allowed.DNSDomains), // TODO(hs): be a bit more lenient w.r.t. the format of domains? I.e. allow "*.localhost" instead of the ".localhost", which is what Name Constraints do. + x509policy.WithPermittedCIDRs(allowed.IPRanges), // TODO(hs): support IPs in addition to ranges + x509policy.WithPermittedEmailAddresses(allowed.EmailAddresses), + x509policy.WithPermittedURIDomains(allowed.URIDomains), + ) + } + + denied := x509Opts.GetDeniedNameOptions() + if denied != nil && denied.HasNames() { + options = append(options, + x509policy.WithExcludedDNSDomains(denied.DNSDomains), // TODO(hs): be a bit more lenient w.r.t. the format of domains? I.e. allow "*.localhost" instead of the ".localhost", which is what Name Constraints do. + x509policy.WithExcludedCIDRs(denied.IPRanges), // TODO(hs): support IPs in addition to ranges + x509policy.WithExcludedEmailAddresses(denied.EmailAddresses), + x509policy.WithExcludedURIDomains(denied.URIDomains), + ) + } + + return x509policy.New(options...) +} + +// newSSHPolicyEngine creates a new SSH name policy engine +func newSSHPolicyEngine(sshOpts *SSHOptions) (*sshpolicy.NamePolicyEngine, error) { + + if sshOpts == nil { + return nil, nil + } + + options := []sshpolicy.NamePolicyOption{} + + allowed := sshOpts.GetAllowedNameOptions() + if allowed != nil && allowed.HasNames() { + options = append(options, + sshpolicy.WithPermittedDNSDomains(allowed.DNSDomains), // TODO(hs): be a bit more lenient w.r.t. the format of domains? I.e. allow "*.localhost" instead of the ".localhost", which is what Name Constraints do. + sshpolicy.WithPermittedEmailAddresses(allowed.EmailAddresses), + sshpolicy.WithPermittedPrincipals(allowed.Principals), + ) + } + + denied := sshOpts.GetDeniedNameOptions() + if denied != nil && denied.HasNames() { + options = append(options, + sshpolicy.WithExcludedDNSDomains(denied.DNSDomains), // TODO(hs): be a bit more lenient w.r.t. the format of domains? I.e. allow "*.localhost" instead of the ".localhost", which is what Name Constraints do. + sshpolicy.WithExcludedEmailAddresses(denied.EmailAddresses), + sshpolicy.WithExcludedPrincipals(denied.Principals), + ) + } + + return sshpolicy.New(options...) +} diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 5d6b2f80..34ea8c4d 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -12,6 +12,8 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/errs" + sshpolicy "github.com/smallstep/certificates/policy/ssh" + x509policy "github.com/smallstep/certificates/policy/x509" "golang.org/x/crypto/ssh" ) @@ -298,7 +300,10 @@ func SanitizeSSHUserPrincipal(email string) string { }, strings.ToLower(email)) } -type base struct{} +type base struct { + x509PolicyEngine *x509policy.NamePolicyEngine + sshPolicyEngine *sshpolicy.NamePolicyEngine +} // AuthorizeSign returns an unimplemented error. Provisioners should overwrite // this method if they will support authorizing tokens for signing x509 Certificates. diff --git a/authority/provisioner/scep.go b/authority/provisioner/scep.go index 145a1920..7c78d14b 100644 --- a/authority/provisioner/scep.go +++ b/authority/provisioner/scep.go @@ -74,7 +74,7 @@ func (s *SCEP) DefaultTLSCertDuration() time.Duration { // Init initializes and validates the fields of a SCEP type. func (s *SCEP) Init(config Config) (err error) { - + s.base = &base{} // prevent nil pointers switch { case s.Type == "": return errors.New("provisioner type cannot be empty") @@ -102,6 +102,11 @@ func (s *SCEP) Init(config Config) (err error) { // TODO: add other, SCEP specific, options? + // Initialize the x509 allow/deny policy engine + if s.x509PolicyEngine, err = newX509PolicyEngine(s.Options.GetX509Options()); err != nil { + return err + } + return err } @@ -117,6 +122,7 @@ func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // validators newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength), newValidityValidator(s.claimer.MinTLSCertDuration(), s.claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(s.x509PolicyEngine), }, nil } diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go old mode 100644 new mode 100755 index 34b2e99b..ccc55435 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -16,6 +16,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/errs" + x509policy "github.com/smallstep/certificates/policy/x509" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/x509util" ) @@ -404,6 +405,32 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error { return nil } +// x509NamePolicyValidator validates that the certificate (to be signed) +// contains only allowed SANs. +type x509NamePolicyValidator struct { + policyEngine *x509policy.NamePolicyEngine +} + +// newX509NamePolicyValidator return a new SANs allow/deny validator. +func newX509NamePolicyValidator(engine *x509policy.NamePolicyEngine) *x509NamePolicyValidator { + return &x509NamePolicyValidator{ + policyEngine: engine, + } +} + +// Valid validates validates that the certificate (to be signed) +// contains only allowed SANs. +func (v *x509NamePolicyValidator) Valid(cert *x509.Certificate, _ SignOptions) error { + if v.policyEngine == nil { + return nil + } + _, err := v.policyEngine.AreCertificateNamesAllowed(cert) + if err != nil { + return err + } + return nil +} + var ( stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go index a2ca78b1..e5bd2121 100644 --- a/authority/provisioner/sign_ssh_options.go +++ b/authority/provisioner/sign_ssh_options.go @@ -10,6 +10,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/errs" + sshpolicy "github.com/smallstep/certificates/policy/ssh" "go.step.sm/crypto/keyutil" "golang.org/x/crypto/ssh" ) @@ -444,6 +445,35 @@ func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SignSSHOpti } } +// sshNamePolicyValidator validates that the certificate (to be signed) +// contains only allowed principals. +type sshNamePolicyValidator struct { + policyEngine *sshpolicy.NamePolicyEngine +} + +// newSSHNamePolicyValidator return a new SSH allow/deny validator. +func newSSHNamePolicyValidator(engine *sshpolicy.NamePolicyEngine) *sshNamePolicyValidator { + return &sshNamePolicyValidator{ + policyEngine: engine, + } +} + +// Valid validates validates that the certificate (to be signed) +// contains only allowed principals. +func (v *sshNamePolicyValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions) error { + if v.policyEngine == nil { + return nil + } + // TODO(hs): should this perform checks only for hosts vs. user certs depending on context? + // The current best practice is to have separate provisioners for hosts and users, and thus + // separate policy engines for the principals that are allowed. + _, err := v.policyEngine.ArePrincipalsAllowed(cert) + if err != nil { + return err + } + return nil +} + // sshCertTypeUInt32 func sshCertTypeUInt32(ct string) uint32 { switch ct { diff --git a/authority/provisioner/ssh_options.go b/authority/provisioner/ssh_options.go index 7ee236d1..ada26d7d 100644 --- a/authority/provisioner/ssh_options.go +++ b/authority/provisioner/ssh_options.go @@ -33,6 +33,26 @@ type SSHOptions struct { // TemplateData is a JSON object with variables that can be used in custom // templates. TemplateData json.RawMessage `json:"templateData,omitempty"` + + // AllowedNames contains the names the provisioner is authorized to sign + AllowedNames *AllowedSSHNameOptions `json:"allow,omitempty"` + + // DeniedNames contains the names the provisioner is not authorized to sign + DeniedNames *DeniedSSHNameOptions `json:"deny,omitempty"` +} + +// AllowedSSHNameOptions models the allowed names +type AllowedSSHNameOptions struct { + DNSDomains []string `json:"dns,omitempty"` + EmailAddresses []string `json:"email,omitempty"` + Principals []string `json:"principal,omitempty"` +} + +// DeniedSSHNameOptions models the denied names +type DeniedSSHNameOptions struct { + DNSDomains []string `json:"dns,omitempty"` + EmailAddresses []string `json:"email,omitempty"` + Principals []string `json:"principal,omitempty"` } // HasTemplate returns true if a template is defined in the provisioner options. @@ -40,6 +60,40 @@ func (o *SSHOptions) HasTemplate() bool { return o != nil && (o.Template != "" || o.TemplateFile != "") } +// GetAllowedNameOptions returns the AllowedSSHNameOptions, which models the +// names that a provisioner is authorized to sign SSH certificates for. +func (o *SSHOptions) GetAllowedNameOptions() *AllowedSSHNameOptions { + if o == nil { + return nil + } + return o.AllowedNames +} + +// GetDeniedNameOptions returns the DeniedSSHNameOptions, which models the +// names that a provisioner is NOT authorized to sign SSH certificates for. +func (o *SSHOptions) GetDeniedNameOptions() *DeniedSSHNameOptions { + if o == nil { + return nil + } + return o.DeniedNames +} + +// HasNames checks if the AllowedSSHNameOptions has one or more +// names configured. +func (o *AllowedSSHNameOptions) HasNames() bool { + return len(o.DNSDomains) > 0 || + len(o.EmailAddresses) > 0 || + len(o.Principals) > 0 +} + +// HasNames checks if the DeniedSSHNameOptions has one or more +// names configured. +func (o *DeniedSSHNameOptions) HasNames() bool { + return len(o.DNSDomains) > 0 || + len(o.EmailAddresses) > 0 || + len(o.Principals) > 0 +} + // TemplateSSHOptions generates a SSHCertificateOptions with the template and // data defined in the ProvisionerOptions, the provisioner generated data, and // the user data provided in the request. If no template has been provided, diff --git a/authority/provisioner/sshpop.go b/authority/provisioner/sshpop.go index 3039d2a3..b41f512e 100644 --- a/authority/provisioner/sshpop.go +++ b/authority/provisioner/sshpop.go @@ -84,6 +84,7 @@ func (p *SSHPOP) GetEncryptedKey() (string, string, bool) { // Init initializes and validates the fields of a SSHPOP type. func (p *SSHPOP) Init(config Config) error { + p.base = &base{} // prevent nil pointers switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -99,6 +100,8 @@ func (p *SSHPOP) Init(config Config) error { return err } + // TODO(hs): initialize the policy engine and add it as an SSH cert validator + p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) p.sshPubKeys = config.SSHKeys return nil diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index fe2678fc..ea0890ae 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -177,6 +177,7 @@ func generateJWK() (*JWK, error) { return nil, err } return &JWK{ + base: &base{}, Name: name, Type: "JWK", Key: &public, @@ -215,6 +216,7 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { } return &K8sSA{ + base: &base{}, Name: K8sSAName, Type: "K8sSA", Claims: &globalProvisionerClaims, @@ -252,6 +254,7 @@ func generateSSHPOP() (*SSHPOP, error) { } return &SSHPOP{ + base: &base{}, Name: name, Type: "SSHPOP", Claims: &globalProvisionerClaims, @@ -306,6 +309,7 @@ M46l92gdOozT rootPool.AddCert(cert) } return &X5C{ + base: &base{}, Name: name, Type: "X5C", Roots: root, @@ -338,6 +342,7 @@ func generateOIDC() (*OIDC, error) { return nil, err } return &OIDC{ + base: &base{}, Name: name, Type: "OIDC", ClientID: clientID, @@ -373,6 +378,7 @@ func generateGCP() (*GCP, error) { return nil, err } return &GCP{ + base: &base{}, Type: "GCP", Name: name, ServiceAccounts: []string{serviceAccount}, @@ -409,6 +415,7 @@ func generateAWS() (*AWS, error) { return nil, errors.Wrap(err, "error parsing AWS certificate") } return &AWS{ + base: &base{}, Type: "AWS", Name: name, Accounts: []string{accountID}, @@ -518,6 +525,7 @@ func generateAWSV1Only() (*AWS, error) { return nil, errors.Wrap(err, "error parsing AWS certificate") } return &AWS{ + base: &base{}, Type: "AWS", Name: name, Accounts: []string{accountID}, @@ -609,6 +617,7 @@ func generateAzure() (*Azure, error) { return nil, err } return &Azure{ + base: &base{}, Type: "Azure", Name: name, TenantID: tenantID, diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index 8710acb5..a87e4392 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -87,6 +87,7 @@ func (p *X5C) GetEncryptedKey() (string, string, bool) { // Init initializes and validates the fields of a X5C type. func (p *X5C) Init(config Config) error { + p.base = &base{} // prevent nil pointers switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -125,6 +126,16 @@ func (p *X5C) Init(config Config) error { return err } + // Initialize the x509 allow/deny policy engine + if p.x509PolicyEngine, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { + return err + } + + // Initialize the SSH allow/deny policy engine + if p.sshPolicyEngine, err = newSSHPolicyEngine(p.Options.GetSSHOptions()); err != nil { + return err + } + p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) return nil } @@ -229,6 +240,7 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er defaultSANsValidator(claims.SANs), defaultPublicKeyValidator{}, newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.x509PolicyEngine), }, nil } @@ -311,5 +323,7 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, &sshCertValidityValidator{p.claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.sshPolicyEngine), ), nil } diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 2959f8c6..5d2a3566 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -463,7 +463,7 @@ func TestX5C_AuthorizeSign(t *testing.T) { } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { - assert.Equals(t, len(opts), 7) + assert.Len(t, 8, opts) for _, o := range opts { switch v := o.(type) { case certificateOptionsFunc: @@ -474,7 +474,6 @@ func TestX5C_AuthorizeSign(t *testing.T) { assert.Len(t, 0, v.KeyValuePairs) case profileLimitDuration: assert.Equals(t, v.def, tc.p.claimer.DefaultTLSCertDuration()) - claims, err := tc.p.authorizeToken(tc.token, tc.p.audiences.Sign) assert.FatalError(t, err) assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter) @@ -486,6 +485,8 @@ func TestX5C_AuthorizeSign(t *testing.T) { case *validityValidator: assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } @@ -778,6 +779,8 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter) case *sshCertValidityValidator: assert.Equals(t, v.Claimer, tc.p.claimer) + case *sshNamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc: default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) @@ -785,9 +788,9 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { tot++ } if len(tc.claims.Step.SSH.CertType) > 0 { - assert.Equals(t, tot, 9) + assert.Equals(t, tot, 10) } else { - assert.Equals(t, tot, 7) + assert.Equals(t, tot, 8) } } } diff --git a/policy/ssh/options.go b/policy/ssh/options.go new file mode 100644 index 00000000..30b68a1d --- /dev/null +++ b/policy/ssh/options.go @@ -0,0 +1,99 @@ +package sshpolicy + +import ( + "fmt" + "strings" + + "github.com/pkg/errors" +) + +type NamePolicyOption func(g *NamePolicyEngine) error + +func WithPermittedDNSDomains(domains []string) NamePolicyOption { + return func(g *NamePolicyEngine) error { + for _, domain := range domains { + if err := validateDNSDomainConstraint(domain); err != nil { + return errors.Errorf("cannot parse permitted domain constraint %q", domain) + } + } + g.permittedDNSDomains = domains + return nil + } +} + +func WithExcludedDNSDomains(domains []string) NamePolicyOption { + return func(g *NamePolicyEngine) error { + for _, domain := range domains { + if err := validateDNSDomainConstraint(domain); err != nil { + return errors.Errorf("cannot parse excluded domain constraint %q", domain) + } + } + g.excludedDNSDomains = domains + return nil + } +} + +func WithPermittedEmailAddresses(emailAddresses []string) NamePolicyOption { + return func(g *NamePolicyEngine) error { + for _, email := range emailAddresses { + if err := validateEmailConstraint(email); err != nil { + return err + } + } + g.permittedEmailAddresses = emailAddresses + return nil + } +} + +func WithExcludedEmailAddresses(emailAddresses []string) NamePolicyOption { + return func(g *NamePolicyEngine) error { + for _, email := range emailAddresses { + if err := validateEmailConstraint(email); err != nil { + return err + } + } + g.excludedEmailAddresses = emailAddresses + return nil + } +} + +func WithPermittedPrincipals(principals []string) NamePolicyOption { + return func(g *NamePolicyEngine) error { + // for _, principal := range principals { + // // TODO: validation? + // } + g.permittedPrincipals = principals + return nil + } +} + +func WithExcludedPrincipals(principals []string) NamePolicyOption { + return func(g *NamePolicyEngine) error { + // for _, principal := range principals { + // // TODO: validation? + // } + g.excludedPrincipals = principals + return nil + } +} + +func validateDNSDomainConstraint(domain string) error { + if _, ok := domainToReverseLabels(domain); !ok { + return errors.Errorf("cannot parse permitted domain constraint %q", domain) + } + return nil +} + +func validateEmailConstraint(constraint string) error { + if strings.Contains(constraint, "@") { + _, ok := parseRFC2821Mailbox(constraint) + if !ok { + return fmt.Errorf("cannot parse email constraint %q", constraint) + } + } + _, ok := domainToReverseLabels(constraint) + if !ok { + return fmt.Errorf("cannot parse email domain constraint %q", constraint) + } + return nil +} diff --git a/policy/ssh/ssh.go b/policy/ssh/ssh.go new file mode 100644 index 00000000..95e7d471 --- /dev/null +++ b/policy/ssh/ssh.go @@ -0,0 +1,472 @@ +package sshpolicy + +import ( + "bytes" + "crypto/x509" + "fmt" + "reflect" + "strings" + + "github.com/pkg/errors" + "golang.org/x/crypto/ssh" +) + +type CertificateInvalidError struct { + Reason x509.InvalidReason + Detail string +} + +func (e CertificateInvalidError) Error() string { + switch e.Reason { + // TODO: include logical errors for this package; exlude ones that don't make sense for its current use case? + // TODO: currently only CANotAuthorizedForThisName is used by this package; we're not checking the other things in CSRs in this package. + case x509.NotAuthorizedToSign: + return "not authorized to sign other certificates" // TODO: this one doesn't make sense for this pkg + case x509.Expired: + return "csr has expired or is not yet valid: " + e.Detail + case x509.CANotAuthorizedForThisName: + return "not authorized to sign for this name: " + e.Detail + case x509.CANotAuthorizedForExtKeyUsage: + return "not authorized for an extended key usage: " + e.Detail + case x509.TooManyIntermediates: + return "too many intermediates for path length constraint" + case x509.IncompatibleUsage: + return "csr specifies an incompatible key usage" + case x509.NameMismatch: + return "issuer name does not match subject from issuing certificate" + case x509.NameConstraintsWithoutSANs: + return "issuer has name constraints but csr doesn't have a SAN extension" + case x509.UnconstrainedName: + return "issuer has name constraints but csr contains unknown or unconstrained name: " + e.Detail + } + return "unknown error" +} + +type NamePolicyEngine struct { + options []NamePolicyOption + permittedDNSDomains []string + excludedDNSDomains []string + permittedEmailAddresses []string + excludedEmailAddresses []string + permittedPrincipals []string // TODO: rename to usernames, as principals can be host, user@ (like mail) and usernames? + excludedPrincipals []string +} + +func New(opts ...NamePolicyOption) (*NamePolicyEngine, error) { + + e := &NamePolicyEngine{} // TODO: embed an x509 engine instead of building it again? + e.options = append(e.options, opts...) + for _, option := range e.options { + if err := option(e); err != nil { + return nil, err + } + } + + return e, nil +} + +func (e *NamePolicyEngine) ArePrincipalsAllowed(cert *ssh.Certificate) (bool, error) { + dnsNames, emails, userNames := splitPrincipals(cert.ValidPrincipals) + if err := e.validateNames(dnsNames, emails, userNames); err != nil { + return false, err + } + return true, nil +} + +func (e *NamePolicyEngine) validateNames(dnsNames, emails, userNames []string) error { + //"dns": ["*.smallstep.com"], + //"email": ["@smallstep.com", "@google.com"], + //"principal": ["max", "mariano", "mike"] + /* No regexes for now. But if we ever implement them, they'd probably look like this */ + /*"principal": ["foo.smallstep.com", "/^*\.smallstep\.com$/"]*/ + + // Principals can be single user names (mariano, max, mike, ...), hostnames/domains (*.smallstep.com, host.smallstep.com, ...) and emails (max@smallstep.com, @smallstep.com, ...) + // All ValidPrincipals can thus be any one of those, and they can be mixed (mike@smallstep.com, mike, ...); we need to split this? + // Should we assume a generic engine, or can we do it host vs. user based? If host vs. user based, then it becomes easier w.r.t. dns; hosts will only be DNS, right? + // If we assume generic, we _may_ have a harder time distinguishing host vs. user certs. We propose to use host + user specific provisioners, though... + // Perhaps we can do some heuristics on the principal names vs. hostnames (i.e. when only a single label and no dot, then it's a user principal) + + for _, dns := range dnsNames { + if _, ok := domainToReverseLabels(dns); !ok { + return errors.Errorf("cannot parse dns %q", dns) + } + if err := checkNameConstraints("dns", dns, dns, + func(parsedName, constraint interface{}) (bool, error) { + return matchDomainConstraint(parsedName.(string), constraint.(string)) + }, e.permittedDNSDomains, e.excludedDNSDomains); err != nil { + return err + } + } + + for _, email := range emails { + mailbox, ok := parseRFC2821Mailbox(email) + if !ok { + return fmt.Errorf("cannot parse rfc822Name %q", mailbox) + } + if err := checkNameConstraints("email", email, mailbox, + func(parsedName, constraint interface{}) (bool, error) { + return matchEmailConstraint(parsedName.(rfc2821Mailbox), constraint.(string)) + }, e.permittedEmailAddresses, e.excludedEmailAddresses); err != nil { + return err + } + } + + for _, userName := range userNames { + // TODO: some validation? I.e. allowed characters? + if err := checkNameConstraints("username", userName, userName, + func(parsedName, constraint interface{}) (bool, error) { + return matchUserNameConstraint(parsedName.(string), constraint.(string)) + }, e.permittedPrincipals, e.excludedPrincipals); err != nil { + return err + } + } + + return nil +} + +// splitPrincipals splits SSH certificate principals into DNS names, emails and user names. +func splitPrincipals(principals []string) (dnsNames, emails, userNames []string) { + dnsNames = []string{} + emails = []string{} + userNames = []string{} + for _, principal := range principals { + if strings.Contains(principal, "@") { + emails = append(emails, principal) + } else if len(strings.Split(principal, ".")) > 1 { + dnsNames = append(dnsNames, principal) + } else { + userNames = append(userNames, principal) + } + } + return +} + +// checkNameConstraints checks that c permits a child certificate to claim the +// given name, of type nameType. The argument parsedName contains the parsed +// form of name, suitable for passing to the match function. The total number +// of comparisons is tracked in the given count and should not exceed the given +// limit. +// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +func checkNameConstraints( + nameType string, + name string, + parsedName interface{}, + match func(parsedName, constraint interface{}) (match bool, err error), + permitted, excluded interface{}) error { + + excludedValue := reflect.ValueOf(excluded) + + // *count += excludedValue.Len() + // if *count > maxConstraintComparisons { + // return x509.CertificateInvalidError{c, x509.TooManyConstraints, ""} + // } + + // TODO: fix the errors; return our own, because we don't have cert ... + + for i := 0; i < excludedValue.Len(); i++ { + constraint := excludedValue.Index(i).Interface() + match, err := match(parsedName, constraint) + if err != nil { + return CertificateInvalidError{ + Reason: x509.CANotAuthorizedForThisName, + Detail: err.Error(), + } + } + + if match { + return CertificateInvalidError{ + Reason: x509.CANotAuthorizedForThisName, + Detail: fmt.Sprintf("%s %q is excluded by constraint %q", nameType, name, constraint), + } + } + } + + permittedValue := reflect.ValueOf(permitted) + + // *count += permittedValue.Len() + // if *count > maxConstraintComparisons { + // return x509.CertificateInvalidError{c, x509.TooManyConstraints, ""} + // } + + ok := true + for i := 0; i < permittedValue.Len(); i++ { + constraint := permittedValue.Index(i).Interface() + var err error + if ok, err = match(parsedName, constraint); err != nil { + return CertificateInvalidError{ + Reason: x509.CANotAuthorizedForThisName, + Detail: err.Error(), + } + } + + if ok { + break + } + } + + if !ok { + return CertificateInvalidError{ + Reason: x509.CANotAuthorizedForThisName, + Detail: fmt.Sprintf("%s %q is not permitted by any constraint", nameType, name), + } + } + + return nil +} + +// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +func matchDomainConstraint(domain, constraint string) (bool, error) { + // The meaning of zero length constraints is not specified, but this + // code follows NSS and accepts them as matching everything. + if constraint == "" { + return true, nil + } + + domainLabels, ok := domainToReverseLabels(domain) + if !ok { + return false, fmt.Errorf("cannot parse domain %q", domain) + } + + // RFC 5280 says that a leading period in a domain name means that at + // least one label must be prepended, but only for URI and email + // constraints, not DNS constraints. The code also supports that + // behavior for DNS constraints. + + mustHaveSubdomains := false + if constraint[0] == '.' { + mustHaveSubdomains = true + constraint = constraint[1:] + } + + constraintLabels, ok := domainToReverseLabels(constraint) + if !ok { + return false, fmt.Errorf("cannot parse domain %q", constraint) + } + + if len(domainLabels) < len(constraintLabels) || + (mustHaveSubdomains && len(domainLabels) == len(constraintLabels)) { + return false, nil + } + + for i, constraintLabel := range constraintLabels { + if !strings.EqualFold(constraintLabel, domainLabels[i]) { + return false, nil + } + } + + return true, nil +} + +// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +func matchEmailConstraint(mailbox rfc2821Mailbox, constraint string) (bool, error) { + // If the constraint contains an @, then it specifies an exact mailbox name. + if strings.Contains(constraint, "@") { + constraintMailbox, ok := parseRFC2821Mailbox(constraint) + if !ok { + return false, fmt.Errorf("cannot parse constraint %q", constraint) + } + return mailbox.local == constraintMailbox.local && strings.EqualFold(mailbox.domain, constraintMailbox.domain), nil + } + + // Otherwise the constraint is like a DNS constraint of the domain part + // of the mailbox. + return matchDomainConstraint(mailbox.domain, constraint) +} + +// matchUserNameConstraint performs a string literal match against a constraint +func matchUserNameConstraint(userName, constraint string) (bool, error) { + return userName == constraint, nil +} + +// TODO: decrease code duplication: single policy engine again, with principals added, but not used in x509? +// Not sure how I'd like to model that in Go, though: use (embedded) structs? interfaces? An x509 name policy engine +// interface could expose the methods that are useful to x509; the SSH name policy engine interfaces could do the +// same for SSH ones. One interface for both (with no methods?); then two, so that not all name policy options +// can be executed on both types? The shared ones could then maybe use the one with no methods? But we need protect +// it from being applied to just any type, of course. Not sure if Go allows us to do something like that, though. +// Maybe some kind of dummy function helps there? + +// domainToReverseLabels converts a textual domain name like foo.example.com to +// the list of labels in reverse order, e.g. ["com", "example", "foo"]. +// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +func domainToReverseLabels(domain string) (reverseLabels []string, ok bool) { + for len(domain) > 0 { + if i := strings.LastIndexByte(domain, '.'); i == -1 { + reverseLabels = append(reverseLabels, domain) + domain = "" + } else { + reverseLabels = append(reverseLabels, domain[i+1:]) + domain = domain[:i] + } + } + + if len(reverseLabels) > 0 && reverseLabels[0] == "" { + // An empty label at the end indicates an absolute value. + return nil, false + } + + for _, label := range reverseLabels { + if label == "" { + // Empty labels are otherwise invalid. + return nil, false + } + + for _, c := range label { + if c < 33 || c > 126 { + // Invalid character. + return nil, false + } + } + } + + return reverseLabels, true +} + +// rfc2821Mailbox represents a “mailbox” (which is an email address to most +// people) by breaking it into the “local” (i.e. before the '@') and “domain” +// parts. +// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +type rfc2821Mailbox struct { + local, domain string +} + +// parseRFC2821Mailbox parses an email address into local and domain parts, +// based on the ABNF for a “Mailbox” from RFC 2821. According to RFC 5280, +// Section 4.2.1.6 that's correct for an rfc822Name from a certificate: “The +// format of an rfc822Name is a "Mailbox" as defined in RFC 2821, Section 4.1.2”. +// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +func parseRFC2821Mailbox(in string) (mailbox rfc2821Mailbox, ok bool) { + if in == "" { + return mailbox, false + } + + localPartBytes := make([]byte, 0, len(in)/2) + + if in[0] == '"' { + // Quoted-string = DQUOTE *qcontent DQUOTE + // non-whitespace-control = %d1-8 / %d11 / %d12 / %d14-31 / %d127 + // qcontent = qtext / quoted-pair + // qtext = non-whitespace-control / + // %d33 / %d35-91 / %d93-126 + // quoted-pair = ("\" text) / obs-qp + // text = %d1-9 / %d11 / %d12 / %d14-127 / obs-text + // + // (Names beginning with “obs-” are the obsolete syntax from RFC 2822, + // Section 4. Since it has been 16 years, we no longer accept that.) + in = in[1:] + QuotedString: + for { + if in == "" { + return mailbox, false + } + c := in[0] + in = in[1:] + + switch { + case c == '"': + break QuotedString + + case c == '\\': + // quoted-pair + if in == "" { + return mailbox, false + } + if in[0] == 11 || + in[0] == 12 || + (1 <= in[0] && in[0] <= 9) || + (14 <= in[0] && in[0] <= 127) { + localPartBytes = append(localPartBytes, in[0]) + in = in[1:] + } else { + return mailbox, false + } + + case c == 11 || + c == 12 || + // Space (char 32) is not allowed based on the + // BNF, but RFC 3696 gives an example that + // assumes that it is. Several “verified” + // errata continue to argue about this point. + // We choose to accept it. + c == 32 || + c == 33 || + c == 127 || + (1 <= c && c <= 8) || + (14 <= c && c <= 31) || + (35 <= c && c <= 91) || + (93 <= c && c <= 126): + // qtext + localPartBytes = append(localPartBytes, c) + + default: + return mailbox, false + } + } + } else { + // Atom ("." Atom)* + NextChar: + for len(in) > 0 { + // atext from RFC 2822, Section 3.2.4 + c := in[0] + + switch { + case c == '\\': + // Examples given in RFC 3696 suggest that + // escaped characters can appear outside of a + // quoted string. Several “verified” errata + // continue to argue the point. We choose to + // accept it. + in = in[1:] + if in == "" { + return mailbox, false + } + fallthrough + + case ('0' <= c && c <= '9') || + ('a' <= c && c <= 'z') || + ('A' <= c && c <= 'Z') || + c == '!' || c == '#' || c == '$' || c == '%' || + c == '&' || c == '\'' || c == '*' || c == '+' || + c == '-' || c == '/' || c == '=' || c == '?' || + c == '^' || c == '_' || c == '`' || c == '{' || + c == '|' || c == '}' || c == '~' || c == '.': + localPartBytes = append(localPartBytes, in[0]) + in = in[1:] + + default: + break NextChar + } + } + + if len(localPartBytes) == 0 { + return mailbox, false + } + + // From RFC 3696, Section 3: + // “period (".") may also appear, but may not be used to start + // or end the local part, nor may two or more consecutive + // periods appear.” + twoDots := []byte{'.', '.'} + if localPartBytes[0] == '.' || + localPartBytes[len(localPartBytes)-1] == '.' || + bytes.Contains(localPartBytes, twoDots) { + return mailbox, false + } + } + + if in == "" || in[0] != '@' { + return mailbox, false + } + in = in[1:] + + // The RFC species a format for domains, but that's known to be + // violated in practice so we accept that anything after an '@' is the + // domain part. + if _, ok := domainToReverseLabels(in); !ok { + return mailbox, false + } + + mailbox.local = string(localPartBytes) + mailbox.domain = in + return mailbox, true +} diff --git a/policy/x509/options.go b/policy/x509/options.go new file mode 100755 index 00000000..68f236cb --- /dev/null +++ b/policy/x509/options.go @@ -0,0 +1,506 @@ +package x509policy + +import ( + "fmt" + "net" + "strings" + + "github.com/pkg/errors" +) + +type NamePolicyOption func(e *NamePolicyEngine) error + +// TODO: wrap (more) errors; and prove a set of known (exported) errors + +func WithPermittedDNSDomains(domains []string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + for _, domain := range domains { + if err := validateDNSDomainConstraint(domain); err != nil { + return errors.Errorf("cannot parse permitted domain constraint %q", domain) + } + } + e.permittedDNSDomains = domains + return nil + } +} + +func AddPermittedDNSDomains(domains []string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + for _, domain := range domains { + if err := validateDNSDomainConstraint(domain); err != nil { + return errors.Errorf("cannot parse permitted domain constraint %q", domain) + } + } + e.permittedDNSDomains = append(e.permittedDNSDomains, domains...) + return nil + } +} + +func WithExcludedDNSDomains(domains []string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + for _, domain := range domains { + if err := validateDNSDomainConstraint(domain); err != nil { + return errors.Errorf("cannot parse excluded domain constraint %q", domain) + } + } + e.excludedDNSDomains = domains + return nil + } +} + +func AddExcludedDNSDomains(domains []string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + for _, domain := range domains { + if err := validateDNSDomainConstraint(domain); err != nil { + return errors.Errorf("cannot parse excluded domain constraint %q", domain) + } + } + e.excludedDNSDomains = append(e.excludedDNSDomains, domains...) + return nil + } +} + +func WithPermittedDNSDomain(domain string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + if err := validateDNSDomainConstraint(domain); err != nil { + return errors.Errorf("cannot parse permitted domain constraint %q", domain) + } + e.permittedDNSDomains = []string{domain} + return nil + } +} + +func AddPermittedDNSDomain(domain string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + if err := validateDNSDomainConstraint(domain); err != nil { + return errors.Errorf("cannot parse permitted domain constraint %q", domain) + } + e.permittedDNSDomains = append(e.permittedDNSDomains, domain) + return nil + } +} + +func WithExcludedDNSDomain(domain string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + if err := validateDNSDomainConstraint(domain); err != nil { + return errors.Errorf("cannot parse excluded domain constraint %q", domain) + } + e.excludedDNSDomains = []string{domain} + return nil + } +} + +func AddExcludedDNSDomain(domain string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + if err := validateDNSDomainConstraint(domain); err != nil { + return errors.Errorf("cannot parse excluded domain constraint %q", domain) + } + e.excludedDNSDomains = append(e.excludedDNSDomains, domain) + return nil + } +} + +func WithPermittedIPRanges(ipRanges []*net.IPNet) NamePolicyOption { + return func(e *NamePolicyEngine) error { + e.permittedIPRanges = ipRanges + return nil + } +} + +func AddPermittedIPRanges(ipRanges []*net.IPNet) NamePolicyOption { + return func(e *NamePolicyEngine) error { + e.permittedIPRanges = append(e.permittedIPRanges, ipRanges...) + return nil + } +} + +func WithPermittedCIDRs(cidrs []string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + networks := []*net.IPNet{} + for _, cidr := range cidrs { + _, nw, err := net.ParseCIDR(cidr) + if err != nil { + return errors.Errorf("cannot parse permitted CIDR constraint %q", cidr) + } + networks = append(networks, nw) + } + e.permittedIPRanges = networks + return nil + } +} + +func AddPermittedCIDRs(cidrs []string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + networks := []*net.IPNet{} + for _, cidr := range cidrs { + _, nw, err := net.ParseCIDR(cidr) + if err != nil { + return errors.Errorf("cannot parse permitted CIDR constraint %q", cidr) + } + networks = append(networks, nw) + } + e.permittedIPRanges = append(e.permittedIPRanges, networks...) + return nil + } +} + +func WithExcludedCIDRs(cidrs []string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + networks := []*net.IPNet{} + for _, cidr := range cidrs { + _, nw, err := net.ParseCIDR(cidr) + if err != nil { + return errors.Errorf("cannot parse excluded CIDR constraint %q", cidr) + } + networks = append(networks, nw) + } + e.excludedIPRanges = networks + return nil + } +} + +func AddExcludedCIDRs(cidrs []string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + networks := []*net.IPNet{} + for _, cidr := range cidrs { + _, nw, err := net.ParseCIDR(cidr) + if err != nil { + return errors.Errorf("cannot parse excluded CIDR constraint %q", cidr) + } + networks = append(networks, nw) + } + e.excludedIPRanges = append(e.excludedIPRanges, networks...) + return nil + } +} + +func WithPermittedCIDR(cidr string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + _, nw, err := net.ParseCIDR(cidr) + if err != nil { + return errors.Errorf("cannot parse permitted CIDR constraint %q", cidr) + } + e.permittedIPRanges = []*net.IPNet{nw} + return nil + } +} + +func AddPermittedCIDR(cidr string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + _, nw, err := net.ParseCIDR(cidr) + if err != nil { + return errors.Errorf("cannot parse permitted CIDR constraint %q", cidr) + } + e.permittedIPRanges = append(e.permittedIPRanges, nw) + return nil + } +} + +func WithPermittedIP(ip net.IP) NamePolicyOption { + return func(e *NamePolicyEngine) error { + var mask net.IPMask + if !isIPv4(ip) { + mask = net.CIDRMask(128, 128) + } else { + mask = net.CIDRMask(32, 32) + } + nw := &net.IPNet{ + IP: ip, + Mask: mask, + } + e.permittedIPRanges = []*net.IPNet{nw} + return nil + } +} + +func AddPermittedIP(ip net.IP) NamePolicyOption { + return func(e *NamePolicyEngine) error { + var mask net.IPMask + if !isIPv4(ip) { + mask = net.CIDRMask(128, 128) + } else { + mask = net.CIDRMask(32, 32) + } + nw := &net.IPNet{ + IP: ip, + Mask: mask, + } + e.permittedIPRanges = append(e.permittedIPRanges, nw) + return nil + } +} + +func WithExcludedIPRanges(ipRanges []*net.IPNet) NamePolicyOption { + return func(e *NamePolicyEngine) error { + e.excludedIPRanges = ipRanges + return nil + } +} + +func AddExcludedIPRanges(ipRanges []*net.IPNet) NamePolicyOption { + return func(e *NamePolicyEngine) error { + e.excludedIPRanges = append(e.excludedIPRanges, ipRanges...) + return nil + } +} + +func WithExcludedCIDR(cidr string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + _, nw, err := net.ParseCIDR(cidr) + if err != nil { + return errors.Errorf("cannot parse excluded CIDR constraint %q", cidr) + } + e.excludedIPRanges = []*net.IPNet{nw} + return nil + } +} + +func AddExcludedCIDR(cidr string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + _, nw, err := net.ParseCIDR(cidr) + if err != nil { + return errors.Errorf("cannot parse excluded CIDR constraint %q", cidr) + } + e.excludedIPRanges = append(e.excludedIPRanges, nw) + return nil + } +} + +func WithExcludedIP(ip net.IP) NamePolicyOption { + return func(e *NamePolicyEngine) error { + var mask net.IPMask + if !isIPv4(ip) { + mask = net.CIDRMask(128, 128) + } else { + mask = net.CIDRMask(32, 32) + } + nw := &net.IPNet{ + IP: ip, + Mask: mask, + } + e.excludedIPRanges = []*net.IPNet{nw} + return nil + } +} + +func AddExcludedIP(ip net.IP) NamePolicyOption { + return func(e *NamePolicyEngine) error { + var mask net.IPMask + if !isIPv4(ip) { + mask = net.CIDRMask(128, 128) + } else { + mask = net.CIDRMask(32, 32) + } + nw := &net.IPNet{ + IP: ip, + Mask: mask, + } + e.excludedIPRanges = append(e.excludedIPRanges, nw) + return nil + } +} + +func WithPermittedEmailAddresses(emailAddresses []string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + for _, email := range emailAddresses { + if err := validateEmailConstraint(email); err != nil { + return err + } + } + e.permittedEmailAddresses = emailAddresses + return nil + } +} + +func AddPermittedEmailAddresses(emailAddresses []string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + for _, email := range emailAddresses { + if err := validateEmailConstraint(email); err != nil { + return err + } + } + e.permittedEmailAddresses = append(e.permittedEmailAddresses, emailAddresses...) + return nil + } +} + +func WithExcludedEmailAddresses(emailAddresses []string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + for _, email := range emailAddresses { + if err := validateEmailConstraint(email); err != nil { + return err + } + } + e.excludedEmailAddresses = emailAddresses + return nil + } +} + +func AddExcludedEmailAddresses(emailAddresses []string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + for _, email := range emailAddresses { + if err := validateEmailConstraint(email); err != nil { + return err + } + } + e.excludedEmailAddresses = append(e.excludedEmailAddresses, emailAddresses...) + return nil + } +} + +func WithPermittedEmailAddress(emailAddress string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + if err := validateEmailConstraint(emailAddress); err != nil { + return err + } + e.permittedEmailAddresses = []string{emailAddress} + return nil + } +} + +func AddPermittedEmailAddress(emailAddress string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + if err := validateEmailConstraint(emailAddress); err != nil { + return err + } + e.permittedEmailAddresses = append(e.permittedEmailAddresses, emailAddress) + return nil + } +} + +func WithExcludedEmailAddress(emailAddress string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + if err := validateEmailConstraint(emailAddress); err != nil { + return err + } + e.excludedEmailAddresses = []string{emailAddress} + return nil + } +} + +func AddExcludedEmailAddress(emailAddress string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + if err := validateEmailConstraint(emailAddress); err != nil { + return err + } + e.excludedEmailAddresses = append(e.excludedEmailAddresses, emailAddress) + return nil + } +} + +func WithPermittedURIDomains(uriDomains []string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + for _, domain := range uriDomains { + if err := validateURIDomainConstraint(domain); err != nil { + return err + } + } + e.permittedURIDomains = uriDomains + return nil + } +} + +func AddPermittedURIDomains(uriDomains []string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + for _, domain := range uriDomains { + if err := validateURIDomainConstraint(domain); err != nil { + return err + } + } + e.permittedURIDomains = append(e.permittedURIDomains, uriDomains...) + return nil + } +} + +func WithPermittedURIDomain(uriDomain string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + if err := validateURIDomainConstraint(uriDomain); err != nil { + return err + } + e.permittedURIDomains = []string{uriDomain} + return nil + } +} + +func AddPermittedURIDomain(uriDomain string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + if err := validateURIDomainConstraint(uriDomain); err != nil { + return err + } + e.permittedURIDomains = append(e.permittedURIDomains, uriDomain) + return nil + } +} + +func WithExcludedURIDomains(uriDomains []string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + for _, domain := range uriDomains { + if err := validateURIDomainConstraint(domain); err != nil { + return err + } + } + e.excludedURIDomains = uriDomains + return nil + } +} + +func AddExcludedURIDomains(uriDomains []string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + for _, domain := range uriDomains { + if err := validateURIDomainConstraint(domain); err != nil { + return err + } + } + e.excludedURIDomains = append(e.excludedURIDomains, uriDomains...) + return nil + } +} + +func WithExcludedURIDomain(uriDomain string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + if err := validateURIDomainConstraint(uriDomain); err != nil { + return err + } + e.excludedURIDomains = []string{uriDomain} + return nil + } +} + +func AddExcludedURIDomain(uriDomain string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + if err := validateURIDomainConstraint(uriDomain); err != nil { + return err + } + e.excludedURIDomains = append(e.excludedURIDomains, uriDomain) + return nil + } +} + +func validateDNSDomainConstraint(domain string) error { + if _, ok := domainToReverseLabels(domain); !ok { + return errors.Errorf("cannot parse permitted domain constraint %q", domain) + } + return nil +} + +func validateEmailConstraint(constraint string) error { + if strings.Contains(constraint, "@") { + _, ok := parseRFC2821Mailbox(constraint) + if !ok { + return fmt.Errorf("cannot parse email constraint %q", constraint) + } + } + _, ok := domainToReverseLabels(constraint) + if !ok { + return fmt.Errorf("cannot parse email domain constraint %q", constraint) + } + return nil +} + +func validateURIDomainConstraint(constraint string) error { + _, ok := domainToReverseLabels(constraint) + if !ok { + return fmt.Errorf("cannot parse URI domain constraint %q", constraint) + } + return nil +} diff --git a/policy/x509/x509.go b/policy/x509/x509.go new file mode 100755 index 00000000..c8d4dfb2 --- /dev/null +++ b/policy/x509/x509.go @@ -0,0 +1,565 @@ +package x509policy + +import ( + "bytes" + "crypto/x509" + "fmt" + "net" + "net/url" + "reflect" + "strings" + + "github.com/pkg/errors" + "go.step.sm/crypto/x509util" +) + +type CertificateInvalidError struct { + Reason x509.InvalidReason + Detail string +} + +func (e CertificateInvalidError) Error() string { + switch e.Reason { + // TODO: include logical errors for this package; exlude ones that don't make sense for its current use case? + // TODO: currently only CANotAuthorizedForThisName is used by this package; we're not checking the other things in CSRs in this package. + case x509.NotAuthorizedToSign: + return "not authorized to sign other certificates" // TODO: this one doesn't make sense for this pkg + case x509.Expired: + return "csr has expired or is not yet valid: " + e.Detail + case x509.CANotAuthorizedForThisName: + return "not authorized to sign for this name: " + e.Detail + case x509.CANotAuthorizedForExtKeyUsage: + return "not authorized for an extended key usage: " + e.Detail + case x509.TooManyIntermediates: + return "too many intermediates for path length constraint" + case x509.IncompatibleUsage: + return "csr specifies an incompatible key usage" + case x509.NameMismatch: + return "issuer name does not match subject from issuing certificate" + case x509.NameConstraintsWithoutSANs: + return "issuer has name constraints but csr doesn't have a SAN extension" + case x509.UnconstrainedName: + return "issuer has name constraints but csr contains unknown or unconstrained name: " + e.Detail + } + return "unknown error" +} + +// NamePolicyEngine can be used to check that a CSR or Certificate meets all allowed and +// denied names before a CA creates and/or signs the Certificate. +// TODO(hs): the x509 RFC also defines name checks on directory name; support that? +// TODO(hs): implement Stringer interface: describe the contents of the NamePolicyEngine? +type NamePolicyEngine struct { + options []NamePolicyOption + permittedDNSDomains []string + excludedDNSDomains []string + permittedIPRanges []*net.IPNet + excludedIPRanges []*net.IPNet + permittedEmailAddresses []string + excludedEmailAddresses []string + permittedURIDomains []string + excludedURIDomains []string +} + +// NewNamePolicyEngine creates a new NamePolicyEngine with NamePolicyOptions +func New(opts ...NamePolicyOption) (*NamePolicyEngine, error) { + + e := &NamePolicyEngine{} + e.options = append(e.options, opts...) + for _, option := range e.options { + if err := option(e); err != nil { + return nil, err + } + } + + return e, nil +} + +// AreCertificateNamesAllowed verifies that all SANs in a Certificate are allowed. +func (e *NamePolicyEngine) AreCertificateNamesAllowed(cert *x509.Certificate) (bool, error) { + if err := e.validateNames(cert.DNSNames, cert.IPAddresses, cert.EmailAddresses, cert.URIs); err != nil { + return false, err + } + return true, nil +} + +// AreCSRNamesAllowed verifies that all names in the CSR are allowed. +func (e *NamePolicyEngine) AreCSRNamesAllowed(csr *x509.CertificateRequest) (bool, error) { + if err := e.validateNames(csr.DNSNames, csr.IPAddresses, csr.EmailAddresses, csr.URIs); err != nil { + return false, err + } + return true, nil +} + +// AreSANSAllowed verifies that all names in the slice of SANs are allowed. +// The SANs are first split into DNS names, IPs, email addresses and URIs. +func (e *NamePolicyEngine) AreSANsAllowed(sans []string) (bool, error) { + dnsNames, ips, emails, uris := x509util.SplitSANs(sans) + if err := e.validateNames(dnsNames, ips, emails, uris); err != nil { + return false, err + } + return true, nil +} + +// IsDNSAllowed verifies a single DNS domain is allowed. +func (e *NamePolicyEngine) IsDNSAllowed(dns string) (bool, error) { + if err := e.validateNames([]string{dns}, []net.IP{}, []string{}, []*url.URL{}); err != nil { + return false, err + } + return true, nil +} + +// IsIPAllowed verifies a single IP domain is allowed. +func (e *NamePolicyEngine) IsIPAllowed(ip net.IP) (bool, error) { + if err := e.validateNames([]string{}, []net.IP{ip}, []string{}, []*url.URL{}); err != nil { + return false, err + } + return true, nil +} + +// validateNames verifies that all names are allowed. +// Its logic follows that of (a large part of) the (c *Certificate) isValid() function +// in https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailAddresses []string, uris []*url.URL) error { + + // TODO: return our own type of error? + + // TODO: set limit on total of all names? In x509 there's a limit on the number of comparisons + // that protects the CA from a DoS (i.e. many heavy comparisons). The x509 implementation takes + // this number as a total of all checks and keeps a (pointer to a) counter of the number of checks + // executed so far. + + // TODO: gather all errors, or return early? Currently we return early on the first wrong name; check might fail for multiple names. + // Perhaps make that an option? + for _, dns := range dnsNames { + if _, ok := domainToReverseLabels(dns); !ok { + return errors.Errorf("cannot parse dns %q", dns) + } + if err := checkNameConstraints("dns", dns, dns, + func(parsedName, constraint interface{}) (bool, error) { + return matchDomainConstraint(parsedName.(string), constraint.(string)) + }, e.permittedDNSDomains, e.excludedDNSDomains); err != nil { + return err + } + } + + for _, ip := range ips { + if err := checkNameConstraints("ip", ip.String(), ip, + func(parsedName, constraint interface{}) (bool, error) { + return matchIPConstraint(parsedName.(net.IP), constraint.(*net.IPNet)) + }, e.permittedIPRanges, e.excludedIPRanges); err != nil { + return err + } + } + + for _, email := range emailAddresses { + mailbox, ok := parseRFC2821Mailbox(email) + if !ok { + return fmt.Errorf("cannot parse rfc822Name %q", mailbox) + } + if err := checkNameConstraints("email", email, mailbox, + func(parsedName, constraint interface{}) (bool, error) { + return matchEmailConstraint(parsedName.(rfc2821Mailbox), constraint.(string)) + }, e.permittedEmailAddresses, e.excludedEmailAddresses); err != nil { + return err + } + } + + for _, uri := range uris { + if err := checkNameConstraints("uri", uri.String(), uri, + func(parsedName, constraint interface{}) (bool, error) { + return matchURIConstraint(parsedName.(*url.URL), constraint.(string)) + }, e.permittedURIDomains, e.excludedURIDomains); err != nil { + return err + } + } + + // TODO: when the error is not nil and returned up in the above, we can add + // additional context to it (i.e. the cert or csr that was inspected). + + // TODO(hs): validate other types of SANs? The Go std library skips those. + // These could be custom checkers. + + // if all checks out, all SANs are allowed + return nil +} + +// checkNameConstraints checks that c permits a child certificate to claim the +// given name, of type nameType. The argument parsedName contains the parsed +// form of name, suitable for passing to the match function. The total number +// of comparisons is tracked in the given count and should not exceed the given +// limit. +// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +func checkNameConstraints( + nameType string, + name string, + parsedName interface{}, + match func(parsedName, constraint interface{}) (match bool, err error), + permitted, excluded interface{}) error { + + excludedValue := reflect.ValueOf(excluded) + + // *count += excludedValue.Len() + // if *count > maxConstraintComparisons { + // return x509.CertificateInvalidError{c, x509.TooManyConstraints, ""} + // } + + // TODO: fix the errors; return our own, because we don't have cert ... + + for i := 0; i < excludedValue.Len(); i++ { + constraint := excludedValue.Index(i).Interface() + match, err := match(parsedName, constraint) + if err != nil { + return CertificateInvalidError{ + Reason: x509.CANotAuthorizedForThisName, + Detail: err.Error(), + } + } + + if match { + return CertificateInvalidError{ + Reason: x509.CANotAuthorizedForThisName, + Detail: fmt.Sprintf("%s %q is excluded by constraint %q", nameType, name, constraint), + } + } + } + + permittedValue := reflect.ValueOf(permitted) + + // *count += permittedValue.Len() + // if *count > maxConstraintComparisons { + // return x509.CertificateInvalidError{c, x509.TooManyConstraints, ""} + // } + + ok := true + for i := 0; i < permittedValue.Len(); i++ { + constraint := permittedValue.Index(i).Interface() + var err error + if ok, err = match(parsedName, constraint); err != nil { + return CertificateInvalidError{ + Reason: x509.CANotAuthorizedForThisName, + Detail: err.Error(), + } + } + + if ok { + break + } + } + + if !ok { + return CertificateInvalidError{ + Reason: x509.CANotAuthorizedForThisName, + Detail: fmt.Sprintf("%s %q is not permitted by any constraint", nameType, name), + } + } + + return nil +} + +// domainToReverseLabels converts a textual domain name like foo.example.com to +// the list of labels in reverse order, e.g. ["com", "example", "foo"]. +// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +func domainToReverseLabels(domain string) (reverseLabels []string, ok bool) { + for len(domain) > 0 { + if i := strings.LastIndexByte(domain, '.'); i == -1 { + reverseLabels = append(reverseLabels, domain) + domain = "" + } else { + reverseLabels = append(reverseLabels, domain[i+1:]) + domain = domain[:i] + } + } + + if len(reverseLabels) > 0 && reverseLabels[0] == "" { + // An empty label at the end indicates an absolute value. + return nil, false + } + + for _, label := range reverseLabels { + if label == "" { + // Empty labels are otherwise invalid. + return nil, false + } + + for _, c := range label { + if c < 33 || c > 126 { + // Invalid character. + return nil, false + } + } + } + + return reverseLabels, true +} + +// rfc2821Mailbox represents a “mailbox” (which is an email address to most +// people) by breaking it into the “local” (i.e. before the '@') and “domain” +// parts. +// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +type rfc2821Mailbox struct { + local, domain string +} + +// parseRFC2821Mailbox parses an email address into local and domain parts, +// based on the ABNF for a “Mailbox” from RFC 2821. According to RFC 5280, +// Section 4.2.1.6 that's correct for an rfc822Name from a certificate: “The +// format of an rfc822Name is a "Mailbox" as defined in RFC 2821, Section 4.1.2”. +// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +func parseRFC2821Mailbox(in string) (mailbox rfc2821Mailbox, ok bool) { + if in == "" { + return mailbox, false + } + + localPartBytes := make([]byte, 0, len(in)/2) + + if in[0] == '"' { + // Quoted-string = DQUOTE *qcontent DQUOTE + // non-whitespace-control = %d1-8 / %d11 / %d12 / %d14-31 / %d127 + // qcontent = qtext / quoted-pair + // qtext = non-whitespace-control / + // %d33 / %d35-91 / %d93-126 + // quoted-pair = ("\" text) / obs-qp + // text = %d1-9 / %d11 / %d12 / %d14-127 / obs-text + // + // (Names beginning with “obs-” are the obsolete syntax from RFC 2822, + // Section 4. Since it has been 16 years, we no longer accept that.) + in = in[1:] + QuotedString: + for { + if in == "" { + return mailbox, false + } + c := in[0] + in = in[1:] + + switch { + case c == '"': + break QuotedString + + case c == '\\': + // quoted-pair + if in == "" { + return mailbox, false + } + if in[0] == 11 || + in[0] == 12 || + (1 <= in[0] && in[0] <= 9) || + (14 <= in[0] && in[0] <= 127) { + localPartBytes = append(localPartBytes, in[0]) + in = in[1:] + } else { + return mailbox, false + } + + case c == 11 || + c == 12 || + // Space (char 32) is not allowed based on the + // BNF, but RFC 3696 gives an example that + // assumes that it is. Several “verified” + // errata continue to argue about this point. + // We choose to accept it. + c == 32 || + c == 33 || + c == 127 || + (1 <= c && c <= 8) || + (14 <= c && c <= 31) || + (35 <= c && c <= 91) || + (93 <= c && c <= 126): + // qtext + localPartBytes = append(localPartBytes, c) + + default: + return mailbox, false + } + } + } else { + // Atom ("." Atom)* + NextChar: + for len(in) > 0 { + // atext from RFC 2822, Section 3.2.4 + c := in[0] + + switch { + case c == '\\': + // Examples given in RFC 3696 suggest that + // escaped characters can appear outside of a + // quoted string. Several “verified” errata + // continue to argue the point. We choose to + // accept it. + in = in[1:] + if in == "" { + return mailbox, false + } + fallthrough + + case ('0' <= c && c <= '9') || + ('a' <= c && c <= 'z') || + ('A' <= c && c <= 'Z') || + c == '!' || c == '#' || c == '$' || c == '%' || + c == '&' || c == '\'' || c == '*' || c == '+' || + c == '-' || c == '/' || c == '=' || c == '?' || + c == '^' || c == '_' || c == '`' || c == '{' || + c == '|' || c == '}' || c == '~' || c == '.': + localPartBytes = append(localPartBytes, in[0]) + in = in[1:] + + default: + break NextChar + } + } + + if len(localPartBytes) == 0 { + return mailbox, false + } + + // From RFC 3696, Section 3: + // “period (".") may also appear, but may not be used to start + // or end the local part, nor may two or more consecutive + // periods appear.” + twoDots := []byte{'.', '.'} + if localPartBytes[0] == '.' || + localPartBytes[len(localPartBytes)-1] == '.' || + bytes.Contains(localPartBytes, twoDots) { + return mailbox, false + } + } + + if in == "" || in[0] != '@' { + return mailbox, false + } + in = in[1:] + + // The RFC species a format for domains, but that's known to be + // violated in practice so we accept that anything after an '@' is the + // domain part. + if _, ok := domainToReverseLabels(in); !ok { + return mailbox, false + } + + mailbox.local = string(localPartBytes) + mailbox.domain = in + return mailbox, true +} + +// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +func matchDomainConstraint(domain, constraint string) (bool, error) { + // The meaning of zero length constraints is not specified, but this + // code follows NSS and accepts them as matching everything. + if constraint == "" { + return true, nil + } + + domainLabels, ok := domainToReverseLabels(domain) + if !ok { + return false, fmt.Errorf("cannot parse domain %q", domain) + } + + // RFC 5280 says that a leading period in a domain name means that at + // least one label must be prepended, but only for URI and email + // constraints, not DNS constraints. The code also supports that + // behavior for DNS constraints. + + mustHaveSubdomains := false + if constraint[0] == '.' { + mustHaveSubdomains = true + constraint = constraint[1:] + } + + constraintLabels, ok := domainToReverseLabels(constraint) + if !ok { + return false, fmt.Errorf("cannot parse domain %q", constraint) + } + + if len(domainLabels) < len(constraintLabels) || + (mustHaveSubdomains && len(domainLabels) == len(constraintLabels)) { + return false, nil + } + + for i, constraintLabel := range constraintLabels { + if !strings.EqualFold(constraintLabel, domainLabels[i]) { + return false, nil + } + } + + return true, nil +} + +// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +func matchIPConstraint(ip net.IP, constraint *net.IPNet) (bool, error) { + + // TODO(hs): this is code from Go library, but I got some unexpected result: + // with permitted net 127.0.0.0/24, 127.0.0.1 is NOT allowed. When parsing 127.0.0.1 as net.IP + // which is in the IPAddresses slice, the underlying length is 16. The contraint.IP has a length + // of 4 instead. I currently don't believe that this is a bug in Go now, but why is it like that? + // Is there a difference because we're not operating on a sans []string slice? Or is the Go + // implementation stricter regarding IPv4 vs. IPv6? I've been bitten by some unfortunate differences + // between the two before (i.e. IPv4 in IPv6; IP SANS in ACME) + // if len(ip) != len(constraint.IP) { + // return false, nil + // } + + // for i := range ip { + // if mask := constraint.Mask[i]; ip[i]&mask != constraint.IP[i]&mask { + // return false, nil + // } + // } + + // if isIPv4(ip) != isIPv4(constraint.IP) { // TODO(hs): this check seems to do what the above intended to do? + // return false, nil + // } + + contained := constraint.Contains(ip) // TODO(hs): validate that this is the correct behavior. + + return contained, nil +} + +func isIPv4(ip net.IP) bool { + return ip.To4() != nil +} + +// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +func matchEmailConstraint(mailbox rfc2821Mailbox, constraint string) (bool, error) { + // If the constraint contains an @, then it specifies an exact mailbox name. + if strings.Contains(constraint, "@") { + constraintMailbox, ok := parseRFC2821Mailbox(constraint) + if !ok { + return false, fmt.Errorf("cannot parse constraint %q", constraint) + } + return mailbox.local == constraintMailbox.local && strings.EqualFold(mailbox.domain, constraintMailbox.domain), nil + } + + // Otherwise the constraint is like a DNS constraint of the domain part + // of the mailbox. + return matchDomainConstraint(mailbox.domain, constraint) +} + +// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +func matchURIConstraint(uri *url.URL, constraint string) (bool, error) { + // From RFC 5280, Section 4.2.1.10: + // “a uniformResourceIdentifier that does not include an authority + // component with a host name specified as a fully qualified domain + // name (e.g., if the URI either does not include an authority + // component or includes an authority component in which the host name + // is specified as an IP address), then the application MUST reject the + // certificate.” + + host := uri.Host + if host == "" { + return false, fmt.Errorf("URI with empty host (%q) cannot be matched against constraints", uri.String()) + } + + if strings.Contains(host, ":") && !strings.HasSuffix(host, "]") { + var err error + host, _, err = net.SplitHostPort(uri.Host) + if err != nil { + return false, err + } + } + + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") || + net.ParseIP(host) != nil { + return false, fmt.Errorf("URI with IP (%q) cannot be matched against constraints", uri.String()) + } + + return matchDomainConstraint(host, constraint) +} diff --git a/policy/x509/x509_test.go b/policy/x509/x509_test.go new file mode 100755 index 00000000..99c371ff --- /dev/null +++ b/policy/x509/x509_test.go @@ -0,0 +1,299 @@ +package x509policy + +import ( + "crypto/x509" + "net" + "net/url" + "testing" + + "github.com/smallstep/assert" +) + +func TestGuard_IsAllowed(t *testing.T) { + type fields struct { + permittedDNSDomains []string + excludedDNSDomains []string + permittedIPRanges []*net.IPNet + excludedIPRanges []*net.IPNet + permittedEmailAddresses []string + excludedEmailAddresses []string + permittedURIDomains []string + excludedURIDomains []string + } + tests := []struct { + name string + fields fields + csr *x509.CertificateRequest + want bool + wantErr bool + }{ + { + name: "fail/dns-permitted", + fields: fields{ + permittedDNSDomains: []string{".local"}, + }, + csr: &x509.CertificateRequest{ + DNSNames: []string{"www.example.com"}, + }, + want: false, + wantErr: true, + }, + { + name: "fail/dns-excluded", + fields: fields{ + excludedDNSDomains: []string{"example.com"}, + }, + csr: &x509.CertificateRequest{ + DNSNames: []string{"www.example.com"}, + }, + want: false, + wantErr: true, + }, + { + name: "fail/ipv4-permitted", + fields: fields{ + permittedIPRanges: []*net.IPNet{ + { + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + }, + }, + csr: &x509.CertificateRequest{ + IPAddresses: []net.IP{net.ParseIP("1.1.1.1")}, + }, + want: false, + wantErr: true, + }, + { + name: "fail/ipv4-excluded", + fields: fields{ + excludedIPRanges: []*net.IPNet{ + { + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + }, + }, + csr: &x509.CertificateRequest{ + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + }, + want: false, + wantErr: true, + }, + { + name: "fail/ipv6-permitted", + fields: fields{ + permittedIPRanges: []*net.IPNet{ + { + IP: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + Mask: net.CIDRMask(120, 128), + }, + }, + }, + csr: &x509.CertificateRequest{ + IPAddresses: []net.IP{net.ParseIP("3001:0db8:85a3:0000:0000:8a2e:0370:7334")}, + }, + want: false, + wantErr: true, + }, + { + name: "fail/ipv6-excluded", + fields: fields{ + excludedIPRanges: []*net.IPNet{ + { + IP: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + Mask: net.CIDRMask(120, 128), + }, + }, + }, + csr: &x509.CertificateRequest{ + IPAddresses: []net.IP{net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}, + }, + want: false, + wantErr: true, + }, + { + name: "fail/mail-permitted", + fields: fields{ + permittedEmailAddresses: []string{"example.local"}, + }, + csr: &x509.CertificateRequest{ + EmailAddresses: []string{"mail@example.com"}, + }, + want: false, + wantErr: true, + }, + { + name: "fail/mail-excluded", + fields: fields{ + excludedEmailAddresses: []string{"example.local"}, + }, + csr: &x509.CertificateRequest{ + EmailAddresses: []string{"mail@example.local"}, + }, + want: false, + wantErr: true, + }, + { + name: "fail/uri-permitted", + fields: fields{ + permittedURIDomains: []string{".example.com"}, + }, + csr: &x509.CertificateRequest{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.local", + }, + }, + }, + want: false, + wantErr: true, + }, + { + name: "fail/uri-excluded", + fields: fields{ + excludedURIDomains: []string{".example.local"}, + }, + csr: &x509.CertificateRequest{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.local", + }, + }, + }, + want: false, + wantErr: true, + }, + { + name: "ok/no-constraints", + fields: fields{}, + csr: &x509.CertificateRequest{ + DNSNames: []string{"www.example.com"}, + }, + want: true, + wantErr: false, + }, + { + name: "ok/dns", + fields: fields{ + permittedDNSDomains: []string{".local"}, + }, + csr: &x509.CertificateRequest{ + DNSNames: []string{"example.local"}, + }, + want: true, + wantErr: false, + }, + { + name: "ok/ipv4", + fields: fields{ + permittedIPRanges: []*net.IPNet{ + { + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + }, + }, + csr: &x509.CertificateRequest{ + IPAddresses: []net.IP{net.ParseIP("127.0.0.20")}, + }, + want: true, + wantErr: false, + }, + { + name: "ok/ipv6", + fields: fields{ + permittedIPRanges: []*net.IPNet{ + { + IP: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + Mask: net.CIDRMask(120, 128), + }, + }, + }, + csr: &x509.CertificateRequest{ + IPAddresses: []net.IP{net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7339")}, + }, + want: true, + wantErr: false, + }, + { + name: "ok/mail", + fields: fields{ + permittedEmailAddresses: []string{"example.local"}, + }, + csr: &x509.CertificateRequest{ + EmailAddresses: []string{"mail@example.local"}, + }, + want: true, + wantErr: false, + }, + { + name: "ok/uri", + fields: fields{ + permittedURIDomains: []string{".example.com"}, + }, + csr: &x509.CertificateRequest{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.com", + }, + }, + }, + want: true, + wantErr: false, + }, + { + name: "ok/combined-simple", + fields: fields{ + permittedDNSDomains: []string{".local"}, + permittedIPRanges: []*net.IPNet{{IP: net.ParseIP("127.0.0.1"), Mask: net.IPv4Mask(255, 255, 255, 0)}}, + permittedEmailAddresses: []string{"example.local"}, + permittedURIDomains: []string{".example.local"}, + }, + csr: &x509.CertificateRequest{ + DNSNames: []string{"example.local"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + EmailAddresses: []string{"mail@example.local"}, + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.local", + }, + }, + }, + want: true, + wantErr: false, + }, + // TODO: more complex uses cases that combine multiple names + // TODO: check errors (reasons) are as expected + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := &NamePolicyEngine{ + permittedDNSDomains: tt.fields.permittedDNSDomains, + excludedDNSDomains: tt.fields.excludedDNSDomains, + permittedIPRanges: tt.fields.permittedIPRanges, + excludedIPRanges: tt.fields.excludedIPRanges, + permittedEmailAddresses: tt.fields.permittedEmailAddresses, + excludedEmailAddresses: tt.fields.excludedEmailAddresses, + permittedURIDomains: tt.fields.permittedURIDomains, + excludedURIDomains: tt.fields.excludedURIDomains, + } + got, err := g.AreCSRNamesAllowed(tt.csr) + if (err != nil) != tt.wantErr { + t.Errorf("Guard.IsAllowed() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err != nil { + assert.NotEquals(t, "", err.Error()) // TODO(hs): make this a complete equality check + } + if got != tt.want { + t.Errorf("Guard.IsAllowed() = %v, want %v", got, tt.want) + } + }) + } +}