From 3fb5e57f12f49b0b2bae2682f10426250b79e8df Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 4 Mar 2022 10:56:09 -0800 Subject: [PATCH 01/44] Upgrade nosql package The new version of the package allows filtering out database drivers using Go tags. --- CHANGELOG.md | 4 ++++ go.mod | 2 +- go.sum | 6 ++---- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 28c2f141..983e5cf9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased - 0.18.2] - DATE ### Added - Added `subscriptionIDs` and `objectIDs` filters to the Azure provisioner. +- [NoSQL](https://github.com/smallstep/nosql/pull/21) package allows filtering + out database drivers using Go tags. For example, using the Go flag + `--tags=nobadger,nobbolt,nomysql` will only compile `step-ca` with the pgx + driver for PostgreSQL. ### Changed - IPv6 addresses are normalized as IP addresses instead of hostnames. - More descriptive JWK decryption error message. diff --git a/go.mod b/go.mod index 46fe260c..e6696529 100644 --- a/go.mod +++ b/go.mod @@ -30,7 +30,7 @@ require ( github.com/sirupsen/logrus v1.8.1 github.com/slackhq/nebula v1.5.2 github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 - github.com/smallstep/nosql v0.3.10 + github.com/smallstep/nosql v0.4.0 github.com/urfave/cli v1.22.4 go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 go.step.sm/cli-utils v0.7.0 diff --git a/go.sum b/go.sum index 1cd8e2e7..123df6e4 100644 --- a/go.sum +++ b/go.sum @@ -607,8 +607,8 @@ github.com/slackhq/nebula v1.5.2/go.mod h1:xaCM6wqbFk/NRmmUe1bv88fWBm3a1UioXJVIp github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5/go.mod h1:TC9A4+RjIOS+HyTH7wG17/gSqVv95uDw2J64dQZx7RE= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc= -github.com/smallstep/nosql v0.3.10 h1:Xs7nueSl250GYb5XdfbzR8w+xPbvF6/oSw6pryY7gJI= -github.com/smallstep/nosql v0.3.10/go.mod h1:yKZT5h7cdIVm6wEKM9+jN5dgK80Hljpuy8HNsnI7Gzo= +github.com/smallstep/nosql v0.4.0 h1:Go3WYwttUuvwqMtFiiU4g7kBIlY+hR0bIZAqVdakQ3M= +github.com/smallstep/nosql v0.4.0/go.mod h1:yKZT5h7cdIVm6wEKM9+jN5dgK80Hljpuy8HNsnI7Gzo= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= @@ -685,8 +685,6 @@ go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/ go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= go.step.sm/crypto v0.15.0 h1:VioBln+x3+RoejgeBhvxkLGVYdWRy6PFiAaUUN29/E0= go.step.sm/crypto v0.15.0/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g= -go.step.sm/linkedca v0.9.2 h1:CpAkd174sLXFfrOZrbPEiTzik91QRj3+L0omsiwsiok= -go.step.sm/linkedca v0.9.2/go.mod h1:5uTRjozEGSPAZal9xJqlaD38cvJcLe3o1VAFVjqcORo= go.step.sm/linkedca v0.10.0 h1:+bqymMRulHYkVde4l16FnqFVskoS6HCWJN5Z5cxAqF8= go.step.sm/linkedca v0.10.0/go.mod h1:5uTRjozEGSPAZal9xJqlaD38cvJcLe3o1VAFVjqcORo= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= From e6992442916cb2e1f8c56fb01a499e7836b8219c Mon Sep 17 00:00:00 2001 From: vijayjt <2975049+vijayjt@users.noreply.github.com> Date: Mon, 7 Mar 2022 11:24:58 +0000 Subject: [PATCH 02/44] Support Azure tokens from managed identities not associated with a VM --- authority/provisioner/azure.go | 14 +++++++++++--- authority/provisioner/azure_test.go | 22 +++++++++++----------- authority/provisioner/utils_test.go | 12 +++++++++--- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index 384617e0..391034bc 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -30,7 +30,7 @@ const azureDefaultAudience = "https://management.azure.com/" // azureXMSMirIDRegExp is the regular expression used to parse the xms_mirid claim. // Using case insensitive as resourceGroups appears as resourcegroups. -var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.Compute/virtualMachines/([^/]+)$`) +var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`) type azureConfig struct { oidcDiscoveryURL string @@ -263,11 +263,19 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, str } re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID) - if len(re) != 4 { + if len(re) != 5 { return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID) } + + var subscription, group, name string identityObjectID := claims.ObjectID - subscription, group, name := re[1], re[2], re[3] + + if strings.Contains(claims.XMSMirID, "virtualMachines") { + subscription, group, name = re[1], re[2], re[4] + } else { + // This is not a VM resource ID so we don't have the VM name so set that to the empty string + subscription, group, name = re[1], re[2], "" + } return &claims, name, group, subscription, identityObjectID, nil } diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index 4ab734d5..69f98502 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -95,7 +95,7 @@ func TestAzure_GetIdentityToken(t *testing.T) { assert.FatalError(t, err) t1, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, - p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) @@ -237,7 +237,7 @@ func TestAzure_authorizeToken(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, - p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), jwk) assert.FatalError(t, err) return test{ @@ -252,7 +252,7 @@ func TestAzure_authorizeToken(t *testing.T) { assert.FatalError(t, err) defer srv.Close() tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience, - p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ @@ -267,7 +267,7 @@ func TestAzure_authorizeToken(t *testing.T) { assert.FatalError(t, err) defer srv.Close() tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, - "foo", "subscriptionID", "resourceGroup", "virtualMachine", + "foo", "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ @@ -321,7 +321,7 @@ func TestAzure_authorizeToken(t *testing.T) { assert.FatalError(t, err) defer srv.Close() tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, - p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ @@ -437,28 +437,28 @@ func TestAzure_AuthorizeSign(t *testing.T) { assert.FatalError(t, err) t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, - p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failIssuer, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience, - p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failAudience, err := generateAzureToken("subject", p1.oidcConfig.Issuer, "bad-audience", - p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failExp, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, - p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failNbf, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, - p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failKey, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, - p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), badKey) assert.FatalError(t, err) diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index fe2678fc..d0992f0a 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -671,7 +671,7 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) { w.Header().Add("Cache-Control", "max-age=5") writeJSON(w, getPublic(az.keyStore.keySet)) case "/metadata/identity/oauth2/token": - tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", time.Now(), &az.keyStore.keySet.Keys[0]) + tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &az.keyStore.keySet.Keys[0]) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } else { @@ -1009,7 +1009,7 @@ func generateAWSToken(p *AWS, sub, iss, aud, accountID, instanceID, privateIP, r return jose.Signed(sig).Claims(claims).CompactSerialize() } -func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, virtualMachine string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { +func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, resourceName string, resourceType string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { sig, err := jose.NewSigner( jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), @@ -1017,6 +1017,12 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, if err != nil { return "", err } + var xmsMirID string + if resourceType == "vm" { + xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, resourceName) + } else if resourceType == "uai" { + xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.ManagedIdentity/userAssignedIdentities/%s", subscriptionID, resourceGroup, resourceName) + } claims := azurePayload{ Claims: jose.Claims{ @@ -1034,7 +1040,7 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, ObjectID: "the-oid", TenantID: tenantID, Version: "the-version", - XMSMirID: fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, virtualMachine), + XMSMirID: xmsMirID, } return jose.Signed(sig).Claims(claims).CompactSerialize() } From 4822516d727fd24fc93a881bf735a3bdd460d30d Mon Sep 17 00:00:00 2001 From: vijayjt <2975049+vijayjt@users.noreply.github.com> Date: Mon, 7 Mar 2022 12:07:48 +0000 Subject: [PATCH 03/44] Remove redundant parameter type declaration --- authority/provisioner/utils_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index d0992f0a..01cdc6f1 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -1009,7 +1009,7 @@ func generateAWSToken(p *AWS, sub, iss, aud, accountID, instanceID, privateIP, r return jose.Signed(sig).Claims(claims).CompactSerialize() } -func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, resourceName string, resourceType string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { +func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, resourceName, resourceType string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { sig, err := jose.NewSigner( jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), From a3cda9c3d7a21e427f2e6752b2676e5b3977c895 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Mon, 7 Mar 2022 13:16:53 +0100 Subject: [PATCH 04/44] Add configuration for custom path segment To support SCEP clients that expect a specific path segment in a SCEP URL, a new "customPath" option was added to the SCEP provisioner configuration. The configuration can be used to set a specific path (segment) that the SCEP provisioner will respond to. --- authority/provisioner/scep.go | 16 ++++++++++++---- scep/api/api.go | 15 +++++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/authority/provisioner/scep.go b/authority/provisioner/scep.go index 5d67762c..05802ffb 100644 --- a/authority/provisioner/scep.go +++ b/authority/provisioner/scep.go @@ -26,10 +26,18 @@ type SCEP struct { // Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7 // at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63 // Defaults to 0, being DES-CBC - EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` - Options *Options `json:"options,omitempty"` - Claims *Claims `json:"claims,omitempty"` - claimer *Claimer + EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` + // CustomPath is used to specify a custom path on which the SCEP provisioner will be made + // available. By default a SCEP provisioner is available at + // https://
:/scep/ and requests performed looking similar + // to https://
:/scep/?operations=GetCACert. When CustomPath + // is set, the SCEP URL will be https://
:/scep//, + // resulting in SCEP clients that expect a specific path, such as "/pkiclient.exe", to be + // able to interact with the SCEP provisioner. + CustomPath string `json:"customPath,omitempty"` + Options *Options `json:"options,omitempty"` + Claims *Claims `json:"claims,omitempty"` + claimer *Claimer secretChallengePassword string encryptionAlgorithm int diff --git a/scep/api/api.go b/scep/api/api.go index 4f8d897b..9b48187a 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -66,7 +66,9 @@ func New(scepAuth scep.Interface) api.RouterHandler { // Route traffic and implement the Router interface. func (h *Handler) Route(r api.Router) { getLink := h.Auth.GetLinkExplicit + r.MethodFunc(http.MethodGet, getLink("{provisionerName}/{customPath}*", false, nil), h.lookupProvisioner(h.Get)) r.MethodFunc(http.MethodGet, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Get)) + r.MethodFunc(http.MethodPost, getLink("{provisionerName}/{customPath}*", false, nil), h.lookupProvisioner(h.Post)) r.MethodFunc(http.MethodPost, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Post)) } @@ -191,6 +193,13 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { return } + customPathParam := chi.URLParam(r, "customPath") + customPath, err := url.PathUnescape(customPathParam) + if err != nil { + api.WriteError(w, err) + return + } + p, err := h.Auth.LoadProvisionerByName(provisionerName) if err != nil { api.WriteError(w, err) @@ -203,6 +212,12 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { return } + configuredCustomPath := strings.Trim(prov.CustomPath, "/") + if customPath != configuredCustomPath { + api.WriteError(w, errors.Errorf("custom path requested '%s' is not the expected path '%s'", customPath, configuredCustomPath)) + return + } + ctx := r.Context() ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(prov)) next(w, r.WithContext(ctx)) From 7c541888ad281e0c6f669ce3963c0f305ac84a62 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 8 Mar 2022 13:26:07 +0100 Subject: [PATCH 05/44] Refactor configuration of allow/deny on authority level --- authority/authority.go | 21 +++ authority/config/config.go | 2 + authority/policy/options.go | 170 ++++++++++++++++++++++ authority/policy/policy.go | 134 +++++++++++++++++ authority/provisioner/acme.go | 6 +- authority/provisioner/aws.go | 9 +- authority/provisioner/azure.go | 9 +- authority/provisioner/gcp.go | 9 +- authority/provisioner/jwk.go | 13 +- authority/provisioner/k8sSA.go | 13 +- authority/provisioner/nebula.go | 13 +- authority/provisioner/oidc.go | 13 +- authority/provisioner/options.go | 32 ++-- authority/provisioner/policy.go | 156 -------------------- authority/provisioner/scep.go | 6 +- authority/provisioner/sign_options.go | 8 +- authority/provisioner/sign_ssh_options.go | 30 ++-- authority/provisioner/ssh_options.go | 92 ++++++------ authority/provisioner/x5c.go | 13 +- authority/ssh.go | 40 +++++ authority/tls.go | 19 +++ 21 files changed, 515 insertions(+), 293 deletions(-) create mode 100644 authority/policy/options.go create mode 100644 authority/policy/policy.go delete mode 100644 authority/provisioner/policy.go diff --git a/authority/authority.go b/authority/authority.go index f396c588..4eacfad7 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -16,6 +16,7 @@ import ( adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql" "github.com/smallstep/certificates/authority/administrator" "github.com/smallstep/certificates/authority/config" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/cas" casapi "github.com/smallstep/certificates/cas/apiv1" @@ -75,6 +76,11 @@ type Authority struct { sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) getIdentityFunc provisioner.GetIdentityFunc + // Policy engines + x509Policy policy.X509Policy + sshUserPolicy policy.UserPolicy + sshHostPolicy policy.HostPolicy + adminMutex sync.RWMutex } @@ -539,6 +545,21 @@ func (a *Authority) init() error { a.templates.Data["Step"] = tmplVars } + // Initialize the x509 allow/deny policy engine + if a.x509Policy, err = policy.NewX509PolicyEngine(a.config.AuthorityConfig.Policy.GetX509Options()); err != nil { + return err + } + + // // Initialize the SSH allow/deny policy engine for host certificates + if a.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(a.config.AuthorityConfig.Policy.GetSSHOptions()); err != nil { + return err + } + + // // Initialize the SSH allow/deny policy engine for user certificates + if a.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(a.config.AuthorityConfig.Policy.GetSSHOptions()); err != nil { + return err + } + // JWT numeric dates are seconds. a.startTime = time.Now().Truncate(time.Second) // Set flag indicating that initialization has been completed, and should diff --git a/authority/config/config.go b/authority/config/config.go index 589b5bbf..0f6120f9 100644 --- a/authority/config/config.go +++ b/authority/config/config.go @@ -8,6 +8,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" cas "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" @@ -90,6 +91,7 @@ type AuthConfig struct { Admins []*linkedca.Admin `json:"-"` Template *ASN1DN `json:"template,omitempty"` Claims *provisioner.Claims `json:"claims,omitempty"` + Policy *policy.Options `json:"policy,omitempty"` DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"` Backdate *provisioner.Duration `json:"backdate,omitempty"` EnableAdmin bool `json:"enableAdmin,omitempty"` diff --git a/authority/policy/options.go b/authority/policy/options.go new file mode 100644 index 00000000..f57f3bcf --- /dev/null +++ b/authority/policy/options.go @@ -0,0 +1,170 @@ +package policy + +type Options struct { + X509 *X509PolicyOptions `json:"x509,omitempty"` + SSH *SSHPolicyOptions `json:"ssh,omitempty"` +} + +func (o *Options) GetX509Options() *X509PolicyOptions { + if o == nil { + return nil + } + return o.X509 +} + +func (o *Options) GetSSHOptions() *SSHPolicyOptions { + if o == nil { + return nil + } + return o.SSH +} + +type X509PolicyOptionsInterface interface { + GetAllowedNameOptions() *X509NameOptions + GetDeniedNameOptions() *X509NameOptions +} + +type X509PolicyOptions struct { + // AllowedNames ... + AllowedNames *X509NameOptions `json:"allow,omitempty"` + + // DeniedNames ... + DeniedNames *X509NameOptions `json:"deny,omitempty"` +} + +// X509NameOptions models the X509 name policy configuration. +type X509NameOptions struct { + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ip,omitempty"` + EmailAddresses []string `json:"email,omitempty"` + URIDomains []string `json:"uri,omitempty"` +} + +// HasNames checks if the AllowedNameOptions has one or more +// names configured. +func (o *X509NameOptions) HasNames() bool { + return len(o.DNSDomains) > 0 || + len(o.IPRanges) > 0 || + len(o.EmailAddresses) > 0 || + len(o.URIDomains) > 0 +} + +type SSHPolicyOptionsInterface interface { + GetAllowedUserNameOptions() *SSHNameOptions + GetDeniedUserNameOptions() *SSHNameOptions + GetAllowedHostNameOptions() *SSHNameOptions + GetDeniedHostNameOptions() *SSHNameOptions +} + +type SSHPolicyOptions struct { + // User contains SSH user certificate options. + User *SSHUserCertificateOptions `json:"user,omitempty"` + + // Host contains SSH host certificate options. + Host *SSHHostCertificateOptions `json:"host,omitempty"` +} + +// GetAllowedNameOptions returns AllowedNames, which models the +// SANs that ... +func (o *X509PolicyOptions) GetAllowedNameOptions() *X509NameOptions { + if o == nil { + return nil + } + return o.AllowedNames +} + +// GetDeniedNameOptions returns the DeniedNames, which models the +// SANs that ... +func (o *X509PolicyOptions) GetDeniedNameOptions() *X509NameOptions { + if o == nil { + return nil + } + return o.DeniedNames +} + +func (o *SSHPolicyOptions) GetAllowedUserNameOptions() *SSHNameOptions { + if o == nil { + return nil + } + if o.User == nil { + return nil + } + return o.User.AllowedNames +} + +func (o *SSHPolicyOptions) GetDeniedUserNameOptions() *SSHNameOptions { + if o == nil { + return nil + } + if o.User == nil { + return nil + } + return o.User.DeniedNames +} + +func (o *SSHPolicyOptions) GetAllowedHostNameOptions() *SSHNameOptions { + if o == nil { + return nil + } + if o.Host == nil { + return nil + } + return o.Host.AllowedNames +} + +func (o *SSHPolicyOptions) GetDeniedHostNameOptions() *SSHNameOptions { + if o == nil { + return nil + } + if o.Host == nil { + return nil + } + return o.Host.DeniedNames +} + +// SSHUserCertificateOptions is a collection of SSH user certificate options. +type SSHUserCertificateOptions struct { + // AllowedNames contains the names the provisioner is authorized to sign + AllowedNames *SSHNameOptions `json:"allow,omitempty"` + // DeniedNames contains the names the provisioner is not authorized to sign + DeniedNames *SSHNameOptions `json:"deny,omitempty"` +} + +// SSHHostCertificateOptions is a collection of SSH host certificate options. +// It's an alias of SSHUserCertificateOptions, as the options are the same +// for both types of certificates. +type SSHHostCertificateOptions SSHUserCertificateOptions + +// SSHNameOptions models the SSH name policy configuration. +type SSHNameOptions struct { + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ip,omitempty"` + EmailAddresses []string `json:"email,omitempty"` + Principals []string `json:"principal,omitempty"` +} + +// GetAllowedNameOptions returns the AllowedSSHNameOptions, which models the +// names that a provisioner is authorized to sign SSH certificates for. +func (o *SSHUserCertificateOptions) GetAllowedNameOptions() *SSHNameOptions { + if o == nil { + return nil + } + return o.AllowedNames +} + +// GetDeniedNameOptions returns the DeniedSSHNameOptions, which models the +// names that a provisioner is NOT authorized to sign SSH certificates for. +func (o *SSHUserCertificateOptions) GetDeniedNameOptions() *SSHNameOptions { + if o == nil { + return nil + } + return o.DeniedNames +} + +// HasNames checks if the SSHNameOptions has one or more +// names configured. +func (o *SSHNameOptions) HasNames() bool { + return len(o.DNSDomains) > 0 || + len(o.EmailAddresses) > 0 || + len(o.Principals) > 0 +} diff --git a/authority/policy/policy.go b/authority/policy/policy.go new file mode 100644 index 00000000..403ac0b7 --- /dev/null +++ b/authority/policy/policy.go @@ -0,0 +1,134 @@ +package policy + +import ( + "fmt" + + "github.com/smallstep/certificates/policy" +) + +// X509Policy is an alias for policy.X509NamePolicyEngine +type X509Policy policy.X509NamePolicyEngine + +// UserPolicy is an alias for policy.SSHNamePolicyEngine +type UserPolicy policy.SSHNamePolicyEngine + +// HostPolicy is an alias for policy.SSHNamePolicyEngine +type HostPolicy policy.SSHNamePolicyEngine + +// NewX509PolicyEngine creates a new x509 name policy engine +func NewX509PolicyEngine(policyOptions X509PolicyOptionsInterface) (X509Policy, error) { + + // return early if no policy engine options to configure + if policyOptions == nil { + return nil, nil + } + + options := []policy.NamePolicyOption{} + + allowed := policyOptions.GetAllowedNameOptions() + if allowed != nil && allowed.HasNames() { + options = append(options, + policy.WithPermittedDNSDomains(allowed.DNSDomains), + policy.WithPermittedIPsOrCIDRs(allowed.IPRanges), + policy.WithPermittedEmailAddresses(allowed.EmailAddresses), + policy.WithPermittedURIDomains(allowed.URIDomains), + ) + } + + denied := policyOptions.GetDeniedNameOptions() + if denied != nil && denied.HasNames() { + options = append(options, + policy.WithExcludedDNSDomains(denied.DNSDomains), + policy.WithExcludedIPsOrCIDRs(denied.IPRanges), + policy.WithExcludedEmailAddresses(denied.EmailAddresses), + policy.WithExcludedURIDomains(denied.URIDomains), + ) + } + + // ensure no policy engine is returned when no name options were provided + if len(options) == 0 { + return nil, nil + } + + // enable x509 Subject Common Name validation by default + options = append(options, policy.WithSubjectCommonNameVerification()) + + return policy.New(options...) +} + +type sshPolicyEngineType string + +const ( + UserPolicyEngineType sshPolicyEngineType = "user" + HostPolicyEngineType sshPolicyEngineType = "host" +) + +// newSSHUserPolicyEngine creates a new SSH user certificate policy engine +func NewSSHUserPolicyEngine(policyOptions SSHPolicyOptionsInterface) (UserPolicy, error) { + policyEngine, err := newSSHPolicyEngine(policyOptions, UserPolicyEngineType) + if err != nil { + return nil, err + } + return policyEngine, nil +} + +// newSSHHostPolicyEngine create a new SSH host certificate policy engine +func NewSSHHostPolicyEngine(policyOptions SSHPolicyOptionsInterface) (HostPolicy, error) { + policyEngine, err := newSSHPolicyEngine(policyOptions, HostPolicyEngineType) + if err != nil { + return nil, err + } + return policyEngine, nil +} + +// newSSHPolicyEngine creates a new SSH name policy engine +func newSSHPolicyEngine(policyOptions SSHPolicyOptionsInterface, typ sshPolicyEngineType) (policy.SSHNamePolicyEngine, error) { + + // return early if no policy engine options to configure + if policyOptions == nil { + return nil, nil + } + + var ( + allowed *SSHNameOptions + denied *SSHNameOptions + ) + + switch typ { + case UserPolicyEngineType: + allowed = policyOptions.GetAllowedUserNameOptions() + denied = policyOptions.GetDeniedUserNameOptions() + case HostPolicyEngineType: + allowed = policyOptions.GetAllowedHostNameOptions() + denied = policyOptions.GetDeniedHostNameOptions() + default: + return nil, fmt.Errorf("unknown SSH policy engine type %s provided", typ) + } + + options := []policy.NamePolicyOption{} + + if allowed != nil && allowed.HasNames() { + options = append(options, + policy.WithPermittedDNSDomains(allowed.DNSDomains), + policy.WithPermittedIPsOrCIDRs(allowed.IPRanges), + policy.WithPermittedEmailAddresses(allowed.EmailAddresses), + policy.WithPermittedPrincipals(allowed.Principals), + ) + } + + if denied != nil && denied.HasNames() { + options = append(options, + policy.WithExcludedDNSDomains(denied.DNSDomains), + policy.WithExcludedIPsOrCIDRs(denied.IPRanges), + policy.WithExcludedEmailAddresses(denied.EmailAddresses), + policy.WithExcludedPrincipals(denied.Principals), + ) + } + + // ensure no policy engine is returned when no name options were provided + if len(options) == 0 { + return nil, nil + } + + return policy.New(options...) +} diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index 05d16e7f..2d5f74ff 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -7,8 +7,8 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/policy" ) // ACME is the acme provisioner type, an entity that can authorize the ACME @@ -27,7 +27,7 @@ type ACME struct { Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` claimer *Claimer - x509Policy policy.X509NamePolicyEngine + x509Policy policy.X509Policy } // GetID returns the provisioner unique identifier. @@ -92,7 +92,7 @@ func (p *ACME) Init(config Config) (err error) { // Initialize the x509 allow/deny policy engine // TODO(hs): ensure no race conditions happen when reloading settings and requesting certs? // TODO(hs): implement memoization strategy, so that reloading is not required when no changes were made to allow/deny? - if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { + if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil { return err } diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index 2ff8ade9..81029b1d 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -17,6 +17,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" @@ -267,8 +268,8 @@ type AWS struct { claimer *Claimer config *awsConfig audiences Audiences - x509Policy x509PolicyEngine - sshHostPolicy *hostPolicyEngine + x509Policy policy.X509Policy + sshHostPolicy policy.HostPolicy } // GetID returns the provisioner unique identifier. @@ -428,12 +429,12 @@ func (p *AWS) Init(config Config) (err error) { } // Initialize the x509 allow/deny policy engine - if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { + if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil { return err } // Initialize the SSH allow/deny policy engine for host certificates - if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { + if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { return err } diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index f010364c..9c596b11 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -13,6 +13,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" @@ -100,8 +101,8 @@ type Azure struct { config *azureConfig oidcConfig openIDConfiguration keyStore *keyStore - x509Policy x509PolicyEngine - sshHostPolicy *hostPolicyEngine + x509Policy policy.X509Policy + sshHostPolicy policy.HostPolicy } // GetID returns the provisioner unique identifier. @@ -226,12 +227,12 @@ func (p *Azure) Init(config Config) (err error) { } // Initialize the x509 allow/deny policy engine - if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { + if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil { return err } // Initialize the SSH allow/deny policy engine for host certificates - if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { + if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { return err } diff --git a/authority/provisioner/gcp.go b/authority/provisioner/gcp.go index e56c0729..5f08f2f6 100644 --- a/authority/provisioner/gcp.go +++ b/authority/provisioner/gcp.go @@ -14,6 +14,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" @@ -92,8 +93,8 @@ type GCP struct { config *gcpConfig keyStore *keyStore audiences Audiences - x509Policy x509PolicyEngine - sshHostPolicy *hostPolicyEngine + x509Policy policy.X509Policy + sshHostPolicy policy.HostPolicy } // GetID returns the provisioner unique identifier. The name should uniquely @@ -219,12 +220,12 @@ func (p *GCP) Init(config Config) error { } // Initialize the x509 allow/deny policy engine - if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { + if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil { return err } // Initialize the SSH allow/deny policy engine for host certificates - if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { + if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { return err } diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index a129a536..b1716233 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -7,6 +7,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" @@ -37,9 +38,9 @@ type JWK struct { Options *Options `json:"options,omitempty"` claimer *Claimer audiences Audiences - x509Policy x509PolicyEngine - sshHostPolicy *hostPolicyEngine - sshUserPolicy *userPolicyEngine + x509Policy policy.X509Policy + sshHostPolicy policy.HostPolicy + sshUserPolicy policy.UserPolicy } // GetID returns the provisioner unique identifier. The name and credential id @@ -107,17 +108,17 @@ func (p *JWK) Init(config Config) (err error) { } // Initialize the x509 allow/deny policy engine - if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { + if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil { return err } // Initialize the SSH allow/deny policy engine for user certificates - if p.sshUserPolicy, err = newSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil { + if p.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil { return err } // Initialize the SSH allow/deny policy engine for host certificates - if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { + if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { return err } diff --git a/authority/provisioner/k8sSA.go b/authority/provisioner/k8sSA.go index be55f114..7737c1cc 100644 --- a/authority/provisioner/k8sSA.go +++ b/authority/provisioner/k8sSA.go @@ -10,6 +10,7 @@ import ( "net/http" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" @@ -52,9 +53,9 @@ type K8sSA struct { audiences Audiences //kauthn kauthn.AuthenticationV1Interface pubKeys []interface{} - x509Policy x509PolicyEngine - sshHostPolicy *hostPolicyEngine - sshUserPolicy *userPolicyEngine + x509Policy policy.X509Policy + sshHostPolicy policy.HostPolicy + sshUserPolicy policy.UserPolicy } // GetID returns the provisioner unique identifier. The name and credential id @@ -148,17 +149,17 @@ func (p *K8sSA) Init(config Config) (err error) { } // Initialize the x509 allow/deny policy engine - if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { + if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil { return err } // Initialize the SSH allow/deny policy engine for user certificates - if p.sshUserPolicy, err = newSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil { + if p.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil { return err } // Initialize the SSH allow/deny policy engine for host certificates - if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { + if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { return err } diff --git a/authority/provisioner/nebula.go b/authority/provisioner/nebula.go index f8027de9..a9bfab9f 100644 --- a/authority/provisioner/nebula.go +++ b/authority/provisioner/nebula.go @@ -10,6 +10,7 @@ import ( "github.com/pkg/errors" nebula "github.com/slackhq/nebula/cert" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" @@ -43,9 +44,9 @@ type Nebula struct { claimer *Claimer caPool *nebula.NebulaCAPool audiences Audiences - x509Policy x509PolicyEngine - sshHostPolicy *hostPolicyEngine - sshUserPolicy *userPolicyEngine + x509Policy policy.X509Policy + sshHostPolicy policy.HostPolicy + sshUserPolicy policy.UserPolicy } // Init verifies and initializes the Nebula provisioner. @@ -72,17 +73,17 @@ func (p *Nebula) Init(config Config) error { p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) // Initialize the x509 allow/deny policy engine - if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { + if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil { return err } // Initialize the SSH allow/deny policy engine for user certificates - if p.sshUserPolicy, err = newSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil { + if p.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil { return err } // Initialize the SSH allow/deny policy engine for host certificates - if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { + if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { return err } diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go index 60bb5cf1..e3c8740a 100644 --- a/authority/provisioner/oidc.go +++ b/authority/provisioner/oidc.go @@ -12,6 +12,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" @@ -94,9 +95,9 @@ type OIDC struct { keyStore *keyStore claimer *Claimer getIdentityFunc GetIdentityFunc - x509Policy x509PolicyEngine - sshHostPolicy *hostPolicyEngine - sshUserPolicy *userPolicyEngine + x509Policy policy.X509Policy + sshHostPolicy policy.HostPolicy + sshUserPolicy policy.UserPolicy } func sanitizeEmail(email string) string { @@ -212,17 +213,17 @@ func (o *OIDC) Init(config Config) (err error) { } // Initialize the x509 allow/deny policy engine - if o.x509Policy, err = newX509PolicyEngine(o.Options.GetX509Options()); err != nil { + if o.x509Policy, err = policy.NewX509PolicyEngine(o.Options.GetX509Options()); err != nil { return err } // Initialize the SSH allow/deny policy engine for user certificates - if o.sshUserPolicy, err = newSSHUserPolicyEngine(o.Options.GetSSHOptions()); err != nil { + if o.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(o.Options.GetSSHOptions()); err != nil { return err } // Initialize the SSH allow/deny policy engine for host certificates - if o.sshHostPolicy, err = newSSHHostPolicyEngine(o.Options.GetSSHOptions()); err != nil { + if o.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(o.Options.GetSSHOptions()); err != nil { return err } diff --git a/authority/provisioner/options.go b/authority/provisioner/options.go index 257a2107..7725c8b0 100644 --- a/authority/provisioner/options.go +++ b/authority/provisioner/options.go @@ -5,8 +5,11 @@ import ( "strings" "github.com/pkg/errors" + "go.step.sm/crypto/jose" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/authority/policy" ) // CertificateOptions is an interface that returns a list of options passed when @@ -58,10 +61,10 @@ type X509Options struct { TemplateData json.RawMessage `json:"templateData,omitempty"` // AllowedNames contains the SANs the provisioner is authorized to sign - AllowedNames *X509NameOptions `json:"allow,omitempty"` + AllowedNames *policy.X509NameOptions // DeniedNames contains the SANs the provisioner is not authorized to sign - DeniedNames *X509NameOptions `json:"deny,omitempty"` + DeniedNames *policy.X509NameOptions } // HasTemplate returns true if a template is defined in the provisioner options. @@ -69,41 +72,24 @@ func (o *X509Options) HasTemplate() bool { return o != nil && (o.Template != "" || o.TemplateFile != "") } -// GetAllowedNameOptions returns the AllowedNameOptions, which models the +// GetAllowedNameOptions returns the AllowedNames, which models the // SANs that a provisioner is authorized to sign x509 certificates for. -func (o *X509Options) GetAllowedNameOptions() *X509NameOptions { +func (o *X509Options) GetAllowedNameOptions() *policy.X509NameOptions { if o == nil { return nil } return o.AllowedNames } -// GetDeniedNameOptions returns the DeniedNameOptions, which models the +// GetDeniedNameOptions returns the DeniedNames, which models the // SANs that a provisioner is NOT authorized to sign x509 certificates for. -func (o *X509Options) GetDeniedNameOptions() *X509NameOptions { +func (o *X509Options) GetDeniedNameOptions() *policy.X509NameOptions { if o == nil { return nil } return o.DeniedNames } -// X509NameOptions models the X509 name policy configuration. -type X509NameOptions struct { - DNSDomains []string `json:"dns,omitempty"` - IPRanges []string `json:"ip,omitempty"` - EmailAddresses []string `json:"email,omitempty"` - URIDomains []string `json:"uri,omitempty"` -} - -// HasNames checks if the AllowedNameOptions has one or more -// names configured. -func (o *X509NameOptions) HasNames() bool { - return len(o.DNSDomains) > 0 || - len(o.IPRanges) > 0 || - len(o.EmailAddresses) > 0 || - len(o.URIDomains) > 0 -} - // TemplateOptions generates a CertificateOptions with the template and data // defined in the ProvisionerOptions, the provisioner generated data, and the // user data provided in the request. If no template has been provided, diff --git a/authority/provisioner/policy.go b/authority/provisioner/policy.go deleted file mode 100644 index b9740e39..00000000 --- a/authority/provisioner/policy.go +++ /dev/null @@ -1,156 +0,0 @@ -package provisioner - -import ( - "fmt" - - "github.com/smallstep/certificates/policy" - "golang.org/x/crypto/ssh" -) - -type sshPolicyEngineType string - -const ( - userPolicyEngineType sshPolicyEngineType = "user" - hostPolicyEngineType sshPolicyEngineType = "host" -) - -var certTypeToPolicyEngineType = map[uint32]sshPolicyEngineType{ - uint32(ssh.UserCert): userPolicyEngineType, - uint32(ssh.HostCert): hostPolicyEngineType, -} - -type x509PolicyEngine interface { - policy.X509NamePolicyEngine -} - -type userPolicyEngine struct { - policy.SSHNamePolicyEngine -} - -type hostPolicyEngine struct { - policy.SSHNamePolicyEngine -} - -// newX509PolicyEngine creates a new x509 name policy engine -func newX509PolicyEngine(x509Opts *X509Options) (x509PolicyEngine, error) { - - if x509Opts == nil { - return nil, nil - } - - options := []policy.NamePolicyOption{ - policy.WithSubjectCommonNameVerification(), // enable x509 Subject Common Name validation by default - } - - allowed := x509Opts.GetAllowedNameOptions() - if allowed != nil && allowed.HasNames() { - options = append(options, - policy.WithPermittedDNSDomains(allowed.DNSDomains), - policy.WithPermittedIPsOrCIDRs(allowed.IPRanges), - policy.WithPermittedEmailAddresses(allowed.EmailAddresses), - policy.WithPermittedURIDomains(allowed.URIDomains), - ) - } - - denied := x509Opts.GetDeniedNameOptions() - if denied != nil && denied.HasNames() { - options = append(options, - policy.WithExcludedDNSDomains(denied.DNSDomains), - policy.WithExcludedIPsOrCIDRs(denied.IPRanges), - policy.WithExcludedEmailAddresses(denied.EmailAddresses), - policy.WithExcludedURIDomains(denied.URIDomains), - ) - } - - return policy.New(options...) -} - -// newSSHUserPolicyEngine creates a new SSH user certificate policy engine -func newSSHUserPolicyEngine(sshOpts *SSHOptions) (*userPolicyEngine, error) { - policyEngine, err := newSSHPolicyEngine(sshOpts, userPolicyEngineType) - if err != nil { - return nil, err - } - // ensure we're not wrapping a nil engine - if policyEngine == nil { - return nil, nil - } - return &userPolicyEngine{ - SSHNamePolicyEngine: policyEngine, - }, nil -} - -// newSSHHostPolicyEngine create a new SSH host certificate policy engine -func newSSHHostPolicyEngine(sshOpts *SSHOptions) (*hostPolicyEngine, error) { - policyEngine, err := newSSHPolicyEngine(sshOpts, hostPolicyEngineType) - if err != nil { - return nil, err - } - // ensure we're not wrapping a nil engine - if policyEngine == nil { - return nil, nil - } - return &hostPolicyEngine{ - SSHNamePolicyEngine: policyEngine, - }, nil -} - -// newSSHPolicyEngine creates a new SSH name policy engine -func newSSHPolicyEngine(sshOpts *SSHOptions, typ sshPolicyEngineType) (policy.SSHNamePolicyEngine, error) { - - if sshOpts == nil { - return nil, nil - } - - var ( - allowed *SSHNameOptions - denied *SSHNameOptions - ) - - // TODO: embed the type in the policy engine itself for reference? - switch typ { - case userPolicyEngineType: - if sshOpts.User != nil { - allowed = sshOpts.User.GetAllowedNameOptions() - denied = sshOpts.User.GetDeniedNameOptions() - } - case hostPolicyEngineType: - if sshOpts.Host != nil { - allowed = sshOpts.Host.AllowedNames - denied = sshOpts.Host.DeniedNames - } - default: - return nil, fmt.Errorf("unknown SSH policy engine type %s provided", typ) - } - - options := []policy.NamePolicyOption{} - - if allowed != nil && allowed.HasNames() { - options = append(options, - policy.WithPermittedDNSDomains(allowed.DNSDomains), - policy.WithPermittedIPsOrCIDRs(allowed.IPRanges), - policy.WithPermittedEmailAddresses(allowed.EmailAddresses), - policy.WithPermittedPrincipals(allowed.Principals), - ) - } - - if denied != nil && denied.HasNames() { - options = append(options, - policy.WithExcludedDNSDomains(denied.DNSDomains), - policy.WithExcludedIPsOrCIDRs(denied.IPRanges), - policy.WithExcludedEmailAddresses(denied.EmailAddresses), - policy.WithExcludedPrincipals(denied.Principals), - ) - } - - // Return nil, because there's no policy to execute. This is - // important, because the logic that determines user vs. host certs - // are allowed depends on this fact. The two policy engines are - // not aware of eachother, so this check is performed in the - // SSH name validator, instead. - if len(options) == 0 { - return nil, nil - } - - return policy.New(options...) -} diff --git a/authority/provisioner/scep.go b/authority/provisioner/scep.go index a9a06cae..9d02aebb 100644 --- a/authority/provisioner/scep.go +++ b/authority/provisioner/scep.go @@ -5,7 +5,7 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/policy" + "github.com/smallstep/certificates/authority/policy" ) // SCEP is the SCEP provisioner type, an entity that can authorize the @@ -31,7 +31,7 @@ type SCEP struct { Options *Options `json:"options,omitempty"` Claims *Claims `json:"claims,omitempty"` claimer *Claimer - x509Policy policy.X509NamePolicyEngine + x509Policy policy.X509Policy secretChallengePassword string encryptionAlgorithm int } @@ -116,7 +116,7 @@ func (s *SCEP) Init(config Config) (err error) { // TODO: add other, SCEP specific, options? // Initialize the x509 allow/deny policy engine - if s.x509Policy, err = newX509PolicyEngine(s.Options.GetX509Options()); err != nil { + if s.x509Policy, err = policy.NewX509PolicyEngine(s.Options.GetX509Options()); err != nil { return err } diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index 3327310b..082d765d 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -15,6 +15,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/x509util" @@ -407,18 +408,17 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error { // x509NamePolicyValidator validates that the certificate (to be signed) // contains only allowed SANs. type x509NamePolicyValidator struct { - policyEngine x509PolicyEngine + policyEngine policy.X509Policy } // newX509NamePolicyValidator return a new SANs allow/deny validator. -func newX509NamePolicyValidator(engine x509PolicyEngine) *x509NamePolicyValidator { +func newX509NamePolicyValidator(engine policy.X509Policy) *x509NamePolicyValidator { return &x509NamePolicyValidator{ policyEngine: engine, } } -// Valid validates validates that the certificate (to be signed) -// contains only allowed SANs. +// Valid validates that the certificate (to be signed) contains only allowed SANs. func (v *x509NamePolicyValidator) Valid(cert *x509.Certificate, _ SignOptions) error { if v.policyEngine == nil { return nil diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go index 8f9cf466..a057b2b9 100644 --- a/authority/provisioner/sign_ssh_options.go +++ b/authority/provisioner/sign_ssh_options.go @@ -10,6 +10,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/keyutil" "golang.org/x/crypto/ssh" @@ -448,20 +449,19 @@ func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SignSSHOpti // sshNamePolicyValidator validates that the certificate (to be signed) // contains only allowed principals. type sshNamePolicyValidator struct { - hostPolicyEngine *hostPolicyEngine - userPolicyEngine *userPolicyEngine + hostPolicyEngine policy.HostPolicy + userPolicyEngine policy.UserPolicy } // newSSHNamePolicyValidator return a new SSH allow/deny validator. -func newSSHNamePolicyValidator(host *hostPolicyEngine, user *userPolicyEngine) *sshNamePolicyValidator { +func newSSHNamePolicyValidator(host policy.HostPolicy, user policy.UserPolicy) *sshNamePolicyValidator { return &sshNamePolicyValidator{ hostPolicyEngine: host, userPolicyEngine: user, } } -// Valid validates validates that the certificate (to be signed) -// contains only allowed principals. +// Valid validates that the certificate (to be signed) contains only allowed principals. func (v *sshNamePolicyValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions) error { if v.hostPolicyEngine == nil && v.userPolicyEngine == nil { // no policy configured at all; allow anything @@ -473,29 +473,25 @@ func (v *sshNamePolicyValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions) // the same for host certs: if only a user policy engine is configured, host // certs are denied. When both policy engines are configured, the type of // cert determines which policy engine is used. - policyType, ok := certTypeToPolicyEngineType[cert.CertType] - if !ok { - return fmt.Errorf("unexpected SSH cert type %d", cert.CertType) - } - switch policyType { - case hostPolicyEngineType: + switch cert.CertType { + case ssh.HostCert: // when no host policy engine is configured, but a user policy engine is - // configured, we don't allow the host certificate. + // configured, the host certificate is denied. if v.hostPolicyEngine == nil && v.userPolicyEngine != nil { - return errors.New("SSH host certificate not authorized") // TODO: include principals in message? + return errors.New("SSH host certificate not authorized") } _, err := v.hostPolicyEngine.ArePrincipalsAllowed(cert) return err - case userPolicyEngineType: + case ssh.UserCert: // when no user policy engine is configured, but a host policy engine is - // configured, we don't allow the user certificate. + // configured, the user certificate is denied. if v.userPolicyEngine == nil && v.hostPolicyEngine != nil { - return errors.New("SSH user certificate not authorized") // TODO: include principals in message? + return errors.New("SSH user certificate not authorized") } _, err := v.userPolicyEngine.ArePrincipalsAllowed(cert) return err default: - return fmt.Errorf("unexpected policy engine type %q", policyType) // satisfy return; shouldn't happen + return fmt.Errorf("unexpected SSH certificate type %d", cert.CertType) // satisfy return; shouldn't happen } } diff --git a/authority/provisioner/ssh_options.go b/authority/provisioner/ssh_options.go index dacafc80..92c5826b 100644 --- a/authority/provisioner/ssh_options.go +++ b/authority/provisioner/ssh_options.go @@ -6,6 +6,8 @@ import ( "github.com/pkg/errors" "go.step.sm/crypto/sshutil" + + "github.com/smallstep/certificates/authority/policy" ) // SSHCertificateOptions is an interface that returns a list of options passed when @@ -35,32 +37,58 @@ type SSHOptions struct { TemplateData json.RawMessage `json:"templateData,omitempty"` // User contains SSH user certificate options. - User *SSHUserCertificateOptions `json:"user,omitempty"` + User *policy.SSHUserCertificateOptions // Host contains SSH host certificate options. - Host *SSHHostCertificateOptions `json:"host,omitempty"` + Host *policy.SSHHostCertificateOptions } -// SSHUserCertificateOptions is a collection of SSH user certificate options. -type SSHUserCertificateOptions struct { - // AllowedNames contains the names the provisioner is authorized to sign - AllowedNames *SSHNameOptions `json:"allow,omitempty"` - - // DeniedNames contains the names the provisioner is not authorized to sign - DeniedNames *SSHNameOptions `json:"deny,omitempty"` +// GetAllowedUserNameOptions returns the SSHNameOptions that are +// allowed when SSH User certificates are requested. +func (o *SSHOptions) GetAllowedUserNameOptions() *policy.SSHNameOptions { + if o == nil { + return nil + } + if o.User == nil { + return nil + } + return o.User.AllowedNames } -// SSHHostCertificateOptions is a collection of SSH host certificate options. -// It's an alias of SSHUserCertificateOptions, as the options are the same -// for both types of certificates. -type SSHHostCertificateOptions SSHUserCertificateOptions +// GetDeniedUserNameOptions returns the SSHNameOptions that are +// denied when SSH user certificates are requested. +func (o *SSHOptions) GetDeniedUserNameOptions() *policy.SSHNameOptions { + if o == nil { + return nil + } + if o.User == nil { + return nil + } + return o.User.DeniedNames +} -// SSHNameOptions models the SSH name policy configuration. -type SSHNameOptions struct { - DNSDomains []string `json:"dns,omitempty"` - IPRanges []string `json:"ip,omitempty"` - EmailAddresses []string `json:"email,omitempty"` - Principals []string `json:"principal,omitempty"` +// GetAllowedHostNameOptions returns the SSHNameOptions that are +// allowed when SSH host certificates are requested. +func (o *SSHOptions) GetAllowedHostNameOptions() *policy.SSHNameOptions { + if o == nil { + return nil + } + if o.Host == nil { + return nil + } + return o.Host.AllowedNames +} + +// GetDeniedHostNameOptions returns the SSHNameOptions that are +// denied when SSH host certificates are requested. +func (o *SSHOptions) GetDeniedHostNameOptions() *policy.SSHNameOptions { + if o == nil { + return nil + } + if o.Host == nil { + return nil + } + return o.Host.DeniedNames } // HasTemplate returns true if a template is defined in the provisioner options. @@ -68,32 +96,6 @@ func (o *SSHOptions) HasTemplate() bool { return o != nil && (o.Template != "" || o.TemplateFile != "") } -// GetAllowedNameOptions returns the AllowedSSHNameOptions, which models the -// names that a provisioner is authorized to sign SSH certificates for. -func (o *SSHUserCertificateOptions) GetAllowedNameOptions() *SSHNameOptions { - if o == nil { - return nil - } - return o.AllowedNames -} - -// GetDeniedNameOptions returns the DeniedSSHNameOptions, which models the -// names that a provisioner is NOT authorized to sign SSH certificates for. -func (o *SSHUserCertificateOptions) GetDeniedNameOptions() *SSHNameOptions { - if o == nil { - return nil - } - return o.DeniedNames -} - -// HasNames checks if the SSHNameOptions has one or more -// names configured. -func (o *SSHNameOptions) HasNames() bool { - return len(o.DNSDomains) > 0 || - len(o.EmailAddresses) > 0 || - len(o.Principals) > 0 -} - // TemplateSSHOptions generates a SSHCertificateOptions with the template and // data defined in the ProvisionerOptions, the provisioner generated data, and // the user data provided in the request. If no template has been provided, diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index 12112cc6..a8275474 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -8,6 +8,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" @@ -35,9 +36,9 @@ type X5C struct { claimer *Claimer audiences Audiences rootPool *x509.CertPool - x509Policy x509PolicyEngine - sshHostPolicy *hostPolicyEngine - sshUserPolicy *userPolicyEngine + x509Policy policy.X509Policy + sshHostPolicy policy.HostPolicy + sshUserPolicy policy.UserPolicy } // GetID returns the provisioner unique identifier. The name and credential id @@ -129,17 +130,17 @@ func (p *X5C) Init(config Config) error { } // Initialize the x509 allow/deny policy engine - if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { + if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil { return err } // Initialize the SSH allow/deny policy engine for user certificates - if p.sshUserPolicy, err = newSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil { + if p.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil { return err } // Initialize the SSH allow/deny policy engine for host certificates - if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { + if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { return err } diff --git a/authority/ssh.go b/authority/ssh.go index 4a67b28c..7c3df192 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/pkg/errors" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" @@ -241,6 +242,45 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi return nil, errs.InternalServer("authority.SignSSH: unexpected ssh certificate type: %d", certTpl.CertType) } + switch certTpl.CertType { + case ssh.UserCert: + // when no user policy engine is configured, but a host policy engine is + // configured, the user certificate is denied. + if a.sshUserPolicy == nil && a.sshHostPolicy != nil { + return nil, errs.ForbiddenErr(errors.New("authority not allowed to sign ssh user certificates"), "authority.SignSSH: error creating ssh user certificate") + } + if a.sshUserPolicy != nil { + allowed, err := a.sshUserPolicy.ArePrincipalsAllowed(certTpl) + if err != nil { + return nil, errs.InternalServerErr(err, + errs.WithMessage("authority.SignSSH: error creating ssh user certificate"), + ) + } + if !allowed { + return nil, errs.ForbiddenErr(errors.New("authority not allowed to sign"), "authority.SignSSH: error creating ssh user certificate") + } + } + case ssh.HostCert: + // when no host policy engine is configured, but a user policy engine is + // configured, the host certificate is denied. + if a.sshHostPolicy == nil && a.sshUserPolicy != nil { + return nil, errs.ForbiddenErr(errors.New("authority not allowed to sign ssh host certificates"), "authority.SignSSH: error creating ssh user certificate") + } + if a.sshHostPolicy != nil { + allowed, err := a.sshHostPolicy.ArePrincipalsAllowed(certTpl) + if err != nil { + return nil, errs.InternalServerErr(err, + errs.WithMessage("authority.SignSSH: error creating ssh host certificate"), + ) + } + if !allowed { + return nil, errs.ForbiddenErr(errors.New("authority not allowed to sign"), "authority.SignSSH: error creating ssh host certificate") + } + } + default: + return nil, errs.InternalServer("authority.SignSSH: unexpected ssh certificate type: %d", certTpl.CertType) + } + // Sign certificate. cert, err := sshutil.CreateCertificate(certTpl, signer) if err != nil { diff --git a/authority/tls.go b/authority/tls.go index 58a1247c..d749e2ad 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -191,6 +191,25 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign } } + // If a policy is configured, perform allow/deny policy check on authority level + if a.x509Policy != nil { + allowed, err := a.x509Policy.AreCertificateNamesAllowed(leaf) + if err != nil { + return nil, errs.InternalServerErr(err, + errs.WithKeyVal("csr", csr), + errs.WithKeyVal("signOptions", signOpts), + errs.WithMessage("error creating certificate"), + ) + } + if !allowed { + // TODO: include SANs in error message? + return nil, errs.ApplyOptions( + errs.ForbiddenErr(errors.New("authority not allowed to sign"), "error creating certificate"), + opts..., + ) + } + } + // Sign certificate lifetime := leaf.NotAfter.Sub(leaf.NotBefore.Add(signOpts.Backdate)) resp, err := a.x509CAService.CreateCertificate(&casapi.CreateCertificateRequest{ From 3ec9a7310cf87bb0d576bc76f0ed6f65037cb8d2 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 8 Mar 2022 14:17:59 +0100 Subject: [PATCH 06/44] Fix ACME order identifier allow/deny check --- acme/api/order.go | 4 +++- acme/common.go | 6 +++--- authority/provisioner/acme.go | 18 +++++++++--------- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/acme/api/order.go b/acme/api/order.go index 3d22ec0f..e1adebb3 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -13,6 +13,7 @@ import ( "github.com/go-chi/chi" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/crypto/randutil" ) @@ -107,7 +108,8 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { for _, identifier := range nor.Identifiers { // TODO: gather all errors, so that we can build subproblems; include the nor.Validate() error here too, like in example? - err = prov.AuthorizeOrderIdentifier(ctx, identifier.Value) + orderIdentifier := provisioner.ACMEIdentifier{Type: provisioner.ACMEIdentifierType(identifier.Type), Value: identifier.Value} + err = prov.AuthorizeOrderIdentifier(ctx, orderIdentifier) if err != nil { api.WriteError(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) return diff --git a/acme/common.go b/acme/common.go index 4b086dd7..9c5e732a 100644 --- a/acme/common.go +++ b/acme/common.go @@ -30,7 +30,7 @@ var clock Clock // Provisioner is an interface that implements a subset of the provisioner.Interface -- // only those methods required by the ACME api/authority. type Provisioner interface { - AuthorizeOrderIdentifier(ctx context.Context, identifier string) error + AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) AuthorizeRevoke(ctx context.Context, token string) error GetID() string @@ -45,7 +45,7 @@ type MockProvisioner struct { Merr error MgetID func() string MgetName func() string - MauthorizeOrderIdentifier func(ctx context.Context, identifier string) error + MauthorizeOrderIdentifier func(ctx context.Context, identifier provisioner.ACMEIdentifier) error MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) MauthorizeRevoke func(ctx context.Context, token string) error MdefaultTLSCertDuration func() time.Duration @@ -61,7 +61,7 @@ func (m *MockProvisioner) GetName() string { } // AuthorizeOrderIdentifiers mock -func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier string) error { +func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error { if m.MauthorizeOrderIdentifier != nil { return m.MauthorizeOrderIdentifier(ctx, identifier) } diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index 2d5f74ff..9f8ef690 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -90,8 +90,6 @@ func (p *ACME) Init(config Config) (err error) { } // Initialize the x509 allow/deny policy engine - // TODO(hs): ensure no race conditions happen when reloading settings and requesting certs? - // TODO(hs): implement memoization strategy, so that reloading is not required when no changes were made to allow/deny? if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil { return err } @@ -115,20 +113,22 @@ type ACMEIdentifier struct { Value string } -// AuthorizeOrderIdentifiers verifies the provisioner is authorized to issue a -// certificate for the Identifiers provided in an Order. -func (p *ACME) AuthorizeOrderIdentifier(ctx context.Context, identifier string) error { +// AuthorizeOrderIdentifier verifies the provisioner is allowed to issue a +// certificate for an ACME Order Identifier. +func (p *ACME) AuthorizeOrderIdentifier(ctx context.Context, identifier ACMEIdentifier) error { + // identifier is allowed if no policy is configured if p.x509Policy == nil { return nil } // assuming only valid identifiers (IP or DNS) are provided var err error - if ip := net.ParseIP(identifier); ip != nil { - _, err = p.x509Policy.IsIPAllowed(ip) - } else { - _, err = p.x509Policy.IsDNSAllowed(identifier) + switch identifier.Type { + case IP: + _, err = p.x509Policy.IsIPAllowed(net.ParseIP(identifier.Value)) + case DNS: + _, err = p.x509Policy.IsDNSAllowed(identifier.Value) } return err From fd6a2eeb9cfb67427df94a93c367d9a044628590 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 9 Mar 2022 18:39:09 -0800 Subject: [PATCH 07/44] Add provisioner controller The provisioner controller has the implementation of the identity function as well as the renew methods with renew after expiry support. --- authority/config/config.go | 25 ++-- authority/provisioner/claims.go | 47 ++++--- authority/provisioner/controller.go | 194 +++++++++++++++++++++++++++ authority/provisioner/provisioner.go | 87 +----------- 4 files changed, 246 insertions(+), 107 deletions(-) create mode 100644 authority/provisioner/controller.go diff --git a/authority/config/config.go b/authority/config/config.go index 589b5bbf..c33a2b1d 100644 --- a/authority/config/config.go +++ b/authority/config/config.go @@ -26,23 +26,26 @@ var ( DefaultBackdate = time.Minute // DefaultDisableRenewal disables renewals per provisioner. DefaultDisableRenewal = false + // DefaultEnableRenewAfterExpiry enables renewals even when the certificate is expired. + DefaultEnableRenewAfterExpiry = false // DefaultEnableSSHCA enable SSH CA features per provisioner or globally // for all provisioners. DefaultEnableSSHCA = false // GlobalProvisionerClaims default claims for the Authority. Can be overridden // by provisioner specific claims. GlobalProvisionerClaims = provisioner.Claims{ - MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs - MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DisableRenewal: &DefaultDisableRenewal, - MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs - MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour}, - MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs - MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, - DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, - EnableSSHCA: &DefaultEnableSSHCA, + MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs + MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs + MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour}, + MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs + MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, + DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, + EnableSSHCA: &DefaultEnableSSHCA, + DisableRenewal: &DefaultDisableRenewal, + EnableRenewAfterExpiry: &DefaultEnableRenewAfterExpiry, } ) diff --git a/authority/provisioner/claims.go b/authority/provisioner/claims.go index 629a313c..c8bee2e5 100644 --- a/authority/provisioner/claims.go +++ b/authority/provisioner/claims.go @@ -10,10 +10,10 @@ import ( // Claims so that individual provisioners can override global claims. type Claims struct { // TLS CA properties - MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"` - MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"` - DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"` - DisableRenewal *bool `json:"disableRenewal,omitempty"` + MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"` + MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"` + DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"` + // SSH CA properties MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"` MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"` @@ -22,6 +22,10 @@ type Claims struct { MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"` DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"` EnableSSHCA *bool `json:"enableSSHCA,omitempty"` + + // Renewal properties + DisableRenewal *bool `json:"disableRenewal,omitempty"` + EnableRenewAfterExpiry *bool `json:"enableRenewAfterExpiry,omitempty"` } // Claimer is the type that controls claims. It provides an interface around the @@ -40,19 +44,22 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) { // Claims returns the merge of the inner and global claims. func (c *Claimer) Claims() Claims { disableRenewal := c.IsDisableRenewal() + enablerenewAfterExpiry := c.IsRenewAfterExpiry() enableSSHCA := c.IsSSHCAEnabled() + return Claims{ - MinTLSDur: &Duration{c.MinTLSCertDuration()}, - MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, - DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, - DisableRenewal: &disableRenewal, - MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()}, - MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()}, - DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()}, - MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()}, - MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()}, - DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()}, - EnableSSHCA: &enableSSHCA, + MinTLSDur: &Duration{c.MinTLSCertDuration()}, + MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, + DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, + MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()}, + MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()}, + DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()}, + MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()}, + MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()}, + DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()}, + EnableSSHCA: &enableSSHCA, + DisableRenewal: &disableRenewal, + EnableRenewAfterExpiry: &enablerenewAfterExpiry, } } @@ -102,6 +109,16 @@ func (c *Claimer) IsDisableRenewal() bool { return *c.claims.DisableRenewal } +// IsRenewAfterExpiry returns if the renewal flow is authorized even if the +// certificate is expired. If the property is not set within the provisioner +// then the global value from the authority configuration will be used. +func (c *Claimer) IsRenewAfterExpiry() bool { + if c.claims == nil || c.claims.EnableRenewAfterExpiry == nil { + return *c.global.EnableRenewAfterExpiry + } + return *c.claims.EnableRenewAfterExpiry +} + // DefaultSSHCertDuration returns the default SSH certificate duration for the // given certificate type. func (c *Claimer) DefaultSSHCertDuration(certType uint32) (time.Duration, error) { diff --git a/authority/provisioner/controller.go b/authority/provisioner/controller.go new file mode 100644 index 00000000..815482f9 --- /dev/null +++ b/authority/provisioner/controller.go @@ -0,0 +1,194 @@ +package provisioner + +import ( + "context" + "crypto/x509" + "regexp" + "strings" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" + "golang.org/x/crypto/ssh" +) + +// Controller wraps a provisioner with other attributes useful in callback +// functions. +type Controller struct { + Interface + Audiences *Audiences + Claimer *Claimer + IdentityFunc GetIdentityFunc + AuthorizeRenewFunc AuthorizeRenewFunc + AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc +} + +// NewController initializes a new provisioner controller. +func NewController(p Interface, claims *Claims, config Config) (*Controller, error) { + claimer, err := NewClaimer(claims, config.Claims) + if err != nil { + return nil, err + } + return &Controller{ + Interface: p, + Audiences: &config.Audiences, + Claimer: claimer, + IdentityFunc: config.GetIdentityFunc, + AuthorizeRenewFunc: config.AuthorizeRenewFunc, + AuthorizeSSHRenewFunc: config.AuthorizeSSHRenewFunc, + }, nil +} + +// GetIdentity returns the identity for a given email. +func (c *Controller) GetIdentity(ctx context.Context, email string) (*Identity, error) { + if c.IdentityFunc != nil { + return c.IdentityFunc(ctx, c.Interface, email) + } + return DefaultIdentityFunc(ctx, c.Interface, email) +} + +// AuthorizeRenew returns nil if the given cert can be renewed, returns an error +// otherwise. +func (c *Controller) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { + if c.AuthorizeRenewFunc != nil { + return c.AuthorizeRenewFunc(ctx, c, cert) + } + return DefaultAuthorizeRenew(ctx, c, cert) +} + +// AuthorizeSSHRenew returns nil if the given cert can be renewed, returns an +// error otherwise. +func (c *Controller) AuthorizeSSHRenew(ctx context.Context, cert *ssh.Certificate) error { + if c.AuthorizeSSHRenewFunc != nil { + return c.AuthorizeSSHRenewFunc(ctx, c, cert) + } + return DefaultAuthorizeSSHRenew(ctx, c, cert) +} + +// Identity is the type representing an externally supplied identity that is used +// by provisioners to populate certificate fields. +type Identity struct { + Usernames []string `json:"usernames"` + Permissions `json:"permissions"` +} + +// GetIdentityFunc is a function that returns an identity. +type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error) + +// AuthorizeRenewFunc is a function that returns nil if the renewal of a +// certificate is enabled. +type AuthorizeRenewFunc func(ctx context.Context, p *Controller, cert *x509.Certificate) error + +// AuthorizeSSHRenewFunc is a function that returns nil if the renewal of the +// given SSH certificate is enabled. +type AuthorizeSSHRenewFunc func(ctx context.Context, p *Controller, cert *ssh.Certificate) error + +// DefaultIdentityFunc return a default identity depending on the provisioner +// type. For OIDC email is always present and the usernames might +// contain empty strings. +func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) { + switch k := p.(type) { + case *OIDC: + // OIDC principals would be: + // ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled + // 2. Sanitized local. + // 3. Raw local (if different). + // 4. Email address. + name := SanitizeSSHUserPrincipal(email) + if !sshUserRegex.MatchString(name) { + return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email) + } + usernames := []string{name} + if i := strings.LastIndex(email, "@"); i >= 0 { + usernames = append(usernames, email[:i]) + } + usernames = append(usernames, email) + return &Identity{ + Usernames: SanitizeStringSlices(usernames), + }, nil + default: + return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k) + } +} + +// DefaultAuthorizeRenew is the default implementation of AuthorizeRenew. It +// will return an error if the provisioner has the renewal disabled, if the +// certificate is not yet valid or if the certificate is expired and renew after +// expiry is disabled. +func DefaultAuthorizeRenew(ctx context.Context, p *Controller, cert *x509.Certificate) error { + if p.Claimer.IsDisableRenewal() { + return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName()) + } + + now := time.Now().Truncate(time.Second) + if now.Before(cert.NotBefore) { + return errs.Unauthorized("certificate is not yet valid") + } + if now.After(cert.NotAfter) && !p.Claimer.IsRenewAfterExpiry() { + return errs.Unauthorized("certificate has expired") + } + + return nil +} + +// DefaultAuthorizeSSHRenew is the default implementation of AuthorizeSSHRenew. It +// will return an error if the provisioner has the renewal disabled, if the +// certificate is not yet valid or if the certificate is expired and renew after +// expiry is disabled. +func DefaultAuthorizeSSHRenew(ctx context.Context, p *Controller, cert *ssh.Certificate) error { + if p.Claimer.IsDisableRenewal() { + return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName()) + } + + unixNow := time.Now().Unix() + if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { + return errs.Unauthorized("certificate is not yet valid") + } + if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.IsRenewAfterExpiry() { + return errs.Unauthorized("certificate has expired") + } + + return nil +} + +var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$") + +// SanitizeStringSlices removes duplicated an empty strings. +func SanitizeStringSlices(original []string) []string { + output := []string{} + seen := make(map[string]struct{}) + for _, entry := range original { + if entry == "" { + continue + } + if _, value := seen[entry]; !value { + seen[entry] = struct{}{} + output = append(output, entry) + } + } + return output +} + +// SanitizeSSHUserPrincipal grabs an email or a string with the format +// local@domain and returns a sanitized version of the local, valid to be used +// as a user name. If the email starts with a letter between a and z, the +// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`. +func SanitizeSSHUserPrincipal(email string) string { + if i := strings.LastIndex(email, "@"); i >= 0 { + email = email[:i] + } + return strings.Map(func(r rune) rune { + switch { + case r >= 'a' && r <= 'z': + return r + case r >= '0' && r <= '9': + return r + case r == '-': + return '-' + case r == '.': // drop dots + return -1 + default: + return '_' + } + }, strings.ToLower(email)) +} diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 55ebe092..0b79bf4f 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -6,7 +6,6 @@ import ( "encoding/json" stderrors "errors" "net/url" - "regexp" "strings" "github.com/pkg/errors" @@ -210,6 +209,12 @@ type Config struct { // GetIdentityFunc is a function that returns an identity that will be // used by the provisioner to populate certificate attributes. GetIdentityFunc GetIdentityFunc + // AuthorizeRenewFunc is a function that returns nil if a given X.509 + // certificate can be renewed. + AuthorizeRenewFunc AuthorizeRenewFunc + // AuthorizeSSHRenewFunc is a function that returns nil if a given SSH + // certificate can be renewed. + AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc } type provisioner struct { @@ -278,32 +283,6 @@ func (l *List) UnmarshalJSON(data []byte) error { return nil } -var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$") - -// SanitizeSSHUserPrincipal grabs an email or a string with the format -// local@domain and returns a sanitized version of the local, valid to be used -// as a user name. If the email starts with a letter between a and z, the -// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`. -func SanitizeSSHUserPrincipal(email string) string { - if i := strings.LastIndex(email, "@"); i >= 0 { - email = email[:i] - } - return strings.Map(func(r rune) rune { - switch { - case r >= 'a' && r <= 'z': - return r - case r >= '0' && r <= '9': - return r - case r == '-': - return '-' - case r == '.': // drop dots - return -1 - default: - return '_' - } - }, strings.ToLower(email)) -} - type base struct{} // AuthorizeSign returns an unimplemented error. Provisioners should overwrite @@ -348,66 +327,12 @@ func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certif return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented") } -// Identity is the type representing an externally supplied identity that is used -// by provisioners to populate certificate fields. -type Identity struct { - Usernames []string `json:"usernames"` - Permissions `json:"permissions"` -} - // Permissions defines extra extensions and critical options to grant to an SSH certificate. type Permissions struct { Extensions map[string]string `json:"extensions"` CriticalOptions map[string]string `json:"criticalOptions"` } -// GetIdentityFunc is a function that returns an identity. -type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error) - -// DefaultIdentityFunc return a default identity depending on the provisioner -// type. For OIDC email is always present and the usernames might -// contain empty strings. -func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) { - switch k := p.(type) { - case *OIDC: - // OIDC principals would be: - // ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled - // 2. Sanitized local. - // 3. Raw local (if different). - // 4. Email address. - name := SanitizeSSHUserPrincipal(email) - if !sshUserRegex.MatchString(name) { - return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email) - } - usernames := []string{name} - if i := strings.LastIndex(email, "@"); i >= 0 { - usernames = append(usernames, email[:i]) - } - usernames = append(usernames, email) - return &Identity{ - Usernames: SanitizeStringSlices(usernames), - }, nil - default: - return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k) - } -} - -// SanitizeStringSlices removes duplicated an empty strings. -func SanitizeStringSlices(original []string) []string { - output := []string{} - seen := make(map[string]struct{}) - for _, entry := range original { - if entry == "" { - continue - } - if _, value := seen[entry]; !value { - seen[entry] = struct{}{} - output = append(output, entry) - } - } - return output -} - // MockProvisioner for testing type MockProvisioner struct { Mret1, Mret2, Mret3 interface{} From 3c2ff33ca90cb97fd861b9a479eb851e4feffab0 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 9 Mar 2022 18:43:27 -0800 Subject: [PATCH 08/44] Add provisioner controller tests. --- authority/provisioner/controller_test.go | 391 +++++++++++++++++++++++ 1 file changed, 391 insertions(+) create mode 100644 authority/provisioner/controller_test.go diff --git a/authority/provisioner/controller_test.go b/authority/provisioner/controller_test.go new file mode 100644 index 00000000..68f7055c --- /dev/null +++ b/authority/provisioner/controller_test.go @@ -0,0 +1,391 @@ +package provisioner + +import ( + "context" + "crypto/x509" + "fmt" + "reflect" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +var trueValue = true + +func mustClaimer(t *testing.T, claims *Claims, global Claims) *Claimer { + t.Helper() + c, err := NewClaimer(claims, global) + if err != nil { + t.Fatal(err) + } + return c +} +func mustDuration(t *testing.T, s string) *Duration { + t.Helper() + d, err := NewDuration(s) + if err != nil { + t.Fatal(err) + } + return d +} + +func TestNewController(t *testing.T) { + type args struct { + p Interface + claims *Claims + config Config + } + tests := []struct { + name string + args args + want *Controller + wantErr bool + }{ + {"ok", args{&JWK{}, nil, Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + }}, &Controller{ + Interface: &JWK{}, + Audiences: &testAudiences, + Claimer: mustClaimer(t, nil, globalProvisionerClaims), + }, false}, + {"ok with claims", args{&JWK{}, &Claims{ + DisableRenewal: &defaultDisableRenewal, + }, Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + }}, &Controller{ + Interface: &JWK{}, + Audiences: &testAudiences, + Claimer: mustClaimer(t, &Claims{ + DisableRenewal: &defaultDisableRenewal, + }, globalProvisionerClaims), + }, false}, + {"fail claimer", args{&JWK{}, &Claims{ + MinTLSDur: mustDuration(t, "24h"), + MaxTLSDur: mustDuration(t, "2h"), + }, Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewController(tt.args.p, tt.args.claims, tt.args.config) + if (err != nil) != tt.wantErr { + t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewController() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestController_GetIdentity(t *testing.T) { + ctx := context.Background() + type fields struct { + Interface Interface + IdentityFunc GetIdentityFunc + } + type args struct { + ctx context.Context + email string + } + tests := []struct { + name string + fields fields + args args + want *Identity + wantErr bool + }{ + {"ok", fields{&OIDC{}, nil}, args{ctx, "jane@doe.org"}, &Identity{ + Usernames: []string{"jane", "jane@doe.org"}, + }, false}, + {"ok custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) { + return &Identity{Usernames: []string{"jane"}}, nil + }}, args{ctx, "jane@doe.org"}, &Identity{ + Usernames: []string{"jane"}, + }, false}, + {"fail provisioner", fields{&JWK{}, nil}, args{ctx, "jane@doe.org"}, nil, true}, + {"fail custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) { + return nil, fmt.Errorf("an error") + }}, args{ctx, "jane@doe.org"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Controller{ + Interface: tt.fields.Interface, + IdentityFunc: tt.fields.IdentityFunc, + } + got, err := c.GetIdentity(tt.args.ctx, tt.args.email) + if (err != nil) != tt.wantErr { + t.Errorf("Controller.GetIdentity() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Controller.GetIdentity() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestController_AuthorizeRenew(t *testing.T) { + ctx := context.Background() + now := time.Now() + type fields struct { + Interface Interface + Claimer *Claimer + AuthorizeRenewFunc AuthorizeRenewFunc + } + type args struct { + ctx context.Context + cert *x509.Certificate + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { + return nil + }}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { + return nil + }}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + }}, false}, + {"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, true}, + {"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now.Add(time.Hour), + NotAfter: now.Add(2 * time.Hour), + }}, true}, + {"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + }}, true}, + {"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { + return fmt.Errorf("an error") + }}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Controller{ + Interface: tt.fields.Interface, + Claimer: tt.fields.Claimer, + AuthorizeRenewFunc: tt.fields.AuthorizeRenewFunc, + } + if err := c.AuthorizeRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("Controller.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestController_AuthorizeSSHRenew(t *testing.T) { + ctx := context.Background() + now := time.Now() + type fields struct { + Interface Interface + Claimer *Claimer + AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc + } + type args struct { + ctx context.Context + cert *ssh.Certificate + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, false}, + {"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { + return nil + }}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, false}, + {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { + return nil + }}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, false}, + {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Add(-time.Hour).Unix()), + ValidBefore: uint64(now.Add(-time.Minute).Unix()), + }}, false}, + {"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, true}, + {"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Add(time.Hour).Unix()), + ValidBefore: uint64(now.Add(2 * time.Hour).Unix()), + }}, true}, + {"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Add(-time.Hour).Unix()), + ValidBefore: uint64(now.Add(-time.Minute).Unix()), + }}, true}, + {"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { + return fmt.Errorf("an error") + }}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Controller{ + Interface: tt.fields.Interface, + Claimer: tt.fields.Claimer, + AuthorizeSSHRenewFunc: tt.fields.AuthorizeSSHRenewFunc, + } + if err := c.AuthorizeSSHRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("Controller.AuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestDefaultAuthorizeRenew(t *testing.T) { + ctx := context.Background() + now := time.Now() + type args struct { + ctx context.Context + p *Controller + cert *x509.Certificate + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, nil, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"ok renew after expiry", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + }}, false}, + {"fail disabled", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, true}, + {"fail not yet valid", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now.Add(time.Hour), + NotAfter: now.Add(2 * time.Hour), + }}, true}, + {"fail expired", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := DefaultAuthorizeRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("DefaultAuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestDefaultAuthorizeSSHRenew(t *testing.T) { + ctx := context.Background() + now := time.Now() + type args struct { + ctx context.Context + p *Controller + cert *ssh.Certificate + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, nil, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, false}, + {"ok renew after expiry", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Add(-time.Hour).Unix()), + ValidBefore: uint64(now.Add(-time.Minute).Unix()), + }}, false}, + {"fail disabled", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, true}, + {"fail not yet valid", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Add(time.Hour).Unix()), + ValidBefore: uint64(now.Add(2 * time.Hour).Unix()), + }}, true}, + {"fail expired", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Add(-time.Hour).Unix()), + ValidBefore: uint64(now.Add(-time.Minute).Unix()), + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := DefaultAuthorizeSSHRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("DefaultAuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} From 259e95947cb2300603098956dbd470875d477c2f Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 9 Mar 2022 18:43:45 -0800 Subject: [PATCH 09/44] Add support for the provisioner controller The claimer, audiences and custom callback methods are now managed by the provisioner controller in an uniform way. --- authority/authorize.go | 2 +- authority/authorize_test.go | 10 +- authority/provisioner/acme.go | 22 +-- authority/provisioner/acme_test.go | 25 ++- authority/provisioner/aws.go | 30 ++- authority/provisioner/aws_test.go | 23 ++- authority/provisioner/azure.go | 32 ++-- authority/provisioner/azure_test.go | 23 ++- authority/provisioner/gcp.go | 39 ++-- authority/provisioner/gcp_test.go | 23 ++- authority/provisioner/jwk.go | 37 ++-- authority/provisioner/jwk_test.go | 33 ++-- authority/provisioner/k8sSA.go | 45 ++--- authority/provisioner/k8sSA_test.go | 31 +-- authority/provisioner/nebula.go | 52 +++-- authority/provisioner/nebula_test.go | 68 +++---- authority/provisioner/oidc.go | 34 +--- authority/provisioner/oidc_test.go | 29 +-- authority/provisioner/scep.go | 32 ++-- .../provisioner/sign_ssh_options_test.go | 20 +- authority/provisioner/sshpop.go | 29 ++- authority/provisioner/sshpop_test.go | 9 +- authority/provisioner/utils_test.go | 181 ++++++++---------- authority/provisioner/x5c.go | 55 +++--- authority/provisioner/x5c_test.go | 29 +-- authority/tls_test.go | 12 +- 26 files changed, 450 insertions(+), 475 deletions(-) diff --git a/authority/authorize.go b/authority/authorize.go index 5108f567..4f64921b 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -276,6 +276,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error { func (a *Authority) authorizeRenew(cert *x509.Certificate) error { serial := cert.SerialNumber.String() var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)} + isRevoked, err := a.IsRevoked(serial) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) @@ -283,7 +284,6 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error { if isRevoked { return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...) } - p, ok := a.provisioners.LoadByCertificate(cert) if !ok { return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...) diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 6d524a25..74f313e7 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -753,6 +753,7 @@ func TestAuthority_Authorize(t *testing.T) { func TestAuthority_authorizeRenew(t *testing.T) { fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt") + fooCrt.NotAfter = time.Now().Add(time.Hour) assert.FatalError(t, err) renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt") @@ -822,7 +823,7 @@ func TestAuthority_authorizeRenew(t *testing.T) { return &authorizeTest{ auth: a, cert: renewDisabledCrt, - err: errors.New("authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'renew_disabled'"), + err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'renew_disabled'"), code: http.StatusUnauthorized, } }, @@ -909,6 +910,7 @@ func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *provisioner. } func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { + now := time.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) if err != nil { return nil, nil, err @@ -917,6 +919,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, if err != nil { return nil, nil, err } + if cert.ValidAfter == 0 { + cert.ValidAfter = uint64(now.Unix()) + } + if cert.ValidBefore == 0 { + cert.ValidBefore = uint64(now.Add(time.Hour).Unix()) + } if err := cert.SignCert(rand.Reader, signer); err != nil { return nil, nil, err } diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index 21958d36..913d0ace 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -6,7 +6,6 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" ) // ACME is the acme provisioner type, an entity that can authorize the ACME @@ -24,7 +23,7 @@ type ACME struct { RequireEAB bool `json:"requireEAB,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -69,7 +68,7 @@ func (p *ACME) GetOptions() *Options { // DefaultTLSCertDuration returns the default TLS cert duration enforced by // the provisioner. func (p *ACME) DefaultTLSCertDuration() time.Duration { - return p.claimer.DefaultTLSCertDuration() + return p.ctl.Claimer.DefaultTLSCertDuration() } // Init initializes and validates the fields of a JWK type. @@ -81,12 +80,8 @@ func (p *ACME) Init(config Config) (err error) { return errors.New("provisioner name cannot be empty") } - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - return err + p.ctl, err = NewController(p, p.Claims, config) + return } // AuthorizeSign does not do any validation, because all validation is handled @@ -97,10 +92,10 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // modifiers / withOptions newProvisionerExtensionOption(TypeACME, p.Name, ""), newForceCNOption(p.ForceCN), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } @@ -118,8 +113,5 @@ func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error { // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } diff --git a/authority/provisioner/acme_test.go b/authority/provisioner/acme_test.go index bd173f87..86e8a9a9 100644 --- a/authority/provisioner/acme_test.go +++ b/authority/provisioner/acme_test.go @@ -91,6 +91,7 @@ func TestACME_Init(t *testing.T) { } func TestACME_AuthorizeRenew(t *testing.T) { + now := time.Now() type test struct { p *ACME cert *x509.Certificate @@ -104,21 +105,27 @@ func TestACME_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p.Claims = &Claims{DisableRenewal: &disable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ - p: p, - cert: &x509.Certificate{}, + p: p, + cert: &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }, code: http.StatusUnauthorized, - err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()), + err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { p, err := generateACME() assert.FatalError(t, err) return test{ - p: p, - cert: &x509.Certificate{}, + p: p, + cert: &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }, } }, } @@ -179,11 +186,11 @@ func TestACME_AuthorizeSign(t *testing.T) { case *forceCNOption: assert.Equals(t, v.ForceCN, tc.p.ForceCN) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration()) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index fdad7b4a..5f79d7d0 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -264,9 +264,8 @@ type AWS struct { IIDRoots string `json:"iidRoots,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer config *awsConfig - audiences Audiences + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -400,15 +399,11 @@ func (p *AWS) Init(config Config) (err error) { case p.InstanceAge.Value() < 0: return errors.New("provisioner instanceAge cannot be negative") } - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } + // Add default config if p.config, err = newAWSConfig(p.IIDRoots); err != nil { return err } - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) // validate IMDS versions if len(p.IMDSVersions) == 0 { @@ -425,7 +420,9 @@ func (p *AWS) Init(config Config) (err error) { } } - return nil + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // AuthorizeSign validates the given token and returns the sign options that @@ -473,11 +470,11 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, commonNameValidator(payload.Claims.Subject), - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), ), nil } @@ -486,10 +483,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // assertConfig initializes the config if it has not been initialized @@ -664,7 +658,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { } // validate audiences with the defaults - if !matchesAudience(payload.Audience, p.audiences.Sign) { + if !matchesAudience(payload.Audience, p.ctl.Audiences.Sign) { return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)") } @@ -704,7 +698,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName()) } claims, err := p.authorizeToken(token) @@ -752,11 +746,11 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 0d2786db..2e684272 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -682,13 +682,13 @@ func TestAWS_AuthorizeSign(t *testing.T) { assert.Equals(t, v.CredentialID, tt.aws.Accounts[0]) assert.Len(t, 2, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.aws.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.aws.ctl.Claimer.DefaultTLSCertDuration()) case commonNameValidator: assert.Equals(t, string(v), tt.args.cn) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.aws.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.aws.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.aws.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.aws.ctl.Claimer.MaxTLSCertDuration()) case ipAddressesValidator: assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")}) case emailAddressesValidator: @@ -726,7 +726,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p3.Claims = &Claims{EnableSSHCA: &disable} - p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) + p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) assert.FatalError(t, err) t1, err := p1.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com") @@ -747,7 +747,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), @@ -824,6 +824,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { } func TestAWS_AuthorizeRenew(t *testing.T) { + now := time.Now() p1, err := generateAWS() assert.FatalError(t, err) p2, err := generateAWS() @@ -832,7 +833,7 @@ func TestAWS_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -845,8 +846,14 @@ func TestAWS_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index 384617e0..d9654566 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -96,10 +96,10 @@ type Azure struct { DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer config *azureConfig oidcConfig openIDConfiguration keyStore *keyStore + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -203,27 +203,24 @@ func (p *Azure) Init(config Config) (err error) { case p.Audience == "": // use default audience p.Audience = azureDefaultAudience } + // Initialize config p.assertConfig() - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - // Decode and validate openid-configuration endpoint - if err := getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil { - return err + if err = getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil { + return } if err := p.oidcConfig.Validate(); err != nil { return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL) } // Get JWK key set if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil { - return err + return } - return nil + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken returns the claims, name, group, subscription, identityObjectID, error. @@ -355,10 +352,10 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), ), nil } @@ -367,15 +364,12 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName()) } @@ -420,11 +414,11 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index 4ab734d5..c40d0f93 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -511,13 +511,13 @@ func TestAzure_AuthorizeSign(t *testing.T) { assert.Equals(t, v.CredentialID, tt.azure.TenantID) assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.azure.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.azure.ctl.Claimer.DefaultTLSCertDuration()) case commonNameValidator: assert.Equals(t, string(v), "virtualMachine") case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.azure.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.azure.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.azure.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.azure.ctl.Claimer.MaxTLSCertDuration()) case ipAddressesValidator: assert.Equals(t, v, nil) case emailAddressesValidator: @@ -536,6 +536,7 @@ func TestAzure_AuthorizeSign(t *testing.T) { } func TestAzure_AuthorizeRenew(t *testing.T) { + now := time.Now() p1, err := generateAzure() assert.FatalError(t, err) p2, err := generateAzure() @@ -544,7 +545,7 @@ func TestAzure_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -557,8 +558,14 @@ func TestAzure_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -595,7 +602,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p3.Claims = &Claims{EnableSSHCA: &disable} - p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) + p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) assert.FatalError(t, err) t1, err := p1.GetIdentityToken("subject", "caURL") @@ -616,7 +623,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"virtualMachine"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), diff --git a/authority/provisioner/gcp.go b/authority/provisioner/gcp.go index e46f4ce4..6070b640 100644 --- a/authority/provisioner/gcp.go +++ b/authority/provisioner/gcp.go @@ -88,10 +88,9 @@ type GCP struct { InstanceAge Duration `json:"instanceAge,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer config *gcpConfig keyStore *keyStore - audiences Audiences + ctl *Controller } // GetID returns the provisioner unique identifier. The name should uniquely @@ -194,8 +193,7 @@ func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) { } // Init validates and initializes the GCP provisioner. -func (p *GCP) Init(config Config) error { - var err error +func (p *GCP) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -204,20 +202,18 @@ func (p *GCP) Init(config Config) error { case p.InstanceAge.Value() < 0: return errors.New("provisioner instanceAge cannot be negative") } + // Initialize config p.assertConfig() - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } + // Initialize key store - p.keyStore, err = newKeyStore(p.config.CertsURL) - if err != nil { - return err + if p.keyStore, err = newKeyStore(p.config.CertsURL); err != nil { + return } - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) - return nil + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // AuthorizeSign validates the given token and returns the sign options that @@ -269,19 +265,16 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), ), nil } // AuthorizeRenew returns an error if the renewal is disabled. func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // assertConfig initializes the config if it has not been initialized. @@ -328,7 +321,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { } // validate audiences with the defaults - if !matchesAudience(claims.Audience, p.audiences.Sign) { + if !matchesAudience(claims.Audience, p.ctl.Audiences.Sign) { return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)") } @@ -383,7 +376,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName()) } claims, err := p.authorizeToken(token) @@ -431,11 +424,11 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index 5f6f9bc7..2fc7fee0 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -554,13 +554,13 @@ func TestGCP_AuthorizeSign(t *testing.T) { assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0]) assert.Len(t, 4, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.gcp.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.gcp.ctl.Claimer.DefaultTLSCertDuration()) case commonNameSliceValidator: assert.Equals(t, []string(v), []string{"instance-name", "instance-id", "instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.gcp.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.gcp.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.gcp.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.gcp.ctl.Claimer.MaxTLSCertDuration()) case ipAddressesValidator: assert.Equals(t, v, nil) case emailAddressesValidator: @@ -595,7 +595,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p3.Claims = &Claims{EnableSSHCA: &disable} - p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) + p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) assert.FatalError(t, err) t1, err := generateGCPToken(p1.ServiceAccounts[0], @@ -622,7 +622,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), @@ -698,6 +698,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { } func TestGCP_AuthorizeRenew(t *testing.T) { + now := time.Now() p1, err := generateGCP() assert.FatalError(t, err) p2, err := generateGCP() @@ -706,7 +707,7 @@ func TestGCP_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -719,8 +720,14 @@ func TestGCP_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renewal-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renewal-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index 137915c8..764f5d7d 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -35,8 +35,9 @@ type JWK struct { EncryptedKey string `json:"encryptedKey,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer - audiences Audiences + // claimer *Claimer + // audiences Audiences + ctl *Controller } // GetID returns the provisioner unique identifier. The name and credential id @@ -98,13 +99,8 @@ func (p *JWK) Init(config Config) (err error) { return errors.New("provisioner key cannot be empty") } - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - p.audiences = config.Audiences - return err + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken performs common jwt authorization actions and returns the @@ -146,13 +142,13 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err // AuthorizeRevoke returns an error if the provisioner does not have rights to // revoke the certificate with serial number in the `sub` property. func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error { - _, err := p.authorizeToken(token, p.audiences.Revoke) + _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - claims, err := p.authorizeToken(token, p.audiences.Sign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") } @@ -179,12 +175,12 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators commonNameValidator(claims.Subject), defaultPublicKeyValidator{}, defaultSANsValidator(claims.SANs), - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } @@ -193,18 +189,15 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName()) } - claims, err := p.authorizeToken(token, p.audiences.SSHSign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign") } @@ -261,11 +254,11 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, return append(signOptions, // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, ), nil @@ -273,6 +266,6 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // AuthorizeSSHRevoke returns nil if the token is valid, false otherwise. func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error { - _, err := p.authorizeToken(token, p.audiences.SSHRevoke) + _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke) return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke") } diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index deae8f7a..f6b2d93c 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -76,13 +76,13 @@ func TestJWK_Init(t *testing.T) { }, "fail-bad-claims": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences, Claims: &Claims{DefaultTLSDur: &Duration{0}}}, + p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, Claims: &Claims{DefaultTLSDur: &Duration{0}}}, err: errors.New("claims: MinTLSCertDuration must be greater than 0"), } }, "ok": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences}, + p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}}, } }, } @@ -305,13 +305,13 @@ func TestJWK_AuthorizeSign(t *testing.T) { assert.Equals(t, v.CredentialID, tt.prov.Key.KeyID) assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration()) case commonNameValidator: assert.Equals(t, string(v), "subject") case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) case defaultSANsValidator: assert.Equals(t, []string(v), tt.sans) default: @@ -325,6 +325,7 @@ func TestJWK_AuthorizeSign(t *testing.T) { } func TestJWK_AuthorizeRenew(t *testing.T) { + now := time.Now() p1, err := generateJWK() assert.FatalError(t, err) p2, err := generateJWK() @@ -333,7 +334,7 @@ func TestJWK_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -346,8 +347,14 @@ func TestJWK_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -373,7 +380,7 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p2.Claims = &Claims{EnableSSHCA: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) jwk, err := decryptJSONWebKey(p1.EncryptedKey) @@ -402,8 +409,8 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - userDuration := p1.claimer.DefaultUserSSHCertDuration() - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), @@ -485,8 +492,8 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) { signer, err := generateJSONWebKey() assert.FatalError(t, err) - userDuration := p1.claimer.DefaultUserSSHCertDuration() - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), diff --git a/authority/provisioner/k8sSA.go b/authority/provisioner/k8sSA.go index d260f5ec..557d571a 100644 --- a/authority/provisioner/k8sSA.go +++ b/authority/provisioner/k8sSA.go @@ -42,16 +42,15 @@ type k8sSAPayload struct { // entity trusted to make signature requests. type K8sSA struct { *base - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - PubKeys []byte `json:"publicKeys,omitempty"` - Claims *Claims `json:"claims,omitempty"` - Options *Options `json:"options,omitempty"` - claimer *Claimer - audiences Audiences + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` + PubKeys []byte `json:"publicKeys,omitempty"` + Claims *Claims `json:"claims,omitempty"` + Options *Options `json:"options,omitempty"` //kauthn kauthn.AuthenticationV1Interface pubKeys []interface{} + ctl *Controller } // GetID returns the provisioner unique identifier. The name and credential id @@ -138,13 +137,8 @@ func (p *K8sSA) Init(config Config) (err error) { p.kauthn = k8s.AuthenticationV1() */ - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - p.audiences = config.Audiences - return err + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken performs common jwt authorization actions and returns the @@ -211,13 +205,13 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, // AuthorizeRevoke returns an error if the provisioner does not have rights to // revoke the certificate with serial number in the `sub` property. func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error { - _, err := p.authorizeToken(token, p.audiences.Revoke) + _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - claims, err := p.authorizeToken(token, p.audiences.Sign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign") } @@ -240,27 +234,24 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeK8sSA, p.Name, ""), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } // AuthorizeRenew returns an error if the renewal is disabled. func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign validates an request for an SSH certificate. func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()) } - claims, err := p.authorizeToken(token, p.audiences.SSHSign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign") } @@ -282,11 +273,11 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio // Require type, key-id and principals in the SignSSHOptions. &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}, // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/k8sSA_test.go b/authority/provisioner/k8sSA_test.go index 176cdfd3..2f357ebe 100644 --- a/authority/provisioner/k8sSA_test.go +++ b/authority/provisioner/k8sSA_test.go @@ -179,6 +179,7 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) { } func TestK8sSA_AuthorizeRenew(t *testing.T) { + now := time.Now() type test struct { p *K8sSA cert *x509.Certificate @@ -192,21 +193,27 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p.Claims = &Claims{DisableRenewal: &disable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ - p: p, - cert: &x509.Certificate{}, + p: p, + cert: &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }, code: http.StatusUnauthorized, - err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName()), + err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { p, err := generateK8sSA(nil) assert.FatalError(t, err) return test{ - p: p, - cert: &x509.Certificate{}, + p: p, + cert: &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }, } }, } @@ -281,11 +288,11 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { assert.Equals(t, v.CredentialID, "") assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration()) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } @@ -313,7 +320,7 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p.Claims = &Claims{EnableSSHCA: &disable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ p: p, @@ -365,11 +372,11 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { case *sshCertOptionsRequireValidator: assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}) case *sshCertValidityValidator: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) case *sshDefaultPublicKeyValidator: case *sshCertDefaultValidator: case *sshDefaultDuration: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/nebula.go b/authority/provisioner/nebula.go index 72a275ff..11cff219 100644 --- a/authority/provisioner/nebula.go +++ b/authority/provisioner/nebula.go @@ -34,19 +34,18 @@ const ( // https://signal.org/docs/specifications/xeddsa/#xeddsa and implemented by // go.step.sm/crypto/x25519. type Nebula struct { - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - Roots []byte `json:"roots"` - Claims *Claims `json:"claims,omitempty"` - Options *Options `json:"options,omitempty"` - claimer *Claimer - caPool *nebula.NebulaCAPool - audiences Audiences + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` + Roots []byte `json:"roots"` + Claims *Claims `json:"claims,omitempty"` + Options *Options `json:"options,omitempty"` + caPool *nebula.NebulaCAPool + ctl *Controller } // Init verifies and initializes the Nebula provisioner. -func (p *Nebula) Init(config Config) error { +func (p *Nebula) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -56,19 +55,14 @@ func (p *Nebula) Init(config Config) error { return errors.New("provisioner root(s) cannot be empty") } - var err error - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - p.caPool, err = nebula.NewCAPoolFromBytes(p.Roots) if err != nil { return errs.InternalServer("failed to create ca pool: %v", err) } - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) - - return nil + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // GetID returns the provisioner id. @@ -120,7 +114,7 @@ func (p *Nebula) GetEncryptedKey() (kid, key string, ok bool) { // AuthorizeSign returns the list of SignOption for a Sign request. func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - crt, claims, err := p.authorizeToken(token, p.audiences.Sign) + crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, err } @@ -154,7 +148,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // modifiers / withOptions newProvisionerExtensionOption(TypeNebula, p.Name, ""), profileLimitDuration{ - def: p.claimer.DefaultTLSCertDuration(), + def: p.ctl.Claimer.DefaultTLSCertDuration(), notBefore: crt.Details.NotBefore, notAfter: crt.Details.NotAfter, }, @@ -165,18 +159,18 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, IPs: crt.Details.Ips, }, defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // Currently the Nebula provisioner only grants host SSH certificates. func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) } - crt, claims, err := p.authorizeToken(token, p.audiences.SSHSign) + crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, err } @@ -254,11 +248,11 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti return append(signOptions, templateOptions, // Checks the validity bounds, and set the validity if has not been set. - &sshLimitDuration{p.claimer, crt.Details.NotAfter}, + &sshLimitDuration{p.ctl.Claimer, crt.Details.NotAfter}, // Validate public key. &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil @@ -266,7 +260,7 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti // AuthorizeRenew returns an error if the renewal is disabled. func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { + if p.ctl.Claimer.IsDisableRenewal() { return errs.Unauthorized("renew is disabled for nebula provisioner '%s'", p.GetName()) } return nil @@ -274,15 +268,15 @@ func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) erro // AuthorizeRevoke returns an error if the token is not valid. func (p *Nebula) AuthorizeRevoke(ctx context.Context, token string) error { - return p.validateToken(token, p.audiences.Revoke) + return p.validateToken(token, p.ctl.Audiences.Revoke) } // AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid. func (p *Nebula) AuthorizeSSHRevoke(ctx context.Context, token string) error { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) } - if _, _, err := p.authorizeToken(token, p.audiences.SSHRevoke); err != nil { + if _, _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke); err != nil { return err } return nil diff --git a/authority/provisioner/nebula_test.go b/authority/provisioner/nebula_test.go index bc539af1..8f9afd9d 100644 --- a/authority/provisioner/nebula_test.go +++ b/authority/provisioner/nebula_test.go @@ -327,7 +327,7 @@ func TestNebula_GetIDForToken(t *testing.T) { func TestNebula_GetTokenID(t *testing.T) { p, ca, signer := mustNebulaProvisioner(t) c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer) - t1 := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv) + t1 := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv) _, claims, err := parseToken(t1) if err != nil { t.Fatal(err) @@ -441,8 +441,8 @@ func TestNebula_AuthorizeSign(t *testing.T) { ctx := context.TODO() p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv) - okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), nil, crt, priv) + ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv) + okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), nil, crt, priv) pBadOptions, _, _ := mustNebulaProvisioner(t) pBadOptions.caPool = p.caPool @@ -483,20 +483,20 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) { // Ok provisioner p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan", "10.1.0.1"}, }, crt, priv) - okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), nil, crt, priv) - okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), nil, crt, priv) + okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ ValidAfter: NewTimeDuration(now().Add(1 * time.Hour)), ValidBefore: NewTimeDuration(now().Add(10 * time.Hour)), }, crt, priv) - failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ CertType: "user", }, crt, priv) - failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan", "10.1.0.1", "foo.bar"}, @@ -584,12 +584,12 @@ func TestNebula_AuthorizeRevoke(t *testing.T) { // Ok provisioner p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv) + ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv) // Fail different CA nc, signer := mustNebulaCA(t) crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer) - failToken := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv) + failToken := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv) type args struct { ctx context.Context @@ -618,12 +618,12 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) { // Ok provisioner p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv) + ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv) // Fail different CA nc, signer := mustNebulaCA(t) crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer) - failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv) + failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv) // Provisioner with SSH disabled var bFalse bool @@ -657,7 +657,7 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) { func TestNebula_AuthorizeSSHRenew(t *testing.T) { p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRenew[0], now(), nil, crt, priv) + t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRenew[0], now(), nil, crt, priv) type args struct { ctx context.Context @@ -689,7 +689,7 @@ func TestNebula_AuthorizeSSHRenew(t *testing.T) { func TestNebula_AuthorizeSSHRekey(t *testing.T) { p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRekey[0], now(), nil, crt, priv) + t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRekey[0], now(), nil, crt, priv) type args struct { ctx context.Context @@ -726,20 +726,20 @@ func TestNebula_authorizeToken(t *testing.T) { t1 := now() p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) - okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, nil, crt, priv) - okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, &SignSSHOptions{ + ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) + okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, nil, crt, priv) + okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan"}, }, crt, priv) - okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, nil, crt, priv) + okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, nil, crt, priv) // Token with errors - failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv) - failIssuer := mustNebulaToken(t, "test.lan", "foo", p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) + failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv) + failIssuer := mustNebulaToken(t, "test.lan", "foo", p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv) - failSubject := mustNebulaToken(t, "", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) + failSubject := mustNebulaToken(t, "", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) // Not a nebula token jwk, err := generateJSONWebKey() @@ -761,7 +761,7 @@ func TestNebula_authorizeToken(t *testing.T) { IssuedAt: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), - Audience: []string{p.audiences.Sign[0]}, + Audience: []string{p.ctl.Audiences.Sign[0]}, } sshClaims := jose.Claims{ ID: "[REPLACEME]", @@ -770,7 +770,7 @@ func TestNebula_authorizeToken(t *testing.T) { IssuedAt: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), - Audience: []string{p.audiences.SSHSign[0]}, + Audience: []string{p.ctl.Audiences.SSHSign[0]}, } type args struct { @@ -785,14 +785,14 @@ func TestNebula_authorizeToken(t *testing.T) { want1 *jwtPayload wantErr bool }{ - {"ok x509", p, args{ok, p.audiences.Sign}, crt, &jwtPayload{ + {"ok x509", p, args{ok, p.ctl.Audiences.Sign}, crt, &jwtPayload{ Claims: x509Claims, SANs: []string{"10.1.0.1"}, }, false}, - {"ok x509 no sans", p, args{okNoSANs, p.audiences.Sign}, crt, &jwtPayload{ + {"ok x509 no sans", p, args{okNoSANs, p.ctl.Audiences.Sign}, crt, &jwtPayload{ Claims: x509Claims, }, false}, - {"ok ssh", p, args{okSSH, p.audiences.SSHSign}, crt, &jwtPayload{ + {"ok ssh", p, args{okSSH, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{ Claims: sshClaims, Step: &stepPayload{ SSH: &SignSSHOptions{ @@ -802,16 +802,16 @@ func TestNebula_authorizeToken(t *testing.T) { }, }, }, false}, - {"ok ssh no principals", p, args{okSSHNoOptions, p.audiences.SSHSign}, crt, &jwtPayload{ + {"ok ssh no principals", p, args{okSSHNoOptions, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{ Claims: sshClaims, }, false}, - {"fail parse", p, args{"bad.token", p.audiences.Sign}, nil, nil, true}, - {"fail header", p, args{simpleToken, p.audiences.Sign}, nil, nil, true}, - {"fail verify", p2, args{ok, p.audiences.Sign}, nil, nil, true}, - {"fail claims nbf", p, args{failNotBefore, p.audiences.Sign}, nil, nil, true}, - {"fail claims iss", p, args{failIssuer, p.audiences.Sign}, nil, nil, true}, - {"fail claims aud", p, args{failAudience, p.audiences.Sign}, nil, nil, true}, - {"fail claims sub", p, args{failSubject, p.audiences.Sign}, nil, nil, true}, + {"fail parse", p, args{"bad.token", p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail header", p, args{simpleToken, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail verify", p2, args{ok, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail claims nbf", p, args{failNotBefore, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail claims iss", p, args{failIssuer, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail claims aud", p, args{failAudience, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail claims sub", p, args{failSubject, p.ctl.Audiences.Sign}, nil, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go index ac1f2a25..1fc9bb4b 100644 --- a/authority/provisioner/oidc.go +++ b/authority/provisioner/oidc.go @@ -92,8 +92,7 @@ type OIDC struct { Options *Options `json:"options,omitempty"` configuration openIDConfiguration keyStore *keyStore - claimer *Claimer - getIdentityFunc GetIdentityFunc + ctl *Controller } func sanitizeEmail(email string) string { @@ -172,11 +171,6 @@ func (o *OIDC) Init(config Config) (err error) { } } - // Update claims with global ones - if o.claimer, err = NewClaimer(o.Claims, config.Claims); err != nil { - return err - } - // Decode and validate openid-configuration endpoint u, err := url.Parse(o.ConfigurationEndpoint) if err != nil { @@ -201,13 +195,8 @@ func (o *OIDC) Init(config Config) (err error) { return err } - // Set the identity getter if it exists, otherwise use the default. - if config.GetIdentityFunc == nil { - o.getIdentityFunc = DefaultIdentityFunc - } else { - o.getIdentityFunc = config.GetIdentityFunc - } - return nil + o.ctl, err = NewController(o, o.Claims, config) + return } // ValidatePayload validates the given token payload. @@ -359,10 +348,10 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID), - profileDefaultDuration(o.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(o.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()), + newValidityValidator(o.ctl.Claimer.MinTLSCertDuration(), o.ctl.Claimer.MaxTLSCertDuration()), }, nil } @@ -371,15 +360,12 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if o.claimer.IsDisableRenewal() { - return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner '%s'", o.GetName()) - } - return nil + return o.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !o.claimer.IsSSHCAEnabled() { + if !o.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner '%s'", o.GetName()) } claims, err := o.authorizeToken(token) @@ -394,7 +380,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption // Get the identity using either the default identityFunc or one injected // externally. Note that the PreferredUsername might be empty. // TBD: Would preferred_username present a safety issue here? - iden, err := o.getIdentityFunc(ctx, o, claims.Email) + iden, err := o.ctl.GetIdentity(ctx, claims.Email) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign") } @@ -445,11 +431,11 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption return append(signOptions, // Set the validity bounds if not set. - &sshDefaultDuration{o.claimer}, + &sshDefaultDuration{o.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{o.claimer}, + &sshCertValidityValidator{o.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index 7bf6ad7a..cfc789f9 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -332,11 +332,11 @@ func TestOIDC_AuthorizeSign(t *testing.T) { assert.Equals(t, v.CredentialID, tt.prov.ClientID) assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration()) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) case emailOnlyIdentity: assert.Equals(t, string(v), "name@smallstep.com") default: @@ -411,6 +411,7 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) { } func TestOIDC_AuthorizeRenew(t *testing.T) { + now := time.Now() p1, err := generateOIDC() assert.FatalError(t, err) p2, err := generateOIDC() @@ -419,7 +420,7 @@ func TestOIDC_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -432,8 +433,14 @@ func TestOIDC_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -478,7 +485,7 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p6.Claims = &Claims{EnableSSHCA: &disable} - p6.claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims) + p6.ctl.Claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims) assert.FatalError(t, err) // Update configuration endpoints and initialize @@ -494,10 +501,10 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { assert.FatalError(t, p4.Init(config)) assert.FatalError(t, p5.Init(config)) - p4.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { + p4.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { return &Identity{Usernames: []string{"max", "mariano"}}, nil } - p5.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { + p5.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { return nil, errors.New("force") } // Additional test needed for empty usernames and duplicate email and usernames @@ -527,8 +534,8 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - userDuration := p1.claimer.DefaultUserSSHCertDuration() - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SignSSHOptions{ CertType: "user", Principals: []string{"name", "name@smallstep.com"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), diff --git a/authority/provisioner/scep.go b/authority/provisioner/scep.go index 5d67762c..f4cffd78 100644 --- a/authority/provisioner/scep.go +++ b/authority/provisioner/scep.go @@ -11,28 +11,30 @@ import ( // SCEP provisioning flow type SCEP struct { *base - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` ForceCN bool `json:"forceCN,omitempty"` ChallengePassword string `json:"challenge,omitempty"` Capabilities []string `json:"capabilities,omitempty"` + // IncludeRoot makes the provisioner return the CA root in addition to the // intermediate in the GetCACerts response IncludeRoot bool `json:"includeRoot,omitempty"` + // MinimumPublicKeyLength is the minimum length for public keys in CSRs MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"` + // Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7 // at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63 // Defaults to 0, being DES-CBC - EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` - Options *Options `json:"options,omitempty"` - Claims *Claims `json:"claims,omitempty"` - claimer *Claimer + EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` + Options *Options `json:"options,omitempty"` + Claims *Claims `json:"claims,omitempty"` secretChallengePassword string encryptionAlgorithm int + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -77,7 +79,7 @@ func (s *SCEP) GetOptions() *Options { // DefaultTLSCertDuration returns the default TLS cert duration enforced by // the provisioner. func (s *SCEP) DefaultTLSCertDuration() time.Duration { - return s.claimer.DefaultTLSCertDuration() + return s.ctl.Claimer.DefaultTLSCertDuration() } // Init initializes and validates the fields of a SCEP type. @@ -90,11 +92,6 @@ func (s *SCEP) Init(config Config) (err error) { return errors.New("provisioner name cannot be empty") } - // Update claims with global ones - if s.claimer, err = NewClaimer(s.Claims, config.Claims); err != nil { - return err - } - // Mask the actual challenge value, so it won't be marshaled s.secretChallengePassword = s.ChallengePassword s.ChallengePassword = "*** redacted ***" @@ -115,7 +112,8 @@ func (s *SCEP) Init(config Config) (err error) { // TODO: add other, SCEP specific, options? - return err + s.ctl, err = NewController(s, s.Claims, config) + return } // AuthorizeSign does not do any verification, because all verification is handled @@ -126,10 +124,10 @@ func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // modifiers / withOptions newProvisionerExtensionOption(TypeSCEP, s.Name, ""), newForceCNOption(s.ForceCN), - profileDefaultDuration(s.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(s.ctl.Claimer.DefaultTLSCertDuration()), // validators newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength), - newValidityValidator(s.claimer.MinTLSCertDuration(), s.claimer.MaxTLSCertDuration()), + newValidityValidator(s.ctl.Claimer.MinTLSCertDuration(), s.ctl.Claimer.MaxTLSCertDuration()), }, nil } diff --git a/authority/provisioner/sign_ssh_options_test.go b/authority/provisioner/sign_ssh_options_test.go index b59d6945..28a35639 100644 --- a/authority/provisioner/sign_ssh_options_test.go +++ b/authority/provisioner/sign_ssh_options_test.go @@ -685,7 +685,7 @@ func Test_sshCertDefaultValidator_Valid(t *testing.T) { func Test_sshCertValidityValidator(t *testing.T) { p, err := generateX5C(nil) assert.FatalError(t, err) - v := sshCertValidityValidator{p.claimer} + v := sshCertValidityValidator{p.ctl.Claimer} n := now() tests := []struct { name string @@ -806,7 +806,7 @@ func Test_sshValidityModifier(t *testing.T) { tests := map[string]func() test{ "fail/type-not-set": func() test { return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)}, cert: &ssh.Certificate{ ValidAfter: uint64(n.Unix()), ValidBefore: uint64(n.Add(8 * time.Hour).Unix()), @@ -816,7 +816,7 @@ func Test_sshValidityModifier(t *testing.T) { }, "fail/type-not-recognized": func() test { return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)}, cert: &ssh.Certificate{ CertType: 4, ValidAfter: uint64(n.Unix()), @@ -827,7 +827,7 @@ func Test_sshValidityModifier(t *testing.T) { }, "fail/requested-validAfter-after-limit": func() test { return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: uint64(n.Add(2 * time.Hour).Unix()), @@ -838,7 +838,7 @@ func Test_sshValidityModifier(t *testing.T) { }, "fail/requested-validBefore-after-limit": func() test { return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: uint64(n.Unix()), @@ -850,7 +850,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/no-limit": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer}, cert: &ssh.Certificate{ CertType: 1, }, @@ -863,7 +863,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/defaults": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer}, cert: &ssh.Certificate{ CertType: 1, }, @@ -876,7 +876,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/valid-requested-validBefore": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, @@ -891,7 +891,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/empty-requested-validBefore-limit-after-default": func() test { va := uint64(n.Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(24 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(24 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, @@ -905,7 +905,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/empty-requested-validBefore-limit-before-default": func() test { va := uint64(n.Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, diff --git a/authority/provisioner/sshpop.go b/authority/provisioner/sshpop.go index 3039d2a3..a7df38de 100644 --- a/authority/provisioner/sshpop.go +++ b/authority/provisioner/sshpop.go @@ -29,8 +29,7 @@ type SSHPOP struct { Type string `json:"type"` Name string `json:"name"` Claims *Claims `json:"claims,omitempty"` - claimer *Claimer - audiences Audiences + ctl *Controller sshPubKeys *SSHKeys } @@ -83,7 +82,7 @@ func (p *SSHPOP) GetEncryptedKey() (string, string, bool) { } // Init initializes and validates the fields of a SSHPOP type. -func (p *SSHPOP) Init(config Config) error { +func (p *SSHPOP) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -93,15 +92,11 @@ func (p *SSHPOP) Init(config Config) error { return errors.New("provisioner public SSH validation keys cannot be empty") } - // Update claims with global ones - var err error - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) p.sshPubKeys = config.SSHKeys - return nil + + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken performs common jwt authorization actions and returns the @@ -186,7 +181,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa // AuthorizeSSHRevoke validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { - claims, err := p.authorizeToken(token, p.audiences.SSHRevoke) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") } @@ -199,22 +194,20 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { // AuthorizeSSHRenew validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { - claims, err := p.authorizeToken(token, p.audiences.SSHRenew) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRenew) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") } if claims.sshCert.CertType != ssh.HostCert { return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") } - - return claims.sshCert, nil - + return claims.sshCert, p.ctl.AuthorizeSSHRenew(ctx, claims.sshCert) } // AuthorizeSSHRekey validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) { - claims, err := p.authorizeToken(token, p.audiences.SSHRekey) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRekey) if err != nil { return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") } @@ -225,7 +218,7 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, }, nil diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go index da036864..715bf6de 100644 --- a/authority/provisioner/sshpop_test.go +++ b/authority/provisioner/sshpop_test.go @@ -38,6 +38,7 @@ func TestSSHPOP_Getters(t *testing.T) { } func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { + now := time.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) if err != nil { return nil, nil, err @@ -46,6 +47,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, if err != nil { return nil, nil, err } + if cert.ValidAfter == 0 { + cert.ValidAfter = uint64(now.Unix()) + } + if cert.ValidBefore == 0 { + cert.ValidBefore = uint64(now.Add(time.Hour).Unix()) + } if err := cert.SignCert(rand.Reader, signer); err != nil { return nil, nil, err } @@ -455,7 +462,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { case *sshDefaultPublicKeyValidator: case *sshCertDefaultValidator: case *sshCertValidityValidator: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index fe2678fc..ff8421f0 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -24,20 +24,22 @@ import ( ) var ( - defaultDisableRenewal = false - defaultEnableSSHCA = true - globalProvisionerClaims = Claims{ - MinTLSDur: &Duration{5 * time.Minute}, - MaxTLSDur: &Duration{24 * time.Hour}, - DefaultTLSDur: &Duration{24 * time.Hour}, - DisableRenewal: &defaultDisableRenewal, - MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs - MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, - DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour}, - MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs - MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, - DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, - EnableSSHCA: &defaultEnableSSHCA, + defaultDisableRenewal = false + defaultEnableRenewAfterExpiry = false + defaultEnableSSHCA = true + globalProvisionerClaims = Claims{ + MinTLSDur: &Duration{5 * time.Minute}, + MaxTLSDur: &Duration{24 * time.Hour}, + DefaultTLSDur: &Duration{24 * time.Hour}, + MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs + MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, + DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour}, + MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs + MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, + DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, + EnableSSHCA: &defaultEnableSSHCA, + DisableRenewal: &defaultDisableRenewal, + EnableRenewAfterExpiry: &defaultEnableRenewAfterExpiry, } testAudiences = Audiences{ Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"}, @@ -172,19 +174,18 @@ func generateJWK() (*JWK, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - return &JWK{ + + p := &JWK{ Name: name, Type: "JWK", Key: &public, EncryptedKey: encrypted, Claims: &globalProvisionerClaims, - audiences: testAudiences, - claimer: claimer, - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { @@ -205,23 +206,21 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } pubKeys := []interface{}{fooPub, barPub} if inputPubKey != nil { pubKeys = append(pubKeys, inputPubKey) } - return &K8sSA{ - Name: K8sSAName, - Type: "K8sSA", - Claims: &globalProvisionerClaims, - audiences: testAudiences, - claimer: claimer, - pubKeys: pubKeys, - }, nil + p := &K8sSA{ + Name: K8sSAName, + Type: "K8sSA", + Claims: &globalProvisionerClaims, + pubKeys: pubKeys, + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateSSHPOP() (*SSHPOP, error) { @@ -229,11 +228,6 @@ func generateSSHPOP() (*SSHPOP, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub") if err != nil { return nil, err @@ -251,17 +245,19 @@ func generateSSHPOP() (*SSHPOP, error) { return nil, err } - return &SSHPOP{ - Name: name, - Type: "SSHPOP", - Claims: &globalProvisionerClaims, - audiences: testAudiences, - claimer: claimer, + p := &SSHPOP{ + Name: name, + Type: "SSHPOP", + Claims: &globalProvisionerClaims, sshPubKeys: &SSHKeys{ UserKeys: []ssh.PublicKey{userKey}, HostKeys: []ssh.PublicKey{hostKey}, }, - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateX5C(root []byte) (*X5C, error) { @@ -283,11 +279,6 @@ M46l92gdOozT if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - rootPool := x509.NewCertPool() var ( @@ -305,15 +296,17 @@ M46l92gdOozT } rootPool.AddCert(cert) } - return &X5C{ - Name: name, - Type: "X5C", - Roots: root, - Claims: &globalProvisionerClaims, - audiences: testAudiences, - claimer: claimer, - rootPool: rootPool, - }, nil + p := &X5C{ + Name: name, + Type: "X5C", + Roots: root, + Claims: &globalProvisionerClaims, + rootPool: rootPool, + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateOIDC() (*OIDC, error) { @@ -333,11 +326,7 @@ func generateOIDC() (*OIDC, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - return &OIDC{ + p := &OIDC{ Name: name, Type: "OIDC", ClientID: clientID, @@ -351,8 +340,11 @@ func generateOIDC() (*OIDC, error) { keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, expiry: time.Now().Add(24 * time.Hour), }, - claimer: claimer, - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateGCP() (*GCP, error) { @@ -368,23 +360,21 @@ func generateGCP() (*GCP, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - return &GCP{ + p := &GCP{ Type: "GCP", Name: name, ServiceAccounts: []string{serviceAccount}, Claims: &globalProvisionerClaims, - claimer: claimer, config: newGCPConfig(), keyStore: &keyStore{ keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, expiry: time.Now().Add(24 * time.Hour), }, - audiences: testAudiences.WithFragment("gcp/" + name), - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences.WithFragment("gcp/" + name), + }) + return p, err } func generateAWS() (*AWS, error) { @@ -396,10 +386,6 @@ func generateAWS() (*AWS, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } block, _ := pem.Decode([]byte(awsTestCertificate)) if block == nil || block.Type != "CERTIFICATE" { return nil, errors.New("error decoding AWS certificate") @@ -408,13 +394,12 @@ func generateAWS() (*AWS, error) { if err != nil { return nil, errors.Wrap(err, "error parsing AWS certificate") } - return &AWS{ + p := &AWS{ Type: "AWS", Name: name, Accounts: []string{accountID}, Claims: &globalProvisionerClaims, IMDSVersions: []string{"v2", "v1"}, - claimer: claimer, config: &awsConfig{ identityURL: awsIdentityURL, signatureURL: awsSignatureURL, @@ -423,8 +408,11 @@ func generateAWS() (*AWS, error) { certificates: []*x509.Certificate{cert}, signatureAlgorithm: awsSignatureAlgorithm, }, - audiences: testAudiences.WithFragment("aws/" + name), - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences.WithFragment("aws/" + name), + }) + return p, err } func generateAWSWithServer() (*AWS, *httptest.Server, error) { @@ -505,10 +493,6 @@ func generateAWSV1Only() (*AWS, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } block, _ := pem.Decode([]byte(awsTestCertificate)) if block == nil || block.Type != "CERTIFICATE" { return nil, errors.New("error decoding AWS certificate") @@ -517,13 +501,12 @@ func generateAWSV1Only() (*AWS, error) { if err != nil { return nil, errors.Wrap(err, "error parsing AWS certificate") } - return &AWS{ + p := &AWS{ Type: "AWS", Name: name, Accounts: []string{accountID}, Claims: &globalProvisionerClaims, IMDSVersions: []string{"v1"}, - claimer: claimer, config: &awsConfig{ identityURL: awsIdentityURL, signatureURL: awsSignatureURL, @@ -532,8 +515,11 @@ func generateAWSV1Only() (*AWS, error) { certificates: []*x509.Certificate{cert}, signatureAlgorithm: awsSignatureAlgorithm, }, - audiences: testAudiences.WithFragment("aws/" + name), - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences.WithFragment("aws/" + name), + }) + return p, err } func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) { @@ -600,21 +586,16 @@ func generateAzure() (*Azure, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } jwk, err := generateJSONWebKey() if err != nil { return nil, err } - return &Azure{ + p := &Azure{ Type: "Azure", Name: name, TenantID: tenantID, Audience: azureDefaultAudience, Claims: &globalProvisionerClaims, - claimer: claimer, config: newAzureConfig(tenantID), oidcConfig: openIDConfiguration{ Issuer: "https://sts.windows.net/" + tenantID + "/", @@ -624,7 +605,11 @@ func generateAzure() (*Azure, error) { keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, expiry: time.Now().Add(24 * time.Hour), }, - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateAzureWithServer() (*Azure, *httptest.Server, error) { diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index aa44245d..6f534c76 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -26,15 +26,14 @@ type x5cPayload struct { // signature requests. type X5C struct { *base - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - Roots []byte `json:"roots"` - Claims *Claims `json:"claims,omitempty"` - Options *Options `json:"options,omitempty"` - claimer *Claimer - audiences Audiences - rootPool *x509.CertPool + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` + Roots []byte `json:"roots"` + Claims *Claims `json:"claims,omitempty"` + Options *Options `json:"options,omitempty"` + ctl *Controller + rootPool *x509.CertPool } // GetID returns the provisioner unique identifier. The name and credential id @@ -86,7 +85,7 @@ func (p *X5C) GetEncryptedKey() (string, string, bool) { } // Init initializes and validates the fields of a X5C type. -func (p *X5C) Init(config Config) error { +func (p *X5C) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -119,14 +118,9 @@ func (p *X5C) Init(config Config) error { return errors.Errorf("no x509 certificates found in roots attribute for provisioner '%s'", p.GetName()) } - // Update claims with global ones - var err error - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) - return nil + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken performs common jwt authorization actions and returns the @@ -189,13 +183,13 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err // AuthorizeRevoke returns an error if the provisioner does not have rights to // revoke the certificate with serial number in the `sub` property. func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error { - _, err := p.authorizeToken(token, p.audiences.Revoke) + _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - claims, err := p.authorizeToken(token, p.audiences.Sign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign") } @@ -227,31 +221,30 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeX5C, p.Name, ""), - profileLimitDuration{p.claimer.DefaultTLSCertDuration(), - claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter}, + profileLimitDuration{ + p.ctl.Claimer.DefaultTLSCertDuration(), + claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter, + }, // validators commonNameValidator(claims.Subject), defaultSANsValidator(claims.SANs), defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } // AuthorizeRenew returns an error if the renewal is disabled. func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()) } - claims, err := p.authorizeToken(token, p.audiences.SSHSign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign") } @@ -314,11 +307,11 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, return append(signOptions, // Checks the validity bounds, and set the validity if has not been set. - &sshLimitDuration{p.claimer, claims.chains[0][0].NotAfter}, + &sshLimitDuration{p.ctl.Claimer, claims.chains[0][0].NotAfter}, // Validate public key. &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 2959f8c6..330e6e7a 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -2,6 +2,7 @@ package provisioner import ( "context" + "crypto/x509" "net/http" "testing" "time" @@ -69,7 +70,7 @@ func TestX5C_Init(t *testing.T) { }, "fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo"), audiences: testAudiences}, + p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo")}, err: errors.Errorf("no x509 certificates found in roots attribute for provisioner 'foo'"), } }, @@ -141,7 +142,7 @@ M46l92gdOozT } } else { if assert.Nil(t, tc.err) { - assert.Equals(t, tc.p.audiences, config.Audiences.WithFragment(tc.p.GetID())) + assert.Equals(t, *tc.p.ctl.Audiences, config.Audiences.WithFragment(tc.p.GetID())) if tc.extraValid != nil { assert.Nil(t, tc.extraValid(tc.p)) } @@ -473,9 +474,9 @@ func TestX5C_AuthorizeSign(t *testing.T) { assert.Equals(t, v.CredentialID, "") assert.Len(t, 0, v.KeyValuePairs) case profileLimitDuration: - assert.Equals(t, v.def, tc.p.claimer.DefaultTLSCertDuration()) + assert.Equals(t, v.def, tc.p.ctl.Claimer.DefaultTLSCertDuration()) - claims, err := tc.p.authorizeToken(tc.token, tc.p.audiences.Sign) + claims, err := tc.p.authorizeToken(tc.token, tc.p.ctl.Audiences.Sign) assert.FatalError(t, err) assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter) case commonNameValidator: @@ -484,8 +485,8 @@ func TestX5C_AuthorizeSign(t *testing.T) { case defaultSANsValidator: assert.Equals(t, []string(v), tc.sans) case *validityValidator: - assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } @@ -551,6 +552,7 @@ func TestX5C_AuthorizeRevoke(t *testing.T) { } func TestX5C_AuthorizeRenew(t *testing.T) { + now := time.Now() type test struct { p *X5C code int @@ -563,12 +565,12 @@ func TestX5C_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p.Claims = &Claims{DisableRenewal: &disable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ p: p, code: http.StatusUnauthorized, - err: errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName()), + err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { @@ -582,7 +584,10 @@ func TestX5C_AuthorizeRenew(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if err := tc.p.AuthorizeRenew(context.Background(), nil); err != nil { + if err := tc.p.AuthorizeRenew(context.Background(), &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }); err != nil { if assert.NotNil(t, tc.err) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") @@ -618,7 +623,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { // disable sshCA enable := false p.Claims = &Claims{EnableSSHCA: &enable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ p: p, @@ -774,10 +779,10 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { case sshCertDefaultsModifier: assert.Equals(t, SignSSHOptions(v), SignSSHOptions{CertType: SSHUserCert}) case *sshLimitDuration: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter) case *sshCertValidityValidator: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc: default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) diff --git a/authority/tls_test.go b/authority/tls_test.go index aeadaf0f..07538701 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -757,7 +757,7 @@ func TestAuthority_Renew(t *testing.T) { now := time.Now().UTC() nb1 := now.Add(-time.Minute * 7) - na1 := now + na1 := now.Add(time.Hour) so := &provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(nb1), NotAfter: provisioner.NewTimeDuration(na1), @@ -798,7 +798,7 @@ func TestAuthority_Renew(t *testing.T) { "fail/unauthorized": func() (*renewTest, error) { return &renewTest{ cert: certNoRenew, - err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"), + err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"), code: http.StatusUnauthorized, }, nil }, @@ -856,7 +856,7 @@ func TestAuthority_Renew(t *testing.T) { expiry := now.Add(time.Minute * 7) assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute))) - assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) + assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour))) tmplt := a.config.AuthorityConfig.Template assert.Equals(t, leaf.Subject.String(), @@ -956,7 +956,7 @@ func TestAuthority_Rekey(t *testing.T) { now := time.Now().UTC() nb1 := now.Add(-time.Minute * 7) - na1 := now + na1 := now.Add(time.Hour) so := &provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(nb1), NotAfter: provisioner.NewTimeDuration(na1), @@ -998,7 +998,7 @@ func TestAuthority_Rekey(t *testing.T) { "fail/unauthorized": func() (*renewTest, error) { return &renewTest{ cert: certNoRenew, - err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"), + err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"), code: http.StatusUnauthorized, }, nil }, @@ -1063,7 +1063,7 @@ func TestAuthority_Rekey(t *testing.T) { expiry := now.Add(time.Minute * 7) assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute))) - assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) + assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour))) tmplt := a.config.AuthorityConfig.Template assert.Equals(t, leaf.Subject.String(), From afb5d362061cc327309d4e1af8b5647ffc173711 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 9 Mar 2022 20:37:41 -0800 Subject: [PATCH 10/44] Allow to renew certificates using an x5c-like token. --- api/api.go | 2 +- api/api_test.go | 94 +++++++++++++++++++++++++++++++++++++++++++++---- api/renew.go | 53 ++++++++++++++++++++++++++-- 3 files changed, 138 insertions(+), 11 deletions(-) diff --git a/api/api.go b/api/api.go index 16e24bb2..c61e447f 100644 --- a/api/api.go +++ b/api/api.go @@ -43,7 +43,7 @@ type Authority interface { GetProvisioners(cursor string, limit int) (provisioner.List, string, error) Revoke(context.Context, *authority.RevokeOptions) error GetEncryptedKey(kid string) (string, error) - GetRoots() (federation []*x509.Certificate, err error) + GetRoots() ([]*x509.Certificate, error) GetFederation() ([]*x509.Certificate, error) Version() authority.Version } diff --git a/api/api_test.go b/api/api_test.go index c7528f9b..f2184596 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -13,6 +13,7 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/base64" "encoding/json" "encoding/pem" "fmt" @@ -34,6 +35,7 @@ import ( "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/templates" "go.step.sm/crypto/jose" + "go.step.sm/crypto/x509util" "golang.org/x/crypto/ssh" ) @@ -920,32 +922,104 @@ func Test_caHandler_Renew(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } + + // Prepare root and leaf for renew after expiry test. + now := time.Now() + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + leafPub, leafPriv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + root := &x509.Certificate{ + Subject: pkix.Name{CommonName: "Test Root CA"}, + PublicKey: rootPub, + KeyUsage: x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IsCA: true, + NotBefore: now.Add(-2 * time.Hour), + NotAfter: now.Add(time.Hour), + } + root, err = x509util.CreateCertificate(root, root, rootPub, rootPriv) + if err != nil { + t.Fatal(err) + } + expiredLeaf := &x509.Certificate{ + Subject: pkix.Name{CommonName: "Leaf certificate"}, + PublicKey: leafPub, + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + EmailAddresses: []string{"test@example.org"}, + } + expiredLeaf, err = x509util.CreateCertificate(expiredLeaf, root, leafPub, rootPriv) + if err != nil { + t.Fatal(err) + } + + // Generate renew after expiry token + so := new(jose.SignerOptions) + so.WithType("JWT") + so.WithHeader("x5cInsecure", []string{base64.StdEncoding.EncodeToString(expiredLeaf.Raw)}) + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: leafPriv}, so) + if err != nil { + t.Fatal(err) + } + generateX5cToken := func(claims jose.Claims) string { + s, err := jose.Signed(sig).Claims(claims).CompactSerialize() + if err != nil { + t.Fatal(err) + } + return s + } + tests := []struct { name string tls *tls.ConnectionState + header http.Header cert *x509.Certificate root *x509.Certificate err error statusCode int }{ - {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, - {"no tls", nil, nil, nil, nil, http.StatusBadRequest}, - {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, - {"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, + {"ok", cs, nil, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, + {"ok renew after expiry", &tls.ConnectionState{}, http.Header{ + "Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{ + NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + })}, + }, expiredLeaf, root, nil, http.StatusCreated}, + {"no tls", nil, nil, nil, nil, nil, http.StatusBadRequest}, + {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, nil, http.StatusBadRequest}, + {"renew error", cs, nil, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, + {"fail expired token", &tls.ConnectionState{}, http.Header{ + "Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{ + NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)), + })}, + }, expiredLeaf, root, errs.Forbidden("an error"), http.StatusUnauthorized}, + {"fail invalid root", &tls.ConnectionState{}, http.Header{ + "Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{ + NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)), + })}, + }, expiredLeaf, parseCertificate(rootPEM), errs.Forbidden("an error"), http.StatusUnauthorized}, } - expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, + getRoots: func() ([]*x509.Certificate, error) { + return []*x509.Certificate{tt.root}, nil + }, getTLSOptions: func() *authority.TLSOptions { return nil }, }).(*caHandler) req := httptest.NewRequest("POST", "http://example.com/renew", nil) req.TLS = tt.tls + req.Header = tt.header w := httptest.NewRecorder() h.Renew(logging.NewResponseLogger(w), req) res := w.Result() @@ -960,8 +1034,14 @@ func Test_caHandler_Renew(t *testing.T) { t.Errorf("caHandler.Renew unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { + expected := []byte(`{"crt":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `",` + + `"ca":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `",` + + `"certChain":["` + + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `","` + + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `"]}`) + if !bytes.Equal(bytes.TrimSpace(body), expected) { - t.Errorf("caHandler.Root Body = %s, wants %s", body, expected) + t.Errorf("caHandler.Root Body = \n%s, wants \n%s", body, expected) } } }) diff --git a/api/renew.go b/api/renew.go index 725322ee..a7449ba1 100644 --- a/api/renew.go +++ b/api/renew.go @@ -1,20 +1,30 @@ package api import ( + "crypto/x509" "net/http" + "strings" + "time" "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" +) + +const ( + authorizationHeader = "Authorization" + bearerScheme = "Bearer" ) // Renew uses the information of certificate in the TLS connection to create a // new one. func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { - if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing client certificate")) + cert, err := h.getPeerCertificate(r) + if err != nil { + WriteError(w, err) return } - certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0]) + certChain, err := h.Authority.Renew(cert) if err != nil { WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) return @@ -33,3 +43,40 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { TLSOptions: h.Authority.GetTLSOptions(), }, http.StatusCreated) } + +func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) { + if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { + return r.TLS.PeerCertificates[0], nil + } + + if s := r.Header.Get(authorizationHeader); s != "" { + if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 { + roots, err := h.Authority.GetRoots() + if err != nil { + return nil, errs.BadRequestErr(err, "missing client certificate") + } + jwt, chain, err := jose.ParseX5cInsecure(parts[1], roots) + if err != nil { + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating client certificate")) + } + + var claims jose.Claims + leaf := chain[0][0] + if err := jwt.Claims(leaf.PublicKey, &claims); err != nil { + return nil, errs.InternalServerErr(err, errs.WithMessage("error validating client certificate")) + } + + // According to "rfc7519 JSON Web Token" acceptable skew should be no + // more than a few minutes. + if err = claims.ValidateWithLeeway(jose.Expected{ + Time: time.Now().UTC(), + }, time.Minute); err != nil { + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating client certificate")) + } + + return leaf, nil + } + } + + return nil, errs.BadRequest("missing client certificate") +} From 8ef8f4f665248476b33fa600422713082c346fce Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 10 Mar 2022 10:45:12 -0800 Subject: [PATCH 11/44] Use the provisioner controller in Nebula renewals --- authority/provisioner/nebula.go | 5 +---- authority/provisioner/nebula_test.go | 12 ++++++++++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/authority/provisioner/nebula.go b/authority/provisioner/nebula.go index 11cff219..1a6eee3e 100644 --- a/authority/provisioner/nebula.go +++ b/authority/provisioner/nebula.go @@ -260,10 +260,7 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti // AuthorizeRenew returns an error if the renewal is disabled. func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) error { - if p.ctl.Claimer.IsDisableRenewal() { - return errs.Unauthorized("renew is disabled for nebula provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, crt) } // AuthorizeRevoke returns an error if the token is not valid. diff --git a/authority/provisioner/nebula_test.go b/authority/provisioner/nebula_test.go index 8f9afd9d..b190d607 100644 --- a/authority/provisioner/nebula_test.go +++ b/authority/provisioner/nebula_test.go @@ -549,6 +549,8 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) { func TestNebula_AuthorizeRenew(t *testing.T) { ctx := context.TODO() + now := time.Now().Truncate(time.Second) + // Ok provisioner p, _, _ := mustNebulaProvisioner(t) @@ -567,8 +569,14 @@ func TestNebula_AuthorizeRenew(t *testing.T) { args args wantErr bool }{ - {"ok", p, args{ctx, &x509.Certificate{}}, false}, - {"fail disabled", pDisabled, args{ctx, &x509.Certificate{}}, true}, + {"ok", p, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"fail disabled", pDisabled, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From 389815642d709b549cc443c9362c9fec3f3cf29c Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 10 Mar 2022 10:46:28 -0800 Subject: [PATCH 12/44] Fix tests: certs are truncated to seconds. --- authority/provisioner/acme_test.go | 2 +- authority/provisioner/aws_test.go | 2 +- authority/provisioner/azure_test.go | 2 +- authority/provisioner/controller.go | 2 +- authority/provisioner/controller_test.go | 4 ++-- authority/provisioner/gcp_test.go | 2 +- authority/provisioner/jwk_test.go | 2 +- authority/provisioner/k8sSA_test.go | 2 +- authority/provisioner/oidc_test.go | 2 +- authority/provisioner/x5c_test.go | 2 +- 10 files changed, 11 insertions(+), 11 deletions(-) diff --git a/authority/provisioner/acme_test.go b/authority/provisioner/acme_test.go index 86e8a9a9..a74ef76e 100644 --- a/authority/provisioner/acme_test.go +++ b/authority/provisioner/acme_test.go @@ -91,7 +91,7 @@ func TestACME_Init(t *testing.T) { } func TestACME_AuthorizeRenew(t *testing.T) { - now := time.Now() + now := time.Now().Truncate(time.Second) type test struct { p *ACME cert *x509.Certificate diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 2e684272..3c6f8741 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -824,7 +824,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { } func TestAWS_AuthorizeRenew(t *testing.T) { - now := time.Now() + now := time.Now().Truncate(time.Second) p1, err := generateAWS() assert.FatalError(t, err) p2, err := generateAWS() diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index c40d0f93..da342ea4 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -536,7 +536,7 @@ func TestAzure_AuthorizeSign(t *testing.T) { } func TestAzure_AuthorizeRenew(t *testing.T) { - now := time.Now() + now := time.Now().Truncate(time.Second) p1, err := generateAzure() assert.FatalError(t, err) p2, err := generateAzure() diff --git a/authority/provisioner/controller.go b/authority/provisioner/controller.go index 815482f9..97ebe8f8 100644 --- a/authority/provisioner/controller.go +++ b/authority/provisioner/controller.go @@ -122,7 +122,7 @@ func DefaultAuthorizeRenew(ctx context.Context, p *Controller, cert *x509.Certif now := time.Now().Truncate(time.Second) if now.Before(cert.NotBefore) { - return errs.Unauthorized("certificate is not yet valid") + return errs.Unauthorized("certificate is not yet valid" + " " + now.UTC().Format(time.RFC3339Nano) + " vs " + cert.NotBefore.Format(time.RFC3339Nano)) } if now.After(cert.NotAfter) && !p.Claimer.IsRenewAfterExpiry() { return errs.Unauthorized("certificate has expired") diff --git a/authority/provisioner/controller_test.go b/authority/provisioner/controller_test.go index 68f7055c..bbf7cb81 100644 --- a/authority/provisioner/controller_test.go +++ b/authority/provisioner/controller_test.go @@ -134,7 +134,7 @@ func TestController_GetIdentity(t *testing.T) { func TestController_AuthorizeRenew(t *testing.T) { ctx := context.Background() - now := time.Now() + now := time.Now().Truncate(time.Second) type fields struct { Interface Interface Claimer *Claimer @@ -276,7 +276,7 @@ func TestController_AuthorizeSSHRenew(t *testing.T) { func TestDefaultAuthorizeRenew(t *testing.T) { ctx := context.Background() - now := time.Now() + now := time.Now().Truncate(time.Second) type args struct { ctx context.Context p *Controller diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index 2fc7fee0..94fbd576 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -698,7 +698,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { } func TestGCP_AuthorizeRenew(t *testing.T) { - now := time.Now() + now := time.Now().Truncate(time.Second) p1, err := generateGCP() assert.FatalError(t, err) p2, err := generateGCP() diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index f6b2d93c..bf5c3d2c 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -325,7 +325,7 @@ func TestJWK_AuthorizeSign(t *testing.T) { } func TestJWK_AuthorizeRenew(t *testing.T) { - now := time.Now() + now := time.Now().Truncate(time.Second) p1, err := generateJWK() assert.FatalError(t, err) p2, err := generateJWK() diff --git a/authority/provisioner/k8sSA_test.go b/authority/provisioner/k8sSA_test.go index 2f357ebe..0a82e8ef 100644 --- a/authority/provisioner/k8sSA_test.go +++ b/authority/provisioner/k8sSA_test.go @@ -179,7 +179,7 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) { } func TestK8sSA_AuthorizeRenew(t *testing.T) { - now := time.Now() + now := time.Now().Truncate(time.Second) type test struct { p *K8sSA cert *x509.Certificate diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index cfc789f9..62082fb2 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -411,7 +411,7 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) { } func TestOIDC_AuthorizeRenew(t *testing.T) { - now := time.Now() + now := time.Now().Truncate(time.Second) p1, err := generateOIDC() assert.FatalError(t, err) p2, err := generateOIDC() diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 330e6e7a..18a31b04 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -552,7 +552,7 @@ func TestX5C_AuthorizeRevoke(t *testing.T) { } func TestX5C_AuthorizeRenew(t *testing.T) { - now := time.Now() + now := time.Now().Truncate(time.Second) type test struct { p *X5C code int From 79349b4d7c051e30db03af3d2a6e4a51356df7fa Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 10 Mar 2022 13:01:08 -0800 Subject: [PATCH 13/44] Add options to use custom renewal methods. --- authority/authority.go | 10 +++--- authority/authorize_test.go | 62 +++++++++++++++++++++++++------------ authority/options.go | 18 +++++++++++ authority/provisioners.go | 4 ++- authority/tls_test.go | 24 ++++++++++++++ 5 files changed, 93 insertions(+), 25 deletions(-) diff --git a/authority/authority.go b/authority/authority.go index f396c588..cc26635e 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -70,10 +70,12 @@ type Authority struct { startTime time.Time // Custom functions - sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error) - sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error) - sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) - getIdentityFunc provisioner.GetIdentityFunc + sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error) + sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error) + sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) + getIdentityFunc provisioner.GetIdentityFunc + authorizeRenewFunc provisioner.AuthorizeRenewFunc + authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc adminMutex sync.RWMutex } diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 74f313e7..b0ab04ec 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -1011,6 +1011,23 @@ func TestAuthority_authorizeSSHSign(t *testing.T) { } func TestAuthority_authorizeSSHRenew(t *testing.T) { + now := time.Now().UTC() + sshpop := func(a *Authority) (*ssh.Certificate, string) { + p, ok := a.provisioners.Load("sshpop/sshpop") + assert.Fatal(t, ok, "sshpop provisioner not found in test authority") + key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") + assert.FatalError(t, err) + signer, ok := key.(crypto.Signer) + assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") + sshSigner, err := ssh.NewSignerFromSigner(signer) + assert.FatalError(t, err) + cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) + assert.FatalError(t, err) + token, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", []string{"foo.smallstep.com"}, now, jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + return cert, token + } + a := testAuthority(t) jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) @@ -1020,8 +1037,6 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) { (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) assert.FatalError(t, err) - now := time.Now().UTC() - validIssuer := "step-cli" type authorizeTest struct { @@ -1058,27 +1073,34 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) { code: http.StatusUnauthorized, } }, + "fail/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest { + aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error { + return errs.Forbidden("forbidden") + })) + _, token := sshpop(aa) + return &authorizeTest{ + auth: aa, + token: token, + err: errors.New("authority.authorizeSSHRenew: forbidden"), + code: http.StatusForbidden, + } + }, "ok": func(t *testing.T) *authorizeTest { - key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") - assert.FatalError(t, err) - signer, ok := key.(crypto.Signer) - assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") - sshSigner, err := ssh.NewSignerFromSigner(signer) - assert.FatalError(t, err) - - cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) - assert.FatalError(t, err) - - p, ok := a.provisioners.Load("sshpop/sshpop") - assert.Fatal(t, ok, "sshpop provisioner not found in test authority") - - tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", - []string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert)) - assert.FatalError(t, err) - + cert, token := sshpop(a) return &authorizeTest{ auth: a, - token: tok, + token: token, + cert: cert, + } + }, + "ok/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest { + aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error { + return nil + })) + cert, token := sshpop(aa) + return &authorizeTest{ + auth: aa, + token: token, cert: cert, } }, diff --git a/authority/options.go b/authority/options.go index f92db99b..a1238b1d 100644 --- a/authority/options.go +++ b/authority/options.go @@ -92,6 +92,24 @@ func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, e } } +// WithAuthorizeRenewFunc sets a custom function that authorizes the renewal of +// an X.509 certificate. +func WithAuthorizeRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error) Option { + return func(a *Authority) error { + a.authorizeRenewFunc = fn + return nil + } +} + +// WithAuthorizeSSHRenewFunc sets a custom function that authorizes the renewal +// of a SSH certificate. +func WithAuthorizeSSHRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error) Option { + return func(a *Authority) error { + a.authorizeSSHRenewFunc = fn + return nil + } +} + // WithSSHBastionFunc sets a custom function to get the bastion for a // given user-host pair. func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*config.Bastion, error)) Option { diff --git a/authority/provisioners.go b/authority/provisioners.go index 8dc27c6a..780d12c0 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -108,7 +108,9 @@ func (a *Authority) generateProvisionerConfig(ctx context.Context) (provisioner. UserKeys: sshKeys.UserKeys, HostKeys: sshKeys.HostKeys, }, - GetIdentityFunc: a.getIdentityFunc, + GetIdentityFunc: a.getIdentityFunc, + AuthorizeRenewFunc: a.authorizeRenewFunc, + AuthorizeSSHRenewFunc: a.authorizeSSHRenewFunc, }, nil } diff --git a/authority/tls_test.go b/authority/tls_test.go index 07538701..6ccf02ca 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -802,6 +802,19 @@ func TestAuthority_Renew(t *testing.T) { code: http.StatusUnauthorized, }, nil }, + "fail/WithAuthorizeRenewFunc": func() (*renewTest, error) { + aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error { + return errs.Unauthorized("not authorized") + })) + aa.x509CAService = a.x509CAService + aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template + return &renewTest{ + auth: aa, + cert: cert, + err: errors.New("authority.Rekey: authority.authorizeRenew: not authorized"), + code: http.StatusUnauthorized, + }, nil + }, "ok": func() (*renewTest, error) { return &renewTest{ auth: a, @@ -820,6 +833,17 @@ func TestAuthority_Renew(t *testing.T) { cert: cert, }, nil }, + "ok/WithAuthorizeRenewFunc": func() (*renewTest, error) { + aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error { + return nil + })) + aa.x509CAService = a.x509CAService + aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template + return &renewTest{ + auth: aa, + cert: cert, + }, nil + }, } for name, genTestCase := range tests { From 41ea67ce1091400aa4c1a892e16de0e13c8b16f6 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 10 Mar 2022 13:01:31 -0800 Subject: [PATCH 14/44] Attempt to fix a bootstrap tests --- ca/bootstrap_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 9482d657..0e16bd7d 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -408,6 +408,7 @@ func TestBootstrapClientServerRotation(t *testing.T) { server.ServeTLS(listener, "", "") }() defer server.Close() + time.Sleep(1 * time.Second) // Create bootstrap client token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") @@ -419,7 +420,6 @@ func TestBootstrapClientServerRotation(t *testing.T) { // doTest does a request that requires mTLS doTest := func(client *http.Client) error { - time.Sleep(1 * time.Second) // test with ca resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody) if err != nil { From 616490a9c6b8fc2b914c000c6ac3a7e714bbd939 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 10 Mar 2022 20:21:01 -0800 Subject: [PATCH 15/44] Refactor renew after expiry token authorization This changes adds a new authority method that authorizes the renew after expiry tokens. --- api/api.go | 1 + api/api_test.go | 37 +++- api/renew.go | 29 +-- authority/authorize.go | 78 ++++++++ authority/authorize_test.go | 285 +++++++++++++++++++++++++++ authority/config/config.go | 36 ++-- authority/provisioner/jwk.go | 4 +- authority/provisioner/provisioner.go | 10 + ca/client.go | 30 +++ ca/client_test.go | 68 +++++++ go.mod | 2 +- go.sum | 4 +- 12 files changed, 527 insertions(+), 57 deletions(-) diff --git a/api/api.go b/api/api.go index c61e447f..7786bd0d 100644 --- a/api/api.go +++ b/api/api.go @@ -33,6 +33,7 @@ type Authority interface { // context specifies the Authorize[Sign|Revoke|etc.] method. Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) AuthorizeSign(ott string) ([]provisioner.SignOption, error) + AuthorizeRenewToken(ctx context.Context, token string) (*x509.Certificate, error) GetTLSOptions() *config.TLSOptions Root(shasum string) (*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) diff --git a/api/api_test.go b/api/api_test.go index f2184596..717621cd 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -173,6 +173,7 @@ type mockAuthority struct { ret1, ret2 interface{} err error authorizeSign func(ott string) ([]provisioner.SignOption, error) + authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error) getTLSOptions func() *authority.TLSOptions root func(shasum string) (*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) @@ -210,6 +211,13 @@ func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, err return m.ret1.([]provisioner.SignOption), m.err } +func (m *mockAuthority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) { + if m.authorizeRenewToken != nil { + return m.authorizeRenewToken(ctx, ott) + } + return m.ret1.(*x509.Certificate), m.err +} + func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions { if m.getTLSOptions != nil { return m.getTLSOptions() @@ -1010,8 +1018,21 @@ func Test_caHandler_Renew(t *testing.T) { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, - getRoots: func() ([]*x509.Certificate, error) { - return []*x509.Certificate{tt.root}, nil + authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) { + jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root}) + if err != nil { + return nil, errs.Unauthorized(err.Error()) + } + var claims jose.Claims + if err := jwt.Claims(chain[0][0].PublicKey, &claims); err != nil { + return nil, errs.Unauthorized(err.Error()) + } + if err := claims.ValidateWithLeeway(jose.Expected{ + Time: now, + }, time.Minute); err != nil { + return nil, errs.Unauthorized(err.Error()) + } + return chain[0][0], nil }, getTLSOptions: func() *authority.TLSOptions { return nil @@ -1022,17 +1043,19 @@ func Test_caHandler_Renew(t *testing.T) { req.Header = tt.header w := httptest.NewRecorder() h.Renew(logging.NewResponseLogger(w), req) - res := w.Result() - if res.StatusCode != tt.statusCode { - t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) - } + res := w.Result() + defer res.Body.Close() body, err := io.ReadAll(res.Body) - res.Body.Close() if err != nil { t.Errorf("caHandler.Renew unexpected error = %v", err) } + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + t.Errorf("%s", body) + } + if tt.statusCode < http.StatusBadRequest { expected := []byte(`{"crt":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `",` + `"ca":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `",` + diff --git a/api/renew.go b/api/renew.go index a7449ba1..408d91a3 100644 --- a/api/renew.go +++ b/api/renew.go @@ -4,10 +4,8 @@ import ( "crypto/x509" "net/http" "strings" - "time" "github.com/smallstep/certificates/errs" - "go.step.sm/crypto/jose" ) const ( @@ -48,35 +46,10 @@ func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, erro if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { return r.TLS.PeerCertificates[0], nil } - if s := r.Header.Get(authorizationHeader); s != "" { if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 { - roots, err := h.Authority.GetRoots() - if err != nil { - return nil, errs.BadRequestErr(err, "missing client certificate") - } - jwt, chain, err := jose.ParseX5cInsecure(parts[1], roots) - if err != nil { - return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating client certificate")) - } - - var claims jose.Claims - leaf := chain[0][0] - if err := jwt.Claims(leaf.PublicKey, &claims); err != nil { - return nil, errs.InternalServerErr(err, errs.WithMessage("error validating client certificate")) - } - - // According to "rfc7519 JSON Web Token" acceptable skew should be no - // more than a few minutes. - if err = claims.ValidateWithLeeway(jose.Expected{ - Time: time.Now().UTC(), - }, time.Minute); err != nil { - return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating client certificate")) - } - - return leaf, nil + return h.Authority.AuthorizeRenewToken(r.Context(), parts[1]) } } - return nil, errs.BadRequest("missing client certificate") } diff --git a/authority/authorize.go b/authority/authorize.go index 4f64921b..7c1c2ff6 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "encoding/hex" "net/http" + "net/url" "strconv" "strings" "time" @@ -371,3 +372,80 @@ func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error } return nil } + +// AuthorizeRenewToken validates the renew token and returns the leaf +// certificate in the x5cInsecure header. +func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) { + var claims jose.Claims + jwt, chain, err := jose.ParseX5cInsecure(ott, a.rootX509Certs) + if err != nil { + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token")) + } + leaf := chain[0][0] + if err := jwt.Claims(leaf.PublicKey, &claims); err != nil { + return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token")) + } + + p, ok := a.provisioners.LoadByCertificate(leaf) + if !ok { + return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate") + } + if err := a.UseToken(ott, p); err != nil { + return nil, err + } + + if err := claims.ValidateWithLeeway(jose.Expected{ + Issuer: p.GetName(), + Subject: leaf.Subject.CommonName, + Time: time.Now().UTC(), + }, time.Minute); err != nil { + switch err { + case jose.ErrInvalidIssuer: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid issuer claim (iss)")) + case jose.ErrInvalidSubject: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid subject claim (sub)")) + case jose.ErrNotValidYet: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token not valid yet (nbf)")) + case jose.ErrExpired: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token is expired (exp)")) + case jose.ErrIssuedInTheFuture: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token issued in the future (iat)")) + default: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token")) + } + } + + audiences := a.config.GetAudiences().Renew + if !matchesAudience(claims.Audience, audiences) { + return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token: invalid audience claim (aud)")) + } + + return leaf, nil +} + +// matchesAudience returns true if A and B share at least one element. +func matchesAudience(as, bs []string) bool { + if len(bs) == 0 || len(as) == 0 { + return false + } + + for _, b := range bs { + for _, a := range as { + if b == a || stripPort(a) == stripPort(b) { + return true + } + } + } + return false +} + +// stripPort attempts to strip the port from the given url. If parsing the url +// produces errors it will just return the passed argument. +func stripPort(rawurl string) string { + u, err := url.Parse(rawurl) + if err != nil { + return rawurl + } + u.Host = u.Hostname() + return u.String() +} diff --git a/authority/authorize_test.go b/authority/authorize_test.go index b0ab04ec..b631741a 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -3,11 +3,15 @@ package authority import ( "context" "crypto" + "crypto/ed25519" "crypto/rand" "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" "encoding/base64" "fmt" "net/http" + "reflect" "strconv" "testing" "time" @@ -20,6 +24,7 @@ import ( "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" + "go.step.sm/crypto/x509util" "golang.org/x/crypto/ssh" ) @@ -1320,3 +1325,283 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) { }) } } + +func TestAuthority_AuthorizeRenewToken(t *testing.T) { + ctx := context.Background() + type stepProvisionerASN1 struct { + Type int + Name []byte + CredentialID []byte + KeyValuePairs []string `asn1:"optional,omitempty"` + } + + _, signer, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + csr, err := x509util.CreateCertificateRequest("test.example.com", []string{"test.example.com"}, signer) + if err != nil { + t.Fatal(err) + } + _, otherSigner, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + + generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) { + chain, err := a.Sign(csr, provisioner.SignOptions{}, opts...) + if err != nil { + t.Fatal(err) + } + + var x5c []string + for _, c := range chain { + x5c = append(x5c, base64.StdEncoding.EncodeToString(c.Raw)) + } + + so := new(jose.SignerOptions) + so.WithType("JWT") + so.WithHeader("x5cInsecure", x5c) + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: key}, so) + if err != nil { + t.Fatal(err) + } + s, err := jose.Signed(sig).Claims(claims).CompactSerialize() + if err != nil { + t.Fatal(err) + } + return s, chain[0] + } + + now := time.Now() + a1 := testAuthority(t) + t1, c1 := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + t2, c2 := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + IssuedAt: jose.NewNumericDate(now), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now.Add(-time.Hour) + cert.NotAfter = now.Add(-time.Minute) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badProvisioner, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badIssuer, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "bad-issuer", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badSubject, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "bad-subject", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badNotBefore, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/sign"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now.Add(5 * time.Minute)), + Expiry: jose.NewNumericDate(now.Add(10 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badExpiry, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/sign"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now.Add(-5 * time.Minute)), + Expiry: jose.NewNumericDate(now.Add(-time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badIssuedAt, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/sign"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + IssuedAt: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badAudience, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/sign"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + + type args struct { + ctx context.Context + ott string + } + tests := []struct { + name string + authority *Authority + args args + want *x509.Certificate + wantErr bool + }{ + {"ok", a1, args{ctx, t1}, c1, false}, + {"ok expired cert", a1, args{ctx, t2}, c2, false}, + {"fail token", a1, args{ctx, "not.a.token"}, nil, true}, + {"fail token reuse", a1, args{ctx, t1}, nil, true}, + {"fail token signature", a1, args{ctx, badSigner}, nil, true}, + {"fail token provisioner", a1, args{ctx, badProvisioner}, nil, true}, + {"fail token iss", a1, args{ctx, badIssuer}, nil, true}, + {"fail token sub", a1, args{ctx, badSubject}, nil, true}, + {"fail token iat", a1, args{ctx, badNotBefore}, nil, true}, + {"fail token iat", a1, args{ctx, badExpiry}, nil, true}, + {"fail token iat", a1, args{ctx, badIssuedAt}, nil, true}, + {"fail token aud", a1, args{ctx, badAudience}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.authority.AuthorizeRenewToken(tt.args.ctx, tt.args.ott) + if (err != nil) != tt.wantErr { + t.Errorf("Authority.AuthorizeRenewToken() error = %+v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.AuthorizeRenewToken() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/config/config.go b/authority/config/config.go index c33a2b1d..e4fcc863 100644 --- a/authority/config/config.go +++ b/authority/config/config.go @@ -272,28 +272,32 @@ func (c *Config) GetAudiences() provisioner.Audiences { } for _, name := range c.DNSNames { + hostname := toHostname(name) audiences.Sign = append(audiences.Sign, - fmt.Sprintf("https://%s/1.0/sign", toHostname(name)), - fmt.Sprintf("https://%s/sign", toHostname(name)), - fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)), - fmt.Sprintf("https://%s/ssh/sign", toHostname(name))) + fmt.Sprintf("https://%s/1.0/sign", hostname), + fmt.Sprintf("https://%s/sign", hostname), + fmt.Sprintf("https://%s/1.0/ssh/sign", hostname), + fmt.Sprintf("https://%s/ssh/sign", hostname)) + audiences.Renew = append(audiences.Renew, + fmt.Sprintf("https://%s/1.0/renew", hostname), + fmt.Sprintf("https://%s/renew", hostname)) audiences.Revoke = append(audiences.Revoke, - fmt.Sprintf("https://%s/1.0/revoke", toHostname(name)), - fmt.Sprintf("https://%s/revoke", toHostname(name))) + fmt.Sprintf("https://%s/1.0/revoke", hostname), + fmt.Sprintf("https://%s/revoke", hostname)) audiences.SSHSign = append(audiences.SSHSign, - fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)), - fmt.Sprintf("https://%s/ssh/sign", toHostname(name)), - fmt.Sprintf("https://%s/1.0/sign", toHostname(name)), - fmt.Sprintf("https://%s/sign", toHostname(name))) + fmt.Sprintf("https://%s/1.0/ssh/sign", hostname), + fmt.Sprintf("https://%s/ssh/sign", hostname), + fmt.Sprintf("https://%s/1.0/sign", hostname), + fmt.Sprintf("https://%s/sign", hostname)) audiences.SSHRevoke = append(audiences.SSHRevoke, - fmt.Sprintf("https://%s/1.0/ssh/revoke", toHostname(name)), - fmt.Sprintf("https://%s/ssh/revoke", toHostname(name))) + fmt.Sprintf("https://%s/1.0/ssh/revoke", hostname), + fmt.Sprintf("https://%s/ssh/revoke", hostname)) audiences.SSHRenew = append(audiences.SSHRenew, - fmt.Sprintf("https://%s/1.0/ssh/renew", toHostname(name)), - fmt.Sprintf("https://%s/ssh/renew", toHostname(name))) + fmt.Sprintf("https://%s/1.0/ssh/renew", hostname), + fmt.Sprintf("https://%s/ssh/renew", hostname)) audiences.SSHRekey = append(audiences.SSHRekey, - fmt.Sprintf("https://%s/1.0/ssh/rekey", toHostname(name)), - fmt.Sprintf("https://%s/ssh/rekey", toHostname(name))) + fmt.Sprintf("https://%s/1.0/ssh/rekey", hostname), + fmt.Sprintf("https://%s/ssh/rekey", hostname)) } return audiences diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index 764f5d7d..c014bec0 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -35,9 +35,7 @@ type JWK struct { EncryptedKey string `json:"encryptedKey,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - // claimer *Claimer - // audiences Audiences - ctl *Controller + ctl *Controller } // GetID returns the provisioner unique identifier. The name and credential id diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 0b79bf4f..7438ea17 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -46,6 +46,7 @@ var ErrAllowTokenReuse = stderrors.New("allow token reuse") // Audiences stores all supported audiences by request type. type Audiences struct { Sign []string + Renew []string Revoke []string SSHSign []string SSHRevoke []string @@ -56,6 +57,7 @@ type Audiences struct { // All returns all supported audiences across all request types in one list. func (a Audiences) All() (auds []string) { auds = a.Sign + auds = append(auds, a.Renew...) auds = append(auds, a.Revoke...) auds = append(auds, a.SSHSign...) auds = append(auds, a.SSHRevoke...) @@ -69,6 +71,7 @@ func (a Audiences) All() (auds []string) { func (a Audiences) WithFragment(fragment string) Audiences { ret := Audiences{ Sign: make([]string, len(a.Sign)), + Renew: make([]string, len(a.Renew)), Revoke: make([]string, len(a.Revoke)), SSHSign: make([]string, len(a.SSHSign)), SSHRevoke: make([]string, len(a.SSHRevoke)), @@ -82,6 +85,13 @@ func (a Audiences) WithFragment(fragment string) Audiences { ret.Sign[i] = s } } + for i, s := range a.Renew { + if u, err := url.Parse(s); err == nil { + ret.Renew[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() + } else { + ret.Renew[i] = s + } + } for i, s := range a.Revoke { if u, err := url.Parse(s); err == nil { ret.Revoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() diff --git a/ca/client.go b/ca/client.go index 6bc48a42..56f17748 100644 --- a/ca/client.go +++ b/ca/client.go @@ -723,6 +723,36 @@ retry: return &sign, nil } +// RenewWithToken performs the renew request to the CA with the given +// authorization token and returns the api.SignResponse struct. This method is +// generally used to renew an expired certificate. +func (c *Client) RenewWithToken(token string) (*api.SignResponse, error) { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"}) + req, err := http.NewRequest("POST", u.String(), http.NoBody) + if err != nil { + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error creating request") + } + req.Header.Add("Authorization", "Bearer "+token) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; client POST %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readError(resp.Body) + } + var sign api.SignResponse + if err := readJSON(resp.Body, &sign); err != nil { + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error reading %s", u) + } + return &sign, nil +} + // Rekey performs the rekey request to the CA and returns the api.SignResponse // struct. func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) { diff --git a/ca/client_test.go b/ca/client_test.go index 29a4848d..e253dab5 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -529,6 +529,74 @@ func TestClient_Renew(t *testing.T) { } } +func TestClient_RenewWithToken(t *testing.T) { + ok := &api.SignResponse{ + ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + CertChainPEM: []api.Certificate{ + {Certificate: parseCertificate(certPEM)}, + {Certificate: parseCertificate(rootPEM)}, + }, + } + + tests := []struct { + name string + response interface{} + responseCode int + wantErr bool + err error + }{ + {"ok", ok, 200, false, nil}, + {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, + {"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, + {"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.Header.Get("Authorization") != "Bearer token" { + api.JSONStatus(w, errs.InternalServer("force"), 500) + } else { + api.JSONStatus(w, tt.response, tt.responseCode) + } + }) + + got, err := c.RenewWithToken("token") + if (err != nil) != tt.wantErr { + fmt.Printf("%+v", err) + t.Errorf("Client.RenewWithToken() error = %v, wantErr %v", err, tt.wantErr) + return + } + + switch { + case err != nil: + if got != nil { + t.Errorf("Client.RenewWithToken() = %v, want nil", got) + } + + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.responseCode) + assert.HasPrefix(t, err.Error(), tt.err.Error()) + default: + if !reflect.DeepEqual(got, tt.response) { + t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response) + } + } + }) + } +} + func TestClient_Rekey(t *testing.T) { ok := &api.SignResponse{ ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, diff --git a/go.mod b/go.mod index e6696529..f3ae5a09 100644 --- a/go.mod +++ b/go.mod @@ -34,7 +34,7 @@ require ( github.com/urfave/cli v1.22.4 go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 go.step.sm/cli-utils v0.7.0 - go.step.sm/crypto v0.15.0 + go.step.sm/crypto v0.15.3 go.step.sm/linkedca v0.10.0 golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d diff --git a/go.sum b/go.sum index 123df6e4..f634a2ce 100644 --- a/go.sum +++ b/go.sum @@ -683,8 +683,8 @@ go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqe go.step.sm/cli-utils v0.7.0 h1:2GvY5Muid1yzp7YQbfCCS+gK3q7zlHjjLL5Z0DXz8ds= go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/E= go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= -go.step.sm/crypto v0.15.0 h1:VioBln+x3+RoejgeBhvxkLGVYdWRy6PFiAaUUN29/E0= -go.step.sm/crypto v0.15.0/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g= +go.step.sm/crypto v0.15.3 h1:f3GMl+aCydt294BZRjTYwpaXRqwwndvoTY2NLN4wu10= +go.step.sm/crypto v0.15.3/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g= go.step.sm/linkedca v0.10.0 h1:+bqymMRulHYkVde4l16FnqFVskoS6HCWJN5Z5cxAqF8= go.step.sm/linkedca v0.10.0/go.mod h1:5uTRjozEGSPAZal9xJqlaD38cvJcLe3o1VAFVjqcORo= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= From f8df6a1acc39c21944cd9b1ed5cada03b7d668c5 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 11 Mar 2022 10:05:35 -0800 Subject: [PATCH 16/44] Change variable name for consistency --- api/api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/api.go b/api/api.go index 7786bd0d..912e39dd 100644 --- a/api/api.go +++ b/api/api.go @@ -33,7 +33,7 @@ type Authority interface { // context specifies the Authorize[Sign|Revoke|etc.] method. Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) AuthorizeSign(ott string) ([]provisioner.SignOption, error) - AuthorizeRenewToken(ctx context.Context, token string) (*x509.Certificate, error) + AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) GetTLSOptions() *config.TLSOptions Root(shasum string) (*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) From 236caaa735291f5315c41099a354dcb2c93f42d8 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 11 Mar 2022 10:51:33 -0800 Subject: [PATCH 17/44] Add entry in changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 09c9f197..d66c45c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased - 0.18.3] - DATE ### Added +- Added support for renew after expiry using the claim `enableRenewAfterExpiry`. ### Changed ### Deprecated ### Removed From 4690fa64ed9154106140552673d6d2b7ef211f3f Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 11 Mar 2022 14:59:42 -0800 Subject: [PATCH 18/44] Add public methods to retrieve the provisioner extensions. --- authority/provisioner/acme_test.go | 2 +- authority/provisioner/aws_test.go | 2 +- authority/provisioner/azure_test.go | 2 +- authority/provisioner/collection.go | 4 +- authority/provisioner/collection_test.go | 30 ++-- authority/provisioner/extension.go | 73 ++++++++ authority/provisioner/extension_test.go | 158 ++++++++++++++++++ authority/provisioner/gcp_test.go | 2 +- authority/provisioner/jwk_test.go | 2 +- authority/provisioner/k8sSA_test.go | 2 +- authority/provisioner/oidc_test.go | 2 +- authority/provisioner/sign_options.go | 53 ++---- authority/provisioner/sign_options_test.go | 6 +- .../testdata/certs/bad-extension.crt | 21 +++ .../testdata/certs/good-extension.crt | 22 +++ authority/provisioner/x5c_test.go | 2 +- 16 files changed, 317 insertions(+), 66 deletions(-) create mode 100644 authority/provisioner/extension.go create mode 100644 authority/provisioner/extension_test.go create mode 100644 authority/provisioner/testdata/certs/bad-extension.crt create mode 100644 authority/provisioner/testdata/certs/good-extension.crt diff --git a/authority/provisioner/acme_test.go b/authority/provisioner/acme_test.go index a74ef76e..bc4e97e0 100644 --- a/authority/provisioner/acme_test.go +++ b/authority/provisioner/acme_test.go @@ -179,7 +179,7 @@ func TestACME_AuthorizeSign(t *testing.T) { for _, o := range opts { switch v := o.(type) { case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeACME)) + assert.Equals(t, v.Type, TypeACME) assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.CredentialID, "") assert.Len(t, 0, v.KeyValuePairs) diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 3c6f8741..559a48f1 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -677,7 +677,7 @@ func TestAWS_AuthorizeSign(t *testing.T) { switch v := o.(type) { case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeAWS)) + assert.Equals(t, v.Type, TypeAWS) assert.Equals(t, v.Name, tt.aws.GetName()) assert.Equals(t, v.CredentialID, tt.aws.Accounts[0]) assert.Len(t, 2, v.KeyValuePairs) diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index da342ea4..c05685b7 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -506,7 +506,7 @@ func TestAzure_AuthorizeSign(t *testing.T) { switch v := o.(type) { case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeAzure)) + assert.Equals(t, v.Type, TypeAzure) assert.Equals(t, v.Name, tt.azure.GetName()) assert.Equals(t, v.CredentialID, tt.azure.TenantID) assert.Len(t, 0, v.KeyValuePairs) diff --git a/authority/provisioner/collection.go b/authority/provisioner/collection.go index 1bec8689..8bbace5f 100644 --- a/authority/provisioner/collection.go +++ b/authority/provisioner/collection.go @@ -152,8 +152,8 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) // proper id to load the provisioner. func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) { for _, e := range cert.Extensions { - if e.Id.Equal(stepOIDProvisioner) { - var provisioner stepProvisionerASN1 + if e.Id.Equal(StepOIDProvisioner) { + var provisioner extensionASN1 if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { return nil, false } diff --git a/authority/provisioner/collection_test.go b/authority/provisioner/collection_test.go index 348b797c..24db4593 100644 --- a/authority/provisioner/collection_test.go +++ b/authority/provisioner/collection_test.go @@ -147,6 +147,17 @@ func TestCollection_LoadByToken(t *testing.T) { } func TestCollection_LoadByCertificate(t *testing.T) { + mustExtension := func(typ Type, name, credentialID string) pkix.Extension { + e := Extension{ + Type: typ, Name: name, CredentialID: credentialID, + } + ext, err := e.ToExtension() + if err != nil { + t.Fatal(err) + } + return ext + } + p1, err := generateJWK() assert.FatalError(t, err) p2, err := generateOIDC() @@ -159,30 +170,21 @@ func TestCollection_LoadByCertificate(t *testing.T) { byName.Store(p2.GetName(), p2) byName.Store(p3.GetName(), p3) - ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID) - assert.FatalError(t, err) - ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID) - assert.FatalError(t, err) - ok3Ext, err := createProvisionerExtension(int(TypeACME), p3.Name, "") - assert.FatalError(t, err) - notFoundExt, err := createProvisionerExtension(1, "foo", "bar") - assert.FatalError(t, err) - ok1Cert := &x509.Certificate{ - Extensions: []pkix.Extension{ok1Ext}, + Extensions: []pkix.Extension{mustExtension(1, p1.Name, p1.Key.KeyID)}, } ok2Cert := &x509.Certificate{ - Extensions: []pkix.Extension{ok2Ext}, + Extensions: []pkix.Extension{mustExtension(2, p2.Name, p2.ClientID)}, } ok3Cert := &x509.Certificate{ - Extensions: []pkix.Extension{ok3Ext}, + Extensions: []pkix.Extension{mustExtension(TypeACME, p3.Name, "")}, } notFoundCert := &x509.Certificate{ - Extensions: []pkix.Extension{notFoundExt}, + Extensions: []pkix.Extension{mustExtension(1, "foo", "bar")}, } badCert := &x509.Certificate{ Extensions: []pkix.Extension{ - {Id: stepOIDProvisioner, Critical: false, Value: []byte("foobar")}, + {Id: StepOIDProvisioner, Critical: false, Value: []byte("foobar")}, }, } diff --git a/authority/provisioner/extension.go b/authority/provisioner/extension.go new file mode 100644 index 00000000..c316329d --- /dev/null +++ b/authority/provisioner/extension.go @@ -0,0 +1,73 @@ +package provisioner + +import ( + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" +) + +var ( + // StepOIDRoot is the root OID for smallstep. + StepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} + + // StepOIDProvisioner is the OID for the provisioner extension. + StepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(StepOIDRoot, 1)...) +) + +// Extension is the Go representation of the provisioner extension. +type Extension struct { + Type Type + Name string + CredentialID string + KeyValuePairs []string +} + +type extensionASN1 struct { + Type int + Name []byte + CredentialID []byte + KeyValuePairs []string `asn1:"optional,omitempty"` +} + +// Marshal marshals the extension using encoding/asn1. +func (e *Extension) Marshal() ([]byte, error) { + return asn1.Marshal(extensionASN1{ + Type: int(e.Type), + Name: []byte(e.Name), + CredentialID: []byte(e.CredentialID), + KeyValuePairs: e.KeyValuePairs, + }) +} + +// ToExtension returns the pkix.Extension representation of the provisioner +// extension. +func (e *Extension) ToExtension() (pkix.Extension, error) { + b, err := e.Marshal() + if err != nil { + return pkix.Extension{}, err + } + return pkix.Extension{ + Id: StepOIDProvisioner, + Value: b, + }, nil +} + +// GetProvisionerExtension goes through all the certificate extensions and +// returns the provisioner extension (1.3.6.1.4.1.37476.9000.64.1). +func GetProvisionerExtension(cert *x509.Certificate) (*Extension, bool) { + for _, e := range cert.Extensions { + if e.Id.Equal(StepOIDProvisioner) { + var provisioner extensionASN1 + if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { + return nil, false + } + return &Extension{ + Type: Type(provisioner.Type), + Name: string(provisioner.Name), + CredentialID: string(provisioner.CredentialID), + KeyValuePairs: provisioner.KeyValuePairs, + }, true + } + } + return nil, false +} diff --git a/authority/provisioner/extension_test.go b/authority/provisioner/extension_test.go new file mode 100644 index 00000000..69be9e18 --- /dev/null +++ b/authority/provisioner/extension_test.go @@ -0,0 +1,158 @@ +package provisioner + +import ( + "crypto/x509" + "crypto/x509/pkix" + "reflect" + "testing" + + "go.step.sm/crypto/pemutil" +) + +func TestExtension_Marshal(t *testing.T) { + type fields struct { + Type Type + Name string + CredentialID string + KeyValuePairs []string + } + tests := []struct { + name string + fields fields + want []byte + wantErr bool + }{ + {"ok", fields{TypeJWK, "name", "credentialID", nil}, []byte{ + 0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, + 0x44, + }, false}, + {"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, []byte{ + 0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, + 0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f, + 0x13, 0x03, 0x62, 0x61, 0x72, + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Extension{ + Type: tt.fields.Type, + Name: tt.fields.Name, + CredentialID: tt.fields.CredentialID, + KeyValuePairs: tt.fields.KeyValuePairs, + } + got, err := e.Marshal() + if (err != nil) != tt.wantErr { + t.Errorf("Extension.Marshal() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Extension.Marshal() = %x, want %v", got, tt.want) + } + }) + } +} + +func TestExtension_ToExtension(t *testing.T) { + type fields struct { + Type Type + Name string + CredentialID string + KeyValuePairs []string + } + tests := []struct { + name string + fields fields + want pkix.Extension + wantErr bool + }{ + {"ok", fields{TypeJWK, "name", "credentialID", nil}, pkix.Extension{ + Id: StepOIDProvisioner, + Value: []byte{ + 0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, + 0x44, + }, + }, false}, + {"ok empty pairs", fields{TypeJWK, "name", "credentialID", []string{}}, pkix.Extension{ + Id: StepOIDProvisioner, + Value: []byte{ + 0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, + 0x44, + }, + }, false}, + {"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, pkix.Extension{ + Id: StepOIDProvisioner, + Value: []byte{ + 0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, + 0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f, + 0x13, 0x03, 0x62, 0x61, 0x72, + }, + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Extension{ + Type: tt.fields.Type, + Name: tt.fields.Name, + CredentialID: tt.fields.CredentialID, + KeyValuePairs: tt.fields.KeyValuePairs, + } + got, err := e.ToExtension() + if (err != nil) != tt.wantErr { + t.Errorf("Extension.ToExtension() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Extension.ToExtension() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetProvisionerExtension(t *testing.T) { + mustCertificate := func(fn string) *x509.Certificate { + cert, err := pemutil.ReadCertificate(fn) + if err != nil { + t.Fatal(err) + } + return cert + } + + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + args args + want *Extension + want1 bool + }{ + {"ok", args{mustCertificate("testdata/certs/good-extension.crt")}, &Extension{ + Type: TypeJWK, + Name: "mariano@smallstep.com", + CredentialID: "nvgnR8wSzpUlrt_tC3mvrhwhBx9Y7T1WL_JjcFVWYBQ", + }, true}, + {"fail unmarshal", args{mustCertificate("testdata/certs/bad-extension.crt")}, nil, false}, + {"missing extension", args{mustCertificate("testdata/certs/aws.crt")}, nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := GetProvisionerExtension(tt.args.cert) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetProvisionerExtension() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("GetProvisionerExtension() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index 94fbd576..b8c437c3 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -549,7 +549,7 @@ func TestGCP_AuthorizeSign(t *testing.T) { switch v := o.(type) { case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeGCP)) + assert.Equals(t, v.Type, TypeGCP) assert.Equals(t, v.Name, tt.gcp.GetName()) assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0]) assert.Len(t, 4, v.KeyValuePairs) diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index bf5c3d2c..dde2f836 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -300,7 +300,7 @@ func TestJWK_AuthorizeSign(t *testing.T) { switch v := o.(type) { case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeJWK)) + assert.Equals(t, v.Type, TypeJWK) assert.Equals(t, v.Name, tt.prov.GetName()) assert.Equals(t, v.CredentialID, tt.prov.Key.KeyID) assert.Len(t, 0, v.KeyValuePairs) diff --git a/authority/provisioner/k8sSA_test.go b/authority/provisioner/k8sSA_test.go index 0a82e8ef..378d4471 100644 --- a/authority/provisioner/k8sSA_test.go +++ b/authority/provisioner/k8sSA_test.go @@ -283,7 +283,7 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { switch v := o.(type) { case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeK8sSA)) + assert.Equals(t, v.Type, TypeK8sSA) assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.CredentialID, "") assert.Len(t, 0, v.KeyValuePairs) diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index 62082fb2..c1a94b1d 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -327,7 +327,7 @@ func TestOIDC_AuthorizeSign(t *testing.T) { switch v := o.(type) { case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeOIDC)) + assert.Equals(t, v.Type, TypeOIDC) assert.Equals(t, v.Name, tt.prov.GetName()) assert.Equals(t, v.CredentialID, tt.prov.ClientID) assert.Len(t, 0, v.KeyValuePairs) diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index 34b2e99b..80dfc66e 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -6,7 +6,6 @@ import ( "crypto/rsa" "crypto/x509" "crypto/x509/pkix" - "encoding/asn1" "encoding/json" "net" "net/http" @@ -14,7 +13,6 @@ import ( "reflect" "time" - "github.com/pkg/errors" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/x509util" @@ -404,17 +402,12 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error { return nil } -var ( - stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} - stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) -) - -type stepProvisionerASN1 struct { - Type int - Name []byte - CredentialID []byte - KeyValuePairs []string `asn1:"optional,omitempty"` -} +// type stepProvisionerASN1 struct { +// Type int +// Name []byte +// CredentialID []byte +// KeyValuePairs []string `asn1:"optional,omitempty"` +// } type forceCNOption struct { ForceCN bool @@ -441,23 +434,22 @@ func (o *forceCNOption) Modify(cert *x509.Certificate, _ SignOptions) error { } type provisionerExtensionOption struct { - Type int - Name string - CredentialID string - KeyValuePairs []string + Extension } func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValuePairs ...string) *provisionerExtensionOption { return &provisionerExtensionOption{ - Type: int(typ), - Name: name, - CredentialID: credentialID, - KeyValuePairs: keyValuePairs, + Extension: Extension{ + Type: typ, + Name: name, + CredentialID: credentialID, + KeyValuePairs: keyValuePairs, + }, } } func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error { - ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID, o.KeyValuePairs...) + ext, err := o.ToExtension() if err != nil { return errs.NewError(http.StatusInternalServerError, err, "error creating certificate") } @@ -471,20 +463,3 @@ func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOption cert.ExtraExtensions = append([]pkix.Extension{ext}, cert.ExtraExtensions...) return nil } - -func createProvisionerExtension(typ int, name, credentialID string, keyValuePairs ...string) (pkix.Extension, error) { - b, err := asn1.Marshal(stepProvisionerASN1{ - Type: typ, - Name: []byte(name), - CredentialID: []byte(credentialID), - KeyValuePairs: keyValuePairs, - }) - if err != nil { - return pkix.Extension{}, errors.Wrap(err, "error marshaling provisioner extension") - } - return pkix.Extension{ - Id: stepOIDProvisioner, - Critical: false, - Value: b, - }, nil -} diff --git a/authority/provisioner/sign_options_test.go b/authority/provisioner/sign_options_test.go index 32b8e3c6..fc4d675a 100644 --- a/authority/provisioner/sign_options_test.go +++ b/authority/provisioner/sign_options_test.go @@ -636,18 +636,18 @@ func Test_newProvisionerExtension_Option(t *testing.T) { valid: func(cert *x509.Certificate) { if assert.Len(t, 1, cert.ExtraExtensions) { ext := cert.ExtraExtensions[0] - assert.Equals(t, ext.Id, stepOIDProvisioner) + assert.Equals(t, ext.Id, StepOIDProvisioner) } }, } }, "ok/prepend": func() test { return test{ - cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: stepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}}, + cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: StepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}}, valid: func(cert *x509.Certificate) { if assert.Len(t, 3, cert.ExtraExtensions) { ext := cert.ExtraExtensions[0] - assert.Equals(t, ext.Id, stepOIDProvisioner) + assert.Equals(t, ext.Id, StepOIDProvisioner) assert.False(t, ext.Critical) } }, diff --git a/authority/provisioner/testdata/certs/bad-extension.crt b/authority/provisioner/testdata/certs/bad-extension.crt new file mode 100644 index 00000000..ecce0f28 --- /dev/null +++ b/authority/provisioner/testdata/certs/bad-extension.crt @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDeTCCAx+gAwIBAgIRAOTItW2pYuSU+PkmLW090iUwCgYIKoZIzj0EAwIwJDEi +MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjUy +MjBaFw0yMjAzMTIyMjUzMjBaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs +aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg +U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0 +ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAE/9vvOZ1Zzysnf3VeGyotMJEMZdAborB36Ah5QL/3yQNMRWIc +pv9Dwx19pHw7SquVE8jIaPPJSjaeWnfMPDYDxaOCAbcwggGzMA4GA1UdDwEB/wQE +AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw +ADAdBgNVHQ4EFgQUkJUg6AsqWlqTZt6BHidRMwh1vKYwHwYDVR0jBBgwFoAUDpTg +d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB +hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu +Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh +NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA +ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G +A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2 +LXNlcnZlci1nMy5jcmwwFwYMKwYBBAGCpGTGKEABBAdmb29vYmFyMAoGCCqGSM49 +BAMCA0gAMEUCIQCWYqOuk4bLkVVeHvo3P8TlJJ3fw6ijDDLstvdrQqAl5wIgEjSY +wVcR649Oc8PJGh/43Kpx0+4OTYPQrD/JqphVF7g= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/authority/provisioner/testdata/certs/good-extension.crt b/authority/provisioner/testdata/certs/good-extension.crt new file mode 100644 index 00000000..103353a7 --- /dev/null +++ b/authority/provisioner/testdata/certs/good-extension.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDujCCA2GgAwIBAgIRAM5celDKTTqAGycljO7FZdEwCgYIKoZIzj0EAwIwJDEi +MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjQx +MDRaFw0yMjAzMTIyMjQyMDRaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs +aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg +U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0 +ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAEkXffZYlSJRMxJrZHmUpEMC4jQYCkF86mLJY0iLZ8k00N/xF0 +4rAGwzTU/l9tfRpNl+z/XfMMWPXS0Q8NU/o4S6OCAfkwggH1MA4GA1UdDwEB/wQE +AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw +ADAdBgNVHQ4EFgQUL3sSlYW8Tf2l2P+gFTdn5wsUjfgwHwYDVR0jBBgwFoAUDpTg +d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB +hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu +Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh +NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA +ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G +A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2 +LXNlcnZlci1nMy5jcmwwWQYMKwYBBAGCpGTGKEABBEkwRwIBAQQVbWFyaWFub0Bz +bWFsbHN0ZXAuY29tBCtudmduUjh3U3pwVWxydF90QzNtdnJod2hCeDlZN1QxV0xf +SmpjRlZXWUJRMAoGCCqGSM49BAMCA0cAMEQCIE6umrhSbeQWWVK5cWBvXj5c0cGB +bUF0rNw/dsaCaWcwAiAKSkmjhsC63DVPXPCNUki90YgVovO69foO1ZaB43lx5w== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 18a31b04..84e29b48 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -469,7 +469,7 @@ func TestX5C_AuthorizeSign(t *testing.T) { switch v := o.(type) { case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeX5C)) + assert.Equals(t, v.Type, TypeX5C) assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.CredentialID, "") assert.Len(t, 0, v.KeyValuePairs) From a4dd586a81847f267aa75689f32b61dc712c6b60 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 11 Mar 2022 15:13:39 -0800 Subject: [PATCH 19/44] Add method to get the CA url from the client. --- ca/client.go | 5 +++++ ca/client_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/ca/client.go b/ca/client.go index 56f17748..40618330 100644 --- a/ca/client.go +++ b/ca/client.go @@ -563,6 +563,11 @@ func (c *Client) retryOnError(r *http.Response) bool { return false } +// GetCaURL returns the configura CA url. +func (c *Client) GetCaURL() string { + return c.endpoint.String() +} + // GetRootCAs returns the RootCAs certificate pool from the configured // transport. func (c *Client) GetRootCAs() *x509.CertPool { diff --git a/ca/client_test.go b/ca/client_test.go index e253dab5..6e352291 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -1128,3 +1128,28 @@ func TestClient_SSHBastion(t *testing.T) { }) } } + +func TestClient_GetCaURL(t *testing.T) { + tests := []struct { + name string + caURL string + want string + }{ + {"ok", "https://ca.com", "https://ca.com"}, + {"ok no schema", "ca.com", "https://ca.com"}, + {"ok with port", "https://ca.com:9000", "https://ca.com:9000"}, + {"ok with version", "https://ca.com/1.0", "https://ca.com/1.0"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(tt.caURL) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + if got := c.GetCaURL(); got != tt.want { + t.Errorf("Client.GetCaURL() = %v, want %v", got, tt.want) + } + }) + } +} From 6dcde8a7438f24d28d6e6013d1e6f8e3b1e40649 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 11 Mar 2022 15:22:53 -0800 Subject: [PATCH 20/44] Fix typo --- ca/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ca/client.go b/ca/client.go index 40618330..3a36fcd6 100644 --- a/ca/client.go +++ b/ca/client.go @@ -563,7 +563,7 @@ func (c *Client) retryOnError(r *http.Response) bool { return false } -// GetCaURL returns the configura CA url. +// GetCaURL returns the configured CA url. func (c *Client) GetCaURL() string { return c.endpoint.String() } From f7a044222e6680c361818374793561b7bd262cdc Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Mon, 14 Mar 2022 13:18:44 +0200 Subject: [PATCH 21/44] git: ignore .envrc files --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index d87786b0..299a2c16 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ coverage.txt output vendor .idea +.envrc From c903f00cd4ba7110a8b55910e063dc75747f0f90 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 14 Mar 2022 15:40:01 -0700 Subject: [PATCH 22/44] Rename claim to allowRenewAfterExpiry. --- CHANGELOG.md | 2 +- authority/config/config.go | 29 ++++++++--------- authority/provisioner/claims.go | 40 ++++++++++++------------ authority/provisioner/controller.go | 4 +-- authority/provisioner/controller_test.go | 12 +++---- authority/provisioner/utils_test.go | 32 +++++++++---------- authority/provisioners.go | 11 +++++-- go.mod | 2 +- go.sum | 4 +-- 9 files changed, 72 insertions(+), 64 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d66c45c0..28dfe305 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased - 0.18.3] - DATE ### Added -- Added support for renew after expiry using the claim `enableRenewAfterExpiry`. +- Added support for renew after expiry using the claim `allowRenewAfterExpiry`. ### Changed ### Deprecated ### Removed diff --git a/authority/config/config.go b/authority/config/config.go index e4fcc863..2c437725 100644 --- a/authority/config/config.go +++ b/authority/config/config.go @@ -26,26 +26,27 @@ var ( DefaultBackdate = time.Minute // DefaultDisableRenewal disables renewals per provisioner. DefaultDisableRenewal = false - // DefaultEnableRenewAfterExpiry enables renewals even when the certificate is expired. - DefaultEnableRenewAfterExpiry = false + // DefaultAllowRenewAfterExpiry allows renewals even if the certificate is + // expired. + DefaultAllowRenewAfterExpiry = false // DefaultEnableSSHCA enable SSH CA features per provisioner or globally // for all provisioners. DefaultEnableSSHCA = false // GlobalProvisionerClaims default claims for the Authority. Can be overridden // by provisioner specific claims. GlobalProvisionerClaims = provisioner.Claims{ - MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs - MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs - MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour}, - MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs - MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, - DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, - EnableSSHCA: &DefaultEnableSSHCA, - DisableRenewal: &DefaultDisableRenewal, - EnableRenewAfterExpiry: &DefaultEnableRenewAfterExpiry, + MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs + MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs + MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour}, + MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs + MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, + DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, + EnableSSHCA: &DefaultEnableSSHCA, + DisableRenewal: &DefaultDisableRenewal, + AllowRenewAfterExpiry: &DefaultAllowRenewAfterExpiry, } ) diff --git a/authority/provisioner/claims.go b/authority/provisioner/claims.go index c8bee2e5..2a3e2c61 100644 --- a/authority/provisioner/claims.go +++ b/authority/provisioner/claims.go @@ -24,8 +24,8 @@ type Claims struct { EnableSSHCA *bool `json:"enableSSHCA,omitempty"` // Renewal properties - DisableRenewal *bool `json:"disableRenewal,omitempty"` - EnableRenewAfterExpiry *bool `json:"enableRenewAfterExpiry,omitempty"` + DisableRenewal *bool `json:"disableRenewal,omitempty"` + AllowRenewAfterExpiry *bool `json:"allowRenewAfterExpiry,omitempty"` } // Claimer is the type that controls claims. It provides an interface around the @@ -44,22 +44,22 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) { // Claims returns the merge of the inner and global claims. func (c *Claimer) Claims() Claims { disableRenewal := c.IsDisableRenewal() - enablerenewAfterExpiry := c.IsRenewAfterExpiry() + allowRenewAfterExpiry := c.AllowRenewAfterExpiry() enableSSHCA := c.IsSSHCAEnabled() return Claims{ - MinTLSDur: &Duration{c.MinTLSCertDuration()}, - MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, - DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, - MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()}, - MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()}, - DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()}, - MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()}, - MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()}, - DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()}, - EnableSSHCA: &enableSSHCA, - DisableRenewal: &disableRenewal, - EnableRenewAfterExpiry: &enablerenewAfterExpiry, + MinTLSDur: &Duration{c.MinTLSCertDuration()}, + MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, + DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, + MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()}, + MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()}, + DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()}, + MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()}, + MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()}, + DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()}, + EnableSSHCA: &enableSSHCA, + DisableRenewal: &disableRenewal, + AllowRenewAfterExpiry: &allowRenewAfterExpiry, } } @@ -109,14 +109,14 @@ func (c *Claimer) IsDisableRenewal() bool { return *c.claims.DisableRenewal } -// IsRenewAfterExpiry returns if the renewal flow is authorized even if the +// AllowRenewAfterExpiry returns if the renewal flow is authorized if the // certificate is expired. If the property is not set within the provisioner // then the global value from the authority configuration will be used. -func (c *Claimer) IsRenewAfterExpiry() bool { - if c.claims == nil || c.claims.EnableRenewAfterExpiry == nil { - return *c.global.EnableRenewAfterExpiry +func (c *Claimer) AllowRenewAfterExpiry() bool { + if c.claims == nil || c.claims.AllowRenewAfterExpiry == nil { + return *c.global.AllowRenewAfterExpiry } - return *c.claims.EnableRenewAfterExpiry + return *c.claims.AllowRenewAfterExpiry } // DefaultSSHCertDuration returns the default SSH certificate duration for the diff --git a/authority/provisioner/controller.go b/authority/provisioner/controller.go index 97ebe8f8..a91ebaac 100644 --- a/authority/provisioner/controller.go +++ b/authority/provisioner/controller.go @@ -124,7 +124,7 @@ func DefaultAuthorizeRenew(ctx context.Context, p *Controller, cert *x509.Certif if now.Before(cert.NotBefore) { return errs.Unauthorized("certificate is not yet valid" + " " + now.UTC().Format(time.RFC3339Nano) + " vs " + cert.NotBefore.Format(time.RFC3339Nano)) } - if now.After(cert.NotAfter) && !p.Claimer.IsRenewAfterExpiry() { + if now.After(cert.NotAfter) && !p.Claimer.AllowRenewAfterExpiry() { return errs.Unauthorized("certificate has expired") } @@ -144,7 +144,7 @@ func DefaultAuthorizeSSHRenew(ctx context.Context, p *Controller, cert *ssh.Cert if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { return errs.Unauthorized("certificate is not yet valid") } - if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.IsRenewAfterExpiry() { + if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewAfterExpiry() { return errs.Unauthorized("certificate has expired") } diff --git a/authority/provisioner/controller_test.go b/authority/provisioner/controller_test.go index bbf7cb81..9fb90e9d 100644 --- a/authority/provisioner/controller_test.go +++ b/authority/provisioner/controller_test.go @@ -160,13 +160,13 @@ func TestController_AuthorizeRenew(t *testing.T) { NotBefore: now, NotAfter: now.Add(time.Hour), }}, false}, - {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { + {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { return nil }}, args{ctx, &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, false}, - {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ NotBefore: now.Add(-time.Hour), NotAfter: now.Add(-time.Minute), }}, false}, @@ -231,13 +231,13 @@ func TestController_AuthorizeSSHRenew(t *testing.T) { ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()), }}, false}, - {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { + {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { return nil }}, args{ctx, &ssh.Certificate{ ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()), }}, false}, - {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ ValidAfter: uint64(now.Add(-time.Hour).Unix()), ValidBefore: uint64(now.Add(-time.Minute).Unix()), }}, false}, @@ -296,7 +296,7 @@ func TestDefaultAuthorizeRenew(t *testing.T) { }}, false}, {"ok renew after expiry", args{ctx, &Controller{ Interface: &JWK{}, - Claimer: mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), + Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), }, &x509.Certificate{ NotBefore: now.Add(-time.Hour), NotAfter: now.Add(-time.Minute), @@ -354,7 +354,7 @@ func TestDefaultAuthorizeSSHRenew(t *testing.T) { }}, false}, {"ok renew after expiry", args{ctx, &Controller{ Interface: &JWK{}, - Claimer: mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), + Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), }, &ssh.Certificate{ ValidAfter: uint64(now.Add(-time.Hour).Unix()), ValidBefore: uint64(now.Add(-time.Minute).Unix()), diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index ff8421f0..669693d6 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -24,22 +24,22 @@ import ( ) var ( - defaultDisableRenewal = false - defaultEnableRenewAfterExpiry = false - defaultEnableSSHCA = true - globalProvisionerClaims = Claims{ - MinTLSDur: &Duration{5 * time.Minute}, - MaxTLSDur: &Duration{24 * time.Hour}, - DefaultTLSDur: &Duration{24 * time.Hour}, - MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs - MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, - DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour}, - MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs - MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, - DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, - EnableSSHCA: &defaultEnableSSHCA, - DisableRenewal: &defaultDisableRenewal, - EnableRenewAfterExpiry: &defaultEnableRenewAfterExpiry, + defaultDisableRenewal = false + defaultAllowRenewAfterExpiry = false + defaultEnableSSHCA = true + globalProvisionerClaims = Claims{ + MinTLSDur: &Duration{5 * time.Minute}, + MaxTLSDur: &Duration{24 * time.Hour}, + DefaultTLSDur: &Duration{24 * time.Hour}, + MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs + MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, + DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour}, + MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs + MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, + DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, + EnableSSHCA: &defaultEnableSSHCA, + DisableRenewal: &defaultDisableRenewal, + AllowRenewAfterExpiry: &defaultAllowRenewAfterExpiry, } testAudiences = Audiences{ Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"}, diff --git a/authority/provisioners.go b/authority/provisioners.go index 780d12c0..a6ac5aa8 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -437,7 +437,8 @@ func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) { } pc := &provisioner.Claims{ - DisableRenewal: &c.DisableRenewal, + DisableRenewal: &c.DisableRenewal, + AllowRenewAfterExpiry: &c.AllowRenewAfterExpiry, } var err error @@ -475,12 +476,18 @@ func claimsToLinkedca(c *provisioner.Claims) *linkedca.Claims { } disableRenewal := config.DefaultDisableRenewal + allowRenewAfterExpiry := config.DefaultAllowRenewAfterExpiry + if c.DisableRenewal != nil { disableRenewal = *c.DisableRenewal } + if c.AllowRenewAfterExpiry != nil { + allowRenewAfterExpiry = *c.AllowRenewAfterExpiry + } lc := &linkedca.Claims{ - DisableRenewal: disableRenewal, + DisableRenewal: disableRenewal, + AllowRenewAfterExpiry: allowRenewAfterExpiry, } if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil { diff --git a/go.mod b/go.mod index f3ae5a09..6033d05e 100644 --- a/go.mod +++ b/go.mod @@ -35,7 +35,7 @@ require ( go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 go.step.sm/cli-utils v0.7.0 go.step.sm/crypto v0.15.3 - go.step.sm/linkedca v0.10.0 + go.step.sm/linkedca v0.11.0 golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect diff --git a/go.sum b/go.sum index f634a2ce..c7a18aad 100644 --- a/go.sum +++ b/go.sum @@ -685,8 +685,8 @@ go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/ go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= go.step.sm/crypto v0.15.3 h1:f3GMl+aCydt294BZRjTYwpaXRqwwndvoTY2NLN4wu10= go.step.sm/crypto v0.15.3/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g= -go.step.sm/linkedca v0.10.0 h1:+bqymMRulHYkVde4l16FnqFVskoS6HCWJN5Z5cxAqF8= -go.step.sm/linkedca v0.10.0/go.mod h1:5uTRjozEGSPAZal9xJqlaD38cvJcLe3o1VAFVjqcORo= +go.step.sm/linkedca v0.11.0 h1:jkG5XDQz9VSz2PH+cGjDvJTwiIziN0SWExTnicWpb8o= +go.step.sm/linkedca v0.11.0/go.mod h1:5uTRjozEGSPAZal9xJqlaD38cvJcLe3o1VAFVjqcORo= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= From 6d532045dcdd93d7c8af98d66eb5427c3df82e0b Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 14 Mar 2022 17:31:21 -0700 Subject: [PATCH 23/44] Fix validity check for sshpop provisioner. --- authority/provisioner/sshpop.go | 26 +++++++++++++++----------- authority/provisioner/sshpop_test.go | 2 +- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/authority/provisioner/sshpop.go b/authority/provisioner/sshpop.go index a7df38de..9de0fca2 100644 --- a/authority/provisioner/sshpop.go +++ b/authority/provisioner/sshpop.go @@ -104,7 +104,7 @@ func (p *SSHPOP) Init(config Config) (err error) { // e.g. a Sign request will auth/validate different fields than a Revoke request. // // Checking for certificate revocation has been moved to the authority package. -func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) { +func (p *SSHPOP) authorizeToken(token string, audiences []string, checkValidity bool) (*sshPOPPayload, error) { sshCert, jwt, err := ExtractSSHPOPCert(token) if err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, @@ -112,13 +112,18 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa } // Check validity period of the certificate. - n := time.Now() - if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) { - return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future") - } - if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) { - return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past") + // + // Controller.AuthorizeSSHRenew will validate this on the renewal flow. + if checkValidity { + unixNow := time.Now().Unix() + if after := int64(sshCert.ValidAfter); after < 0 || unixNow < int64(sshCert.ValidAfter) { + return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future") + } + if before := int64(sshCert.ValidBefore); sshCert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) { + return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past") + } } + sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey) if !ok { return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey") @@ -181,7 +186,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa // AuthorizeSSHRevoke validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { - claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke, true) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") } @@ -194,7 +199,7 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { // AuthorizeSSHRenew validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { - claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRenew) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRenew, false) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") } @@ -207,7 +212,7 @@ func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Cert // AuthorizeSSHRekey validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) { - claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRekey) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRekey, true) if err != nil { return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") } @@ -222,7 +227,6 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, }, nil - } // ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go index 715bf6de..b548fe71 100644 --- a/authority/provisioner/sshpop_test.go +++ b/authority/provisioner/sshpop_test.go @@ -214,7 +214,7 @@ func TestSSHPOP_authorizeToken(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { + if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign, true); err != nil { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tc.code) From 81b0c6c37c6129930ffb659cb8758952a7834e91 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 15 Mar 2022 15:51:45 +0100 Subject: [PATCH 24/44] Add API implementation for authority and provisioner policy --- api/utils.go | 7 + authority/admin/api/admin.go | 4 + authority/admin/api/admin_test.go | 21 ++ authority/admin/api/handler.go | 52 +++- authority/admin/api/middleware.go | 23 ++ authority/admin/api/policy.go | 313 ++++++++++++++++++++++++ authority/admin/db.go | 37 +++ authority/admin/db/nosql/nosql.go | 7 +- authority/admin/db/nosql/policy.go | 144 +++++++++++ authority/admin/db/nosql/provisioner.go | 4 + authority/authority.go | 61 +++-- authority/linkedca.go | 20 ++ authority/policy.go | 132 ++++++++++ authority/policy/options.go | 34 ++- authority/provisioners.go | 53 ++++ authority/tls.go | 5 +- ca/ca.go | 3 +- go.mod | 2 +- go.sum | 4 - 19 files changed, 883 insertions(+), 43 deletions(-) create mode 100644 authority/admin/api/policy.go create mode 100644 authority/admin/db/nosql/policy.go create mode 100644 authority/policy.go diff --git a/api/utils.go b/api/utils.go index a7f4bf58..b6ff7960 100644 --- a/api/utils.go +++ b/api/utils.go @@ -66,6 +66,13 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) { LogEnabledResponse(w, v) } +// JSONNotFound writes a HTTP Not Found response with empty body. +func JSONNotFound(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + LogEnabledResponse(w, nil) +} + // ProtoJSON writes the passed value into the http.ResponseWriter. func ProtoJSON(w http.ResponseWriter, m proto.Message) { ProtoJSONStatus(w, m, http.StatusOK) diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index 7aa66d0f..dd40784b 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -25,6 +25,10 @@ type adminAuthority interface { LoadProvisionerByID(id string) (provisioner.Interface, error) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error RemoveProvisioner(ctx context.Context, id string) error + GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) + StoreAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error + UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error + RemoveAuthorityPolicy(ctx context.Context) error } // CreateAdminRequest represents the body for a CreateAdmin request. diff --git a/authority/admin/api/admin_test.go b/authority/admin/api/admin_test.go index 8d223b52..f1698139 100644 --- a/authority/admin/api/admin_test.go +++ b/authority/admin/api/admin_test.go @@ -37,6 +37,11 @@ type mockAdminAuthority struct { MockLoadProvisionerByID func(id string) (provisioner.Interface, error) MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error MockRemoveProvisioner func(ctx context.Context, id string) error + + MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error) + MockStoreAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error + MockUpdateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error + MockRemoveAuthorityPolicy func(ctx context.Context) error } func (m *mockAdminAuthority) IsAdminAPIEnabled() bool { @@ -130,6 +135,22 @@ func (m *mockAdminAuthority) RemoveProvisioner(ctx context.Context, id string) e return m.MockErr } +func (m *mockAdminAuthority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("not implemented yet") +} + +func (m *mockAdminAuthority) StoreAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + return errors.New("not implemented yet") +} + +func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + return errors.New("not implemented yet") +} + +func (m *mockAdminAuthority) RemoveAuthorityPolicy(ctx context.Context) error { + return errors.New("not implemented yet") +} + func TestCreateAdminRequest_Validate(t *testing.T) { type fields struct { Subject string diff --git a/authority/admin/api/handler.go b/authority/admin/api/handler.go index 99e74c88..e59b95e0 100644 --- a/authority/admin/api/handler.go +++ b/authority/admin/api/handler.go @@ -8,32 +8,44 @@ import ( // Handler is the Admin API request handler. type Handler struct { - adminDB admin.DB - auth adminAuthority - acmeDB acme.DB - acmeResponder acmeAdminResponderInterface + adminDB admin.DB + auth adminAuthority + acmeDB acme.DB + acmeResponder acmeAdminResponderInterface + policyResponder policyAdminResponderInterface } // NewHandler returns a new Authority Config Handler. -func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface) api.RouterHandler { +func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface, policyResponder policyAdminResponderInterface) api.RouterHandler { return &Handler{ - auth: auth, - adminDB: adminDB, - acmeDB: acmeDB, - acmeResponder: acmeResponder, + auth: auth, + adminDB: adminDB, + acmeDB: acmeDB, + acmeResponder: acmeResponder, + policyResponder: policyResponder, } } // Route traffic and implement the Router interface. func (h *Handler) Route(r api.Router) { + authnz := func(next nextHTTP) nextHTTP { - return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next)) + //return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next)) + return h.requireAPIEnabled(next) // TODO(hs): remove this; temporarily no auth checks for simple testing... } requireEABEnabled := func(next nextHTTP) nextHTTP { return h.requireEABEnabled(next) } + enabledInStandalone := func(next nextHTTP) nextHTTP { + return h.checkAction(next, true) + } + + disabledInStandalone := func(next nextHTTP) nextHTTP { + return h.checkAction(next, false) + } + // Provisioners r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner)) r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners)) @@ -53,4 +65,24 @@ func (h *Handler) Route(r api.Router) { r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys))) r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.CreateExternalAccountKey))) r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(requireEABEnabled(h.acmeResponder.DeleteExternalAccountKey))) + + // Policy - Authority + r.MethodFunc("GET", "/policy", authnz(enabledInStandalone(h.policyResponder.GetAuthorityPolicy))) + r.MethodFunc("POST", "/policy", authnz(enabledInStandalone(h.policyResponder.CreateAuthorityPolicy))) + r.MethodFunc("PUT", "/policy", authnz(enabledInStandalone(h.policyResponder.UpdateAuthorityPolicy))) + r.MethodFunc("DELETE", "/policy", authnz(enabledInStandalone(h.policyResponder.DeleteAuthorityPolicy))) + + // Policy - Provisioner + //r.MethodFunc("GET", "/provisioners/{name}/policy", noauth(h.policyResponder.GetProvisionerPolicy)) + r.MethodFunc("GET", "/provisioners/{name}/policy", authnz(disabledInStandalone(h.policyResponder.GetProvisionerPolicy))) + r.MethodFunc("POST", "/provisioners/{name}/policy", authnz(disabledInStandalone(h.policyResponder.CreateProvisionerPolicy))) + r.MethodFunc("PUT", "/provisioners/{name}/policy", authnz(disabledInStandalone(h.policyResponder.UpdateProvisionerPolicy))) + r.MethodFunc("DELETE", "/provisioners/{name}/policy", authnz(disabledInStandalone(h.policyResponder.DeleteProvisionerPolicy))) + + // Policy - ACME Account + // TODO: ensure we don't clash with eab; might want to change eab paths slightly (as long as we don't have it released completely; needs changes in adminClient too) + r.MethodFunc("GET", "/acme/{provisionerName}/{accountID}/policy", authnz(disabledInStandalone(h.policyResponder.GetACMEAccountPolicy))) + r.MethodFunc("POST", "/acme/{provisionerName}/{accountID}/policy", authnz(disabledInStandalone(h.policyResponder.CreateACMEAccountPolicy))) + r.MethodFunc("PUT", "/acme/{provisionerName}/{accountID}/policy", authnz(disabledInStandalone(h.policyResponder.UpdateACMEAccountPolicy))) + r.MethodFunc("DELETE", "/acme/{provisionerName}/{accountID}/policy", authnz(disabledInStandalone(h.policyResponder.DeleteACMEAccountPolicy))) } diff --git a/authority/admin/api/middleware.go b/authority/admin/api/middleware.go index 19025a9d..62aefdc3 100644 --- a/authority/admin/api/middleware.go +++ b/authority/admin/api/middleware.go @@ -6,6 +6,7 @@ import ( "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/admin/db/nosql" ) type nextHTTP = func(http.ResponseWriter, *http.Request) @@ -44,6 +45,28 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { } } +// checkAction checks if an action is supported in standalone or not +func (h *Handler) checkAction(next nextHTTP, supportedInStandalone bool) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + + // actions allowed in standalone mode are always allowed + if supportedInStandalone { + next(w, r) + return + } + + // when in standalone mode, actions are not supported + if _, ok := h.adminDB.(*nosql.DB); ok { + api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, + "operation not supported in standalone mode")) + return + } + + // continue to next http handler + next(w, r) + } +} + // ContextKey is the key type for storing and searching for ACME request // essentials in the context of a request. type ContextKey string diff --git a/authority/admin/api/policy.go b/authority/admin/api/policy.go new file mode 100644 index 00000000..c318e5e5 --- /dev/null +++ b/authority/admin/api/policy.go @@ -0,0 +1,313 @@ +package api + +import ( + "net/http" + + "github.com/go-chi/chi" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/provisioner" + "go.step.sm/linkedca" +) + +type policyAdminResponderInterface interface { + GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) + CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) + UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) + DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) + GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) + CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) + UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) + DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) + GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) + CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) + UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) + DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) +} + +// PolicyAdminResponder is responsible for writing ACME admin responses +type PolicyAdminResponder struct { + auth adminAuthority + adminDB admin.DB +} + +// NewACMEAdminResponder returns a new ACMEAdminResponder +func NewPolicyAdminResponder(auth adminAuthority, adminDB admin.DB) *PolicyAdminResponder { + return &PolicyAdminResponder{ + auth: auth, + adminDB: adminDB, + } +} + +// GetAuthorityPolicy handles the GET /admin/authority/policy request +func (par *PolicyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + + policy, err := par.auth.GetAuthorityPolicy(r.Context()) + if ae, ok := err.(*admin.Error); ok { + if !ae.IsType(admin.ErrorNotFoundType) { + api.WriteError(w, admin.WrapErrorISE(ae, "error retrieving authority policy")) + return + } + } + + if policy == nil { + api.JSONNotFound(w) + return + } + + api.ProtoJSONStatus(w, policy, http.StatusOK) +} + +// CreateAuthorityPolicy handles the POST /admin/authority/policy request +func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + + ctx := r.Context() + policy, err := par.auth.GetAuthorityPolicy(ctx) + + shouldWriteError := false + if ae, ok := err.(*admin.Error); ok { + shouldWriteError = !ae.IsType(admin.ErrorNotFoundType) + } + + if shouldWriteError { + api.WriteError(w, admin.WrapErrorISE(err, "error retrieving authority policy")) + return + } + + if policy != nil { + adminErr := admin.NewError(admin.ErrorBadRequestType, "authority already has a policy") + adminErr.Status = http.StatusConflict + api.WriteError(w, adminErr) + return + } + + var newPolicy = new(linkedca.Policy) + if err := api.ReadProtoJSON(r.Body, newPolicy); err != nil { + api.WriteError(w, err) + return + } + + if err := par.auth.StoreAuthorityPolicy(ctx, newPolicy); err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error storing authority policy")) + return + } + + storedPolicy, err := par.auth.GetAuthorityPolicy(ctx) + if err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error retrieving authority policy after updating")) + return + } + + api.JSONStatus(w, storedPolicy, http.StatusCreated) +} + +// UpdateAuthorityPolicy handles the PUT /admin/authority/policy request +func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + var policy = new(linkedca.Policy) + if err := api.ReadProtoJSON(r.Body, policy); err != nil { + api.WriteError(w, err) + return + } + + ctx := r.Context() + if err := par.auth.UpdateAuthorityPolicy(ctx, policy); err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error updating authority policy")) + return + } + + newPolicy, err := par.auth.GetAuthorityPolicy(ctx) + if err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error retrieving authority policy after updating")) + return + } + + api.ProtoJSONStatus(w, newPolicy, http.StatusOK) +} + +// DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request +func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + + ctx := r.Context() + policy, err := par.auth.GetAuthorityPolicy(ctx) + + if ae, ok := err.(*admin.Error); ok { + if !ae.IsType(admin.ErrorNotFoundType) { + api.WriteError(w, admin.WrapErrorISE(ae, "error retrieving authority policy")) + return + } + } + + if policy == nil { + api.JSONNotFound(w) + return + } + + err = par.auth.RemoveAuthorityPolicy(ctx) + if err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error deleting authority policy")) + return + } + + api.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) +} + +// GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request +func (par *PolicyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + // TODO: move getting provisioner to middleware? + ctx := r.Context() + name := chi.URLParam(r, "name") + var ( + p provisioner.Interface + err error + ) + if p, err = par.auth.LoadProvisionerByName(name); err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + return + } + + prov, err := par.adminDB.GetProvisioner(ctx, p.GetID()) + if err != nil { + api.WriteError(w, err) + return + } + + policy := prov.GetPolicy() + if policy == nil { + api.JSONNotFound(w) + return + } + + api.ProtoJSONStatus(w, policy, http.StatusOK) +} + +// CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request +func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + name := chi.URLParam(r, "name") + var ( + p provisioner.Interface + err error + ) + if p, err = par.auth.LoadProvisionerByName(name); err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + return + } + + prov, err := par.adminDB.GetProvisioner(ctx, p.GetID()) + if err != nil { + api.WriteError(w, err) + return + } + + policy := prov.GetPolicy() + if policy != nil { + adminErr := admin.NewError(admin.ErrorBadRequestType, "provisioner %s already has a policy", name) + adminErr.Status = http.StatusConflict + api.WriteError(w, adminErr) + } + + var newPolicy = new(linkedca.Policy) + if err := api.ReadProtoJSON(r.Body, newPolicy); err != nil { + api.WriteError(w, err) + return + } + + prov.Policy = newPolicy + + err = par.auth.UpdateProvisioner(ctx, prov) + if err != nil { + api.WriteError(w, err) + return + } + + api.ProtoJSONStatus(w, newPolicy, http.StatusCreated) +} + +// UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request +func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + name := chi.URLParam(r, "name") + var ( + p provisioner.Interface + err error + ) + if p, err = par.auth.LoadProvisionerByName(name); err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + return + } + + prov, err := par.adminDB.GetProvisioner(ctx, p.GetID()) + if err != nil { + api.WriteError(w, err) + return + } + + var policy = new(linkedca.Policy) + if err := api.ReadProtoJSON(r.Body, policy); err != nil { + api.WriteError(w, err) + return + } + + prov.Policy = policy + err = par.auth.UpdateProvisioner(ctx, prov) + if err != nil { + api.WriteError(w, err) + return + } + + api.ProtoJSONStatus(w, policy, http.StatusOK) +} + +// DeleteProvisionerPolicy ... +func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + + ctx := r.Context() + name := chi.URLParam(r, "name") + var ( + p provisioner.Interface + err error + ) + if p, err = par.auth.LoadProvisionerByName(name); err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + return + } + + prov, err := par.adminDB.GetProvisioner(ctx, p.GetID()) + if err != nil { + api.WriteError(w, err) + return + } + + if prov.Policy == nil { + api.JSONNotFound(w) + return + } + + // remove the policy + prov.Policy = nil + + err = par.auth.UpdateProvisioner(ctx, prov) + if err != nil { + api.WriteError(w, err) + return + } + + api.JSON(w, &DeleteResponse{Status: "ok"}) +} + +// GetACMEAccountPolicy ... +func (par *PolicyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + api.JSON(w, "ok") +} + +func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + api.JSON(w, "ok") +} + +func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + api.JSON(w, "ok") +} + +func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + api.JSON(w, "ok") +} diff --git a/authority/admin/db.go b/authority/admin/db.go index bf34a3c2..75ac1368 100644 --- a/authority/admin/db.go +++ b/authority/admin/db.go @@ -69,6 +69,11 @@ type DB interface { GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) UpdateAdmin(ctx context.Context, admin *linkedca.Admin) error DeleteAdmin(ctx context.Context, id string) error + + CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error + GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) + UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error + DeleteAuthorityPolicy(ctx context.Context) error } // MockDB is an implementation of the DB interface that should only be used as @@ -86,6 +91,11 @@ type MockDB struct { MockUpdateAdmin func(ctx context.Context, adm *linkedca.Admin) error MockDeleteAdmin func(ctx context.Context, id string) error + MockCreateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error + MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error) + MockUpdateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error + MockDeleteAuthorityPolicy func(ctx context.Context) error + MockError error MockRet1 interface{} } @@ -179,3 +189,30 @@ func (m *MockDB) DeleteAdmin(ctx context.Context, id string) error { } return m.MockError } + +func (m *MockDB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + if m.MockCreateAuthorityPolicy != nil { + return m.MockCreateAuthorityPolicy(ctx, policy) + } + return m.MockError +} +func (m *MockDB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { + if m.MockGetAuthorityPolicy != nil { + return m.MockGetAuthorityPolicy(ctx) + } + return m.MockRet1.(*linkedca.Policy), m.MockError +} + +func (m *MockDB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + if m.MockUpdateAuthorityPolicy != nil { + return m.MockUpdateAuthorityPolicy(ctx, policy) + } + return m.MockError +} + +func (m *MockDB) DeleteAuthorityPolicy(ctx context.Context) error { + if m.MockDeleteAuthorityPolicy != nil { + return m.MockDeleteAuthorityPolicy(ctx) + } + return m.MockError +} diff --git a/authority/admin/db/nosql/nosql.go b/authority/admin/db/nosql/nosql.go index 22b049f5..32e05d92 100644 --- a/authority/admin/db/nosql/nosql.go +++ b/authority/admin/db/nosql/nosql.go @@ -11,8 +11,9 @@ import ( ) var ( - adminsTable = []byte("admins") - provisionersTable = []byte("provisioners") + adminsTable = []byte("admins") + provisionersTable = []byte("provisioners") + authorityPoliciesTable = []byte("authority_policies") ) // DB is a struct that implements the AdminDB interface. @@ -23,7 +24,7 @@ type DB struct { // New configures and returns a new Authority DB backend implemented using a nosql DB. func New(db nosqlDB.DB, authorityID string) (*DB, error) { - tables := [][]byte{adminsTable, provisionersTable} + tables := [][]byte{adminsTable, provisionersTable, authorityPoliciesTable} for _, b := range tables { if err := db.CreateTable(b); err != nil { return nil, errors.Wrapf(err, "error creating table %s", diff --git a/authority/admin/db/nosql/policy.go b/authority/admin/db/nosql/policy.go new file mode 100644 index 00000000..94ff2a0e --- /dev/null +++ b/authority/admin/db/nosql/policy.go @@ -0,0 +1,144 @@ +package nosql + +import ( + "context" + "encoding/json" + + "github.com/pkg/errors" + + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/nosql" + "go.step.sm/linkedca" +) + +type dbAuthorityPolicy struct { + ID string `json:"id"` + AuthorityID string `json:"authorityID"` + Policy *linkedca.Policy `json:"policy"` +} + +func (dbap *dbAuthorityPolicy) convert() *linkedca.Policy { + return dbap.Policy +} + +func (dbap *dbAuthorityPolicy) clone() *dbAuthorityPolicy { + u := *dbap + return &u +} + +func (db *DB) getDBAuthorityPolicyBytes(ctx context.Context, authorityID string) ([]byte, error) { + data, err := db.db.Get(authorityPoliciesTable, []byte(authorityID)) + if nosql.IsErrNotFound(err) { + return nil, admin.NewError(admin.ErrorNotFoundType, "policy %s not found", authorityID) + } else if err != nil { + return nil, errors.Wrapf(err, "error loading admin %s", authorityID) + } + return data, nil +} + +func (db *DB) unmarshalDBAuthorityPolicy(data []byte, authorityID string) (*dbAuthorityPolicy, error) { + var dba = new(dbAuthorityPolicy) + if err := json.Unmarshal(data, dba); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling admin %s into dbAdmin", authorityID) + } + // if !dba.DeletedAt.IsZero() { + // return nil, admin.NewError(admin.ErrorDeletedType, "admin %s is deleted", authorityID) + // } + if dba.AuthorityID != db.authorityID { + return nil, admin.NewError(admin.ErrorAuthorityMismatchType, + "admin %s is not owned by authority %s", dba.ID, db.authorityID) + } + return dba, nil +} + +func (db *DB) getDBAuthorityPolicy(ctx context.Context, authorityID string) (*dbAuthorityPolicy, error) { + data, err := db.getDBAuthorityPolicyBytes(ctx, authorityID) + if err != nil { + return nil, err + } + dbap, err := db.unmarshalDBAuthorityPolicy(data, authorityID) + if err != nil { + return nil, err + } + return dbap, nil +} + +func (db *DB) unmarshalAuthorityPolicy(data []byte, authorityID string) (*linkedca.Policy, error) { + dbap, err := db.unmarshalDBAuthorityPolicy(data, authorityID) + if err != nil { + return nil, err + } + return dbap.convert(), nil +} + +func (db *DB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + + dbap := &dbAuthorityPolicy{ + ID: db.authorityID, + AuthorityID: db.authorityID, + Policy: policy, + } + + old, err := db.getDBAuthorityPolicy(ctx, db.authorityID) + if err != nil { + return err + } + + return db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable) +} + +func (db *DB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { + // policy := &linkedca.Policy{ + // X509: &linkedca.X509Policy{ + // Allow: &linkedca.X509Names{ + // Dns: []string{".localhost"}, + // }, + // Deny: &linkedca.X509Names{ + // Dns: []string{"denied.localhost"}, + // }, + // }, + // Ssh: &linkedca.SSHPolicy{ + // User: &linkedca.SSHUserPolicy{ + // Allow: &linkedca.SSHUserNames{}, + // Deny: &linkedca.SSHUserNames{}, + // }, + // Host: &linkedca.SSHHostPolicy{ + // Allow: &linkedca.SSHHostNames{}, + // Deny: &linkedca.SSHHostNames{}, + // }, + // }, + // } + + dbap, err := db.getDBAuthorityPolicy(ctx, db.authorityID) + if err != nil { + return nil, err + } + + return dbap.convert(), nil +} + +func (db *DB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + old, err := db.getDBAuthorityPolicy(ctx, db.authorityID) + if err != nil { + return err + } + + dbap := &dbAuthorityPolicy{ + ID: db.authorityID, + AuthorityID: db.authorityID, + Policy: policy, + } + + return db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable) +} + +func (db *DB) DeleteAuthorityPolicy(ctx context.Context) error { + dbap, err := db.getDBAuthorityPolicy(ctx, db.authorityID) + if err != nil { + return err + } + old := dbap.clone() + + dbap.Policy = nil + return db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable) +} diff --git a/authority/admin/db/nosql/provisioner.go b/authority/admin/db/nosql/provisioner.go index 71d9c8d6..540e3ae2 100644 --- a/authority/admin/db/nosql/provisioner.go +++ b/authority/admin/db/nosql/provisioner.go @@ -19,6 +19,7 @@ type dbProvisioner struct { Type linkedca.Provisioner_Type `json:"type"` Name string `json:"name"` Claims *linkedca.Claims `json:"claims"` + Policy *linkedca.Policy `json:"policy"` Details []byte `json:"details"` X509Template *linkedca.Template `json:"x509Template"` SSHTemplate *linkedca.Template `json:"sshTemplate"` @@ -43,6 +44,7 @@ func (dbp *dbProvisioner) convert2linkedca() (*linkedca.Provisioner, error) { Type: dbp.Type, Name: dbp.Name, Claims: dbp.Claims, + Policy: dbp.Policy, Details: details, X509Template: dbp.X509Template, SshTemplate: dbp.SSHTemplate, @@ -160,6 +162,7 @@ func (db *DB) CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) Type: prov.Type, Name: prov.Name, Claims: prov.Claims, + Policy: prov.Policy, Details: details, X509Template: prov.X509Template, SSHTemplate: prov.SshTemplate, @@ -187,6 +190,7 @@ func (db *DB) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) } nu.Name = prov.Name nu.Claims = prov.Claims + nu.Policy = prov.Policy nu.Details, err = json.Marshal(prov.Details.GetData()) if err != nil { return admin.WrapErrorISE(err, "error marshaling details when updating provisioner %s", prov.Name) diff --git a/authority/authority.go b/authority/authority.go index 4eacfad7..aaf0e478 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -205,6 +205,47 @@ func (a *Authority) reloadAdminResources(ctx context.Context) error { a.provisioners = provClxn a.config.AuthorityConfig.Admins = adminList a.admins = adminClxn + + return nil +} + +// reloadPolicyEngines reloads x509 and SSH policy engines using +// configuration stored in the DB or from the configuration file. +func (a *Authority) reloadPolicyEngines(ctx context.Context) error { + var ( + err error + policyOptions *policy.Options + ) + if a.config.AuthorityConfig.EnableAdmin { + linkedPolicy, err := a.adminDB.GetAuthorityPolicy(ctx) + if err != nil { + return admin.WrapErrorISE(err, "error getting policy to initialize authority") + } + policyOptions = policyToCertificates(linkedPolicy) + } else { + policyOptions = a.config.AuthorityConfig.Policy + } + + // return early if no policy options set + if policyOptions == nil { + return nil + } + + // Initialize the x509 allow/deny policy engine + if a.x509Policy, err = policy.NewX509PolicyEngine(policyOptions.GetX509Options()); err != nil { + return err + } + + // // Initialize the SSH allow/deny policy engine for host certificates + if a.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(policyOptions.GetSSHOptions()); err != nil { + return err + } + + // // Initialize the SSH allow/deny policy engine for user certificates + if a.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(policyOptions.GetSSHOptions()); err != nil { + return err + } + return nil } @@ -533,6 +574,11 @@ func (a *Authority) init() error { return err } + // Load Policy Engines + if err := a.reloadPolicyEngines(context.Background()); err != nil { + return err + } + // Configure templates, currently only ssh templates are supported. if a.sshCAHostCertSignKey != nil || a.sshCAUserCertSignKey != nil { a.templates = a.config.Templates @@ -545,21 +591,6 @@ func (a *Authority) init() error { a.templates.Data["Step"] = tmplVars } - // Initialize the x509 allow/deny policy engine - if a.x509Policy, err = policy.NewX509PolicyEngine(a.config.AuthorityConfig.Policy.GetX509Options()); err != nil { - return err - } - - // // Initialize the SSH allow/deny policy engine for host certificates - if a.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(a.config.AuthorityConfig.Policy.GetSSHOptions()); err != nil { - return err - } - - // // Initialize the SSH allow/deny policy engine for user certificates - if a.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(a.config.AuthorityConfig.Policy.GetSSHOptions()); err != nil { - return err - } - // JWT numeric dates are seconds. a.startTime = time.Now().Truncate(time.Second) // Set flag indicating that initialization has been completed, and should diff --git a/authority/linkedca.go b/authority/linkedca.go index b568dcbb..11c8668c 100644 --- a/authority/linkedca.go +++ b/authority/linkedca.go @@ -15,6 +15,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/db" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" @@ -34,6 +35,9 @@ type linkedCaClient struct { authorityID string } +// interface guard +var _ admin.DB = (*linkedCaClient)(nil) + type linkedCAClaims struct { jose.Claims SANs []string `json:"sans"` @@ -310,6 +314,22 @@ func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) { return resp.Status != linkedca.RevocationStatus_ACTIVE, nil } +func (c *linkedCaClient) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + return errors.New("not implemented yet") +} + +func (c *linkedCaClient) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("not implemented yet") +} + +func (c *linkedCaClient) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + return errors.New("not implemented yet") +} + +func (c *linkedCaClient) DeleteAuthorityPolicy(ctx context.Context) error { + return errors.New("not implemented yet") +} + func serializeCertificate(crt *x509.Certificate) string { if crt == nil { return "" diff --git a/authority/policy.go b/authority/policy.go new file mode 100644 index 00000000..8ef264d0 --- /dev/null +++ b/authority/policy.go @@ -0,0 +1,132 @@ +package authority + +import ( + "context" + + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/policy" + "go.step.sm/linkedca" +) + +func (a *Authority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + + policy, err := a.adminDB.GetAuthorityPolicy(ctx) + if err != nil { + return nil, err + } + + return policy, nil +} + +func (a *Authority) StoreAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + + if err := a.adminDB.CreateAuthorityPolicy(ctx, policy); err != nil { + return err + } + + if err := a.reloadPolicyEngines(ctx); err != nil { + return admin.WrapErrorISE(err, "error reloading admin resources when creating authority policy") + } + + return nil +} + +func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + + if err := a.adminDB.UpdateAuthorityPolicy(ctx, policy); err != nil { + return err + } + + if err := a.reloadPolicyEngines(ctx); err != nil { + return admin.WrapErrorISE(err, "error reloading admin resources when updating authority policy") + } + + return nil +} + +func (a *Authority) RemoveAuthorityPolicy(ctx context.Context) error { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + + if err := a.adminDB.DeleteAuthorityPolicy(ctx); err != nil { + return err + } + + if err := a.reloadPolicyEngines(ctx); err != nil { + return admin.WrapErrorISE(err, "error reloading admin resources when deleting authority policy") + } + + return nil +} + +func policyToCertificates(p *linkedca.Policy) *policy.Options { + // return early + if p == nil { + return nil + } + // prepare full policy struct + opts := &policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{}, + DeniedNames: &policy.X509NameOptions{}, + }, + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{}, + DeniedNames: &policy.SSHNameOptions{}, + }, + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{}, + DeniedNames: &policy.SSHNameOptions{}, + }, + }, + } + // fill x509 policy configuration + if p.X509 != nil { + if p.X509.Allow != nil { + opts.X509.AllowedNames.DNSDomains = p.X509.Allow.Dns + opts.X509.AllowedNames.IPRanges = p.X509.Allow.Ips + opts.X509.AllowedNames.EmailAddresses = p.X509.Allow.Emails + opts.X509.AllowedNames.URIDomains = p.X509.Allow.Uris + } + if p.X509.Deny != nil { + opts.X509.DeniedNames.DNSDomains = p.X509.Deny.Dns + opts.X509.DeniedNames.IPRanges = p.X509.Deny.Ips + opts.X509.DeniedNames.EmailAddresses = p.X509.Deny.Emails + opts.X509.DeniedNames.URIDomains = p.X509.Deny.Uris + } + } + // fill ssh policy configuration + if p.Ssh != nil { + if p.Ssh.Host != nil { + if p.Ssh.Host.Allow != nil { + opts.SSH.Host.AllowedNames.DNSDomains = p.Ssh.Host.Allow.Dns + opts.SSH.Host.AllowedNames.IPRanges = p.Ssh.Host.Allow.Ips + opts.SSH.Host.AllowedNames.EmailAddresses = p.Ssh.Host.Allow.Principals + } + if p.Ssh.Host.Deny != nil { + opts.SSH.Host.DeniedNames.DNSDomains = p.Ssh.Host.Deny.Dns + opts.SSH.Host.DeniedNames.IPRanges = p.Ssh.Host.Deny.Ips + opts.SSH.Host.DeniedNames.Principals = p.Ssh.Host.Deny.Principals + } + } + if p.Ssh.User != nil { + if p.Ssh.User.Allow != nil { + opts.SSH.User.AllowedNames.EmailAddresses = p.Ssh.User.Allow.Emails + opts.SSH.User.AllowedNames.Principals = p.Ssh.User.Allow.Principals + } + if p.Ssh.User.Deny != nil { + opts.SSH.User.DeniedNames.EmailAddresses = p.Ssh.User.Deny.Emails + opts.SSH.User.DeniedNames.Principals = p.Ssh.User.Deny.Principals + } + } + } + + return opts +} diff --git a/authority/policy/options.go b/authority/policy/options.go index f57f3bcf..5c6e6134 100644 --- a/authority/policy/options.go +++ b/authority/policy/options.go @@ -1,10 +1,14 @@ package policy +// Options is a container for authority level x509 and SSH +// policy configuration. type Options struct { X509 *X509PolicyOptions `json:"x509,omitempty"` SSH *SSHPolicyOptions `json:"ssh,omitempty"` } +// GetX509Options returns the x509 authority level policy +// configuration func (o *Options) GetX509Options() *X509PolicyOptions { if o == nil { return nil @@ -12,6 +16,8 @@ func (o *Options) GetX509Options() *X509PolicyOptions { return o.X509 } +// GetSSHOptions returns the SSH authority level policy +// configuration func (o *Options) GetSSHOptions() *SSHPolicyOptions { if o == nil { return nil @@ -19,16 +25,19 @@ func (o *Options) GetSSHOptions() *SSHPolicyOptions { return o.SSH } +// X509PolicyOptionsInterface is an interface for providers +// of x509 allowed and denied names. type X509PolicyOptionsInterface interface { GetAllowedNameOptions() *X509NameOptions GetDeniedNameOptions() *X509NameOptions } +// X509PolicyOptions is a container for x509 allowed and denied +// names. type X509PolicyOptions struct { - // AllowedNames ... + // AllowedNames contains the x509 allowed names AllowedNames *X509NameOptions `json:"allow,omitempty"` - - // DeniedNames ... + // DeniedNames contains the x509 denied names DeniedNames *X509NameOptions `json:"deny,omitempty"` } @@ -49,6 +58,8 @@ func (o *X509NameOptions) HasNames() bool { len(o.URIDomains) > 0 } +// SSHPolicyOptionsInterface is an interface for providers of +// SSH user and host name policy configuration. type SSHPolicyOptionsInterface interface { GetAllowedUserNameOptions() *SSHNameOptions GetDeniedUserNameOptions() *SSHNameOptions @@ -56,16 +67,16 @@ type SSHPolicyOptionsInterface interface { GetDeniedHostNameOptions() *SSHNameOptions } +// SSHPolicyOptions is a container for SSH user and host policy +// configuration type SSHPolicyOptions struct { // User contains SSH user certificate options. User *SSHUserCertificateOptions `json:"user,omitempty"` - // Host contains SSH host certificate options. Host *SSHHostCertificateOptions `json:"host,omitempty"` } -// GetAllowedNameOptions returns AllowedNames, which models the -// SANs that ... +// GetAllowedNameOptions returns x509 allowed name policy configuration func (o *X509PolicyOptions) GetAllowedNameOptions() *X509NameOptions { if o == nil { return nil @@ -73,8 +84,7 @@ func (o *X509PolicyOptions) GetAllowedNameOptions() *X509NameOptions { return o.AllowedNames } -// GetDeniedNameOptions returns the DeniedNames, which models the -// SANs that ... +// GetDeniedNameOptions returns the x509 denied name policy configuration func (o *X509PolicyOptions) GetDeniedNameOptions() *X509NameOptions { if o == nil { return nil @@ -82,6 +92,8 @@ func (o *X509PolicyOptions) GetDeniedNameOptions() *X509NameOptions { return o.DeniedNames } +// GetAllowedUserNameOptions returns the SSH allowed user name policy +// configuration. func (o *SSHPolicyOptions) GetAllowedUserNameOptions() *SSHNameOptions { if o == nil { return nil @@ -92,6 +104,8 @@ func (o *SSHPolicyOptions) GetAllowedUserNameOptions() *SSHNameOptions { return o.User.AllowedNames } +// GetDeniedUserNameOptions returns the SSH denied user name policy +// configuration. func (o *SSHPolicyOptions) GetDeniedUserNameOptions() *SSHNameOptions { if o == nil { return nil @@ -102,6 +116,8 @@ func (o *SSHPolicyOptions) GetDeniedUserNameOptions() *SSHNameOptions { return o.User.DeniedNames } +// GetAllowedHostNameOptions returns the SSH allowed host name policy +// configuration. func (o *SSHPolicyOptions) GetAllowedHostNameOptions() *SSHNameOptions { if o == nil { return nil @@ -112,6 +128,8 @@ func (o *SSHPolicyOptions) GetAllowedHostNameOptions() *SSHNameOptions { return o.Host.AllowedNames } +// GetDeniedHostNameOptions returns the SSH denied host name policy +// configuration. func (o *SSHPolicyOptions) GetDeniedHostNameOptions() *SSHNameOptions { if o == nil { return nil diff --git a/authority/provisioners.go b/authority/provisioners.go index 8dc27c6a..7a579267 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -12,6 +12,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/config" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "go.step.sm/cli-utils/step" @@ -395,6 +396,58 @@ func optionsToCertificates(p *linkedca.Provisioner) *provisioner.Options { ops.SSH.Template = string(p.SshTemplate.Template) ops.SSH.TemplateData = p.SshTemplate.Data } + if p.Policy != nil { + if p.Policy.X509 != nil { + if p.Policy.X509.Allow != nil { + ops.X509.AllowedNames = &policy.X509NameOptions{ + DNSDomains: p.Policy.X509.Allow.Dns, + IPRanges: p.Policy.X509.Allow.Ips, + EmailAddresses: p.Policy.X509.Allow.Emails, + URIDomains: p.Policy.X509.Allow.Uris, + } + } + if p.Policy.X509.Deny != nil { + ops.X509.DeniedNames = &policy.X509NameOptions{ + DNSDomains: p.Policy.X509.Deny.Dns, + IPRanges: p.Policy.X509.Deny.Ips, + EmailAddresses: p.Policy.X509.Deny.Emails, + URIDomains: p.Policy.X509.Deny.Uris, + } + } + } + if p.Policy.Ssh != nil { + if p.Policy.Ssh.Host != nil { + ops.SSH.Host = &policy.SSHHostCertificateOptions{} + if p.Policy.Ssh.Host.Allow != nil { + ops.SSH.Host.AllowedNames = &policy.SSHNameOptions{ + DNSDomains: p.Policy.Ssh.Host.Allow.Dns, + IPRanges: p.Policy.Ssh.Host.Allow.Ips, + } + } + if p.Policy.Ssh.Host.Deny != nil { + ops.SSH.Host.DeniedNames = &policy.SSHNameOptions{ + DNSDomains: p.Policy.Ssh.Host.Deny.Dns, + IPRanges: p.Policy.Ssh.Host.Deny.Ips, + } + } + } + if p.Policy.Ssh.User != nil { + ops.SSH.User = &policy.SSHUserCertificateOptions{} + if p.Policy.Ssh.User.Allow != nil { + ops.SSH.User.AllowedNames = &policy.SSHNameOptions{ + EmailAddresses: p.Policy.Ssh.User.Allow.Emails, + Principals: p.Policy.Ssh.User.Allow.Principals, + } + } + if p.Policy.Ssh.User.Deny != nil { + ops.SSH.User.DeniedNames = &policy.SSHNameOptions{ + EmailAddresses: p.Policy.Ssh.User.Deny.Emails, + Principals: p.Policy.Ssh.User.Deny.Principals, + } + } + } + } + } return ops } diff --git a/authority/tls.go b/authority/tls.go index d749e2ad..96c80e9a 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -192,7 +192,10 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign } // If a policy is configured, perform allow/deny policy check on authority level - if a.x509Policy != nil { + // TODO: policy currently also applies to admin token certs; how to circumvent? + // Allow any name of an admin in the DB? Or in the admin collection? + todoRemoveThis := false + if todoRemoveThis && a.x509Policy != nil { allowed, err := a.x509Policy.AreCertificateNamesAllowed(leaf) if err != nil { return nil, errs.InternalServerErr(err, diff --git a/ca/ca.go b/ca/ca.go index c95ba22f..f4585aba 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -208,7 +208,8 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { adminDB := auth.GetAdminDatabase() if adminDB != nil { acmeAdminResponder := adminAPI.NewACMEAdminResponder() - adminHandler := adminAPI.NewHandler(auth, adminDB, acmeDB, acmeAdminResponder) + policyAdminResponder := adminAPI.NewPolicyAdminResponder(auth, adminDB) + adminHandler := adminAPI.NewHandler(auth, adminDB, acmeDB, acmeAdminResponder, policyAdminResponder) mux.Route("/admin", func(r chi.Router) { adminHandler.Route(r) }) diff --git a/go.mod b/go.mod index 46fe260c..76cdff9a 100644 --- a/go.mod +++ b/go.mod @@ -49,4 +49,4 @@ require ( // replace github.com/smallstep/nosql => ../nosql // replace go.step.sm/crypto => ../crypto // replace go.step.sm/cli-utils => ../cli-utils -// replace go.step.sm/linkedca => ../linkedca +replace go.step.sm/linkedca => ../linkedca diff --git a/go.sum b/go.sum index 1cd8e2e7..ba7cb531 100644 --- a/go.sum +++ b/go.sum @@ -685,10 +685,6 @@ go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/ go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= go.step.sm/crypto v0.15.0 h1:VioBln+x3+RoejgeBhvxkLGVYdWRy6PFiAaUUN29/E0= go.step.sm/crypto v0.15.0/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g= -go.step.sm/linkedca v0.9.2 h1:CpAkd174sLXFfrOZrbPEiTzik91QRj3+L0omsiwsiok= -go.step.sm/linkedca v0.9.2/go.mod h1:5uTRjozEGSPAZal9xJqlaD38cvJcLe3o1VAFVjqcORo= -go.step.sm/linkedca v0.10.0 h1:+bqymMRulHYkVde4l16FnqFVskoS6HCWJN5Z5cxAqF8= -go.step.sm/linkedca v0.10.0/go.mod h1:5uTRjozEGSPAZal9xJqlaD38cvJcLe3o1VAFVjqcORo= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= From ead742ca0ff8049d52b31dd7b395b6f09308b673 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 15 Mar 2022 12:13:01 -0700 Subject: [PATCH 25/44] Fix unit test --- ca/client_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ca/client_test.go b/ca/client_test.go index 6e352291..a00ca1cf 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -1142,7 +1142,7 @@ func TestClient_GetCaURL(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c, err := NewClient(tt.caURL) + c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport)) if err != nil { t.Errorf("NewClient() error = %v", err) return From 915911efb648c1174ce0fde7c16d5da538237cae Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 15 Mar 2022 12:26:00 -0700 Subject: [PATCH 26/44] Disable http loggers in test. They hide the test that fail on tests in the CI. --- ca/testdata/ca.json | 2 +- ca/testdata/federated-ca.json | 2 +- ca/testdata/rotate-ca-0.json | 2 +- ca/testdata/rotate-ca-1.json | 2 +- ca/testdata/rotate-ca-2.json | 2 +- ca/testdata/rotate-ca-3.json | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ca/testdata/ca.json b/ca/testdata/ca.json index d40325e8..2a336f24 100644 --- a/ca/testdata/ca.json +++ b/ca/testdata/ca.json @@ -6,7 +6,7 @@ "password": "password", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.3, diff --git a/ca/testdata/federated-ca.json b/ca/testdata/federated-ca.json index 342adfcf..0b1c6c8d 100644 --- a/ca/testdata/federated-ca.json +++ b/ca/testdata/federated-ca.json @@ -6,7 +6,7 @@ "password": "asdf", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, diff --git a/ca/testdata/rotate-ca-0.json b/ca/testdata/rotate-ca-0.json index 20dd603a..aa9353ed 100644 --- a/ca/testdata/rotate-ca-0.json +++ b/ca/testdata/rotate-ca-0.json @@ -5,7 +5,7 @@ "password": "password", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, diff --git a/ca/testdata/rotate-ca-1.json b/ca/testdata/rotate-ca-1.json index b038f694..c78ba035 100644 --- a/ca/testdata/rotate-ca-1.json +++ b/ca/testdata/rotate-ca-1.json @@ -5,7 +5,7 @@ "password": "password", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, diff --git a/ca/testdata/rotate-ca-2.json b/ca/testdata/rotate-ca-2.json index 7ec965d0..2db1c992 100644 --- a/ca/testdata/rotate-ca-2.json +++ b/ca/testdata/rotate-ca-2.json @@ -5,7 +5,7 @@ "password": "asdf", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, diff --git a/ca/testdata/rotate-ca-3.json b/ca/testdata/rotate-ca-3.json index 968da6ba..50f4a118 100644 --- a/ca/testdata/rotate-ca-3.json +++ b/ca/testdata/rotate-ca-3.json @@ -5,7 +5,7 @@ "password": "asdf", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, From 15477f6d7be0525574776264205ce8f6ab7a52d7 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 15 Mar 2022 23:28:56 +0100 Subject: [PATCH 27/44] Make custom SCEP CA paths automagic --- authority/provisioner/scep.go | 16 ++++------------ scep/api/api.go | 17 ++--------------- 2 files changed, 6 insertions(+), 27 deletions(-) diff --git a/authority/provisioner/scep.go b/authority/provisioner/scep.go index 05802ffb..5d67762c 100644 --- a/authority/provisioner/scep.go +++ b/authority/provisioner/scep.go @@ -26,18 +26,10 @@ type SCEP struct { // Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7 // at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63 // Defaults to 0, being DES-CBC - EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` - // CustomPath is used to specify a custom path on which the SCEP provisioner will be made - // available. By default a SCEP provisioner is available at - // https://
:/scep/ and requests performed looking similar - // to https://
:/scep/?operations=GetCACert. When CustomPath - // is set, the SCEP URL will be https://
:/scep//, - // resulting in SCEP clients that expect a specific path, such as "/pkiclient.exe", to be - // able to interact with the SCEP provisioner. - CustomPath string `json:"customPath,omitempty"` - Options *Options `json:"options,omitempty"` - Claims *Claims `json:"claims,omitempty"` - claimer *Claimer + EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` + Options *Options `json:"options,omitempty"` + Claims *Claims `json:"claims,omitempty"` + claimer *Claimer secretChallengePassword string encryptionAlgorithm int diff --git a/scep/api/api.go b/scep/api/api.go index 9b48187a..77c683ee 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -66,9 +66,9 @@ func New(scepAuth scep.Interface) api.RouterHandler { // Route traffic and implement the Router interface. func (h *Handler) Route(r api.Router) { getLink := h.Auth.GetLinkExplicit - r.MethodFunc(http.MethodGet, getLink("{provisionerName}/{customPath}*", false, nil), h.lookupProvisioner(h.Get)) + r.MethodFunc(http.MethodGet, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Get)) r.MethodFunc(http.MethodGet, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Get)) - r.MethodFunc(http.MethodPost, getLink("{provisionerName}/{customPath}*", false, nil), h.lookupProvisioner(h.Post)) + r.MethodFunc(http.MethodPost, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Post)) r.MethodFunc(http.MethodPost, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Post)) } @@ -193,13 +193,6 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { return } - customPathParam := chi.URLParam(r, "customPath") - customPath, err := url.PathUnescape(customPathParam) - if err != nil { - api.WriteError(w, err) - return - } - p, err := h.Auth.LoadProvisionerByName(provisionerName) if err != nil { api.WriteError(w, err) @@ -212,12 +205,6 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { return } - configuredCustomPath := strings.Trim(prov.CustomPath, "/") - if customPath != configuredCustomPath { - api.WriteError(w, errors.Errorf("custom path requested '%s' is not the expected path '%s'", customPath, configuredCustomPath)) - return - } - ctx := r.Context() ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(prov)) next(w, r.WithContext(ctx)) From dcbcd88a62cfa452da2fb1d2a9b049cdc735159d Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 16 Mar 2022 00:04:15 +0100 Subject: [PATCH 28/44] Add changelog item for dynamic SCEP CA URL paths --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 44c713d7..b43a5f7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased - 0.18.3] - DATE ### Added ### Changed +- Made SCEP CA URL paths dynamic ### Deprecated ### Removed ### Fixed From 7fb8acda2778c9188e6d19583ce041f4706629fe Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Fri, 18 Mar 2022 15:21:40 +0200 Subject: [PATCH 29/44] api/read: initial implementation of the package --- api/read/read.go | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 api/read/read.go diff --git a/api/read/read.go b/api/read/read.go new file mode 100644 index 00000000..fab8fa8f --- /dev/null +++ b/api/read/read.go @@ -0,0 +1,30 @@ +// Package read implements request object readers. +package read + +import ( + "encoding/json" + "io" + + "github.com/smallstep/certificates/errs" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) + +// JSON reads JSON from the request body and stores it in the value +// pointed by v. +func JSON(r io.Reader, v interface{}) error { + if err := json.NewDecoder(r).Decode(v); err != nil { + return errs.BadRequestErr(err, "error decoding json") + } + return nil +} + +// ProtoJSON reads JSON from the request body and stores it in the value +// pointed by v. +func ProtoJSON(r io.Reader, m proto.Message) error { + data, err := io.ReadAll(r) + if err != nil { + return errs.BadRequestErr(err, "error reading request body") + } + return protojson.Unmarshal(data, m) +} From 29092b9d8aecfd21238d7967a202f51bbd165228 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Fri, 18 Mar 2022 15:25:34 +0200 Subject: [PATCH 30/44] api: refactored to use the read package --- api/api.go | 1 + api/api_test.go | 8 +++++--- api/errors.go | 1 + api/read/read.go | 3 ++- api/rekey.go | 3 ++- api/revoke.go | 6 ++++-- api/revoke_test.go | 1 + api/sign.go | 3 ++- api/ssh.go | 12 +++++++----- api/sshRekey.go | 6 ++++-- api/sshRenew.go | 4 +++- api/sshRevoke.go | 6 ++++-- api/ssh_test.go | 3 ++- api/utils.go | 24 ++---------------------- api/utils_test.go | 4 +++- 15 files changed, 43 insertions(+), 42 deletions(-) diff --git a/api/api.go b/api/api.go index 912e39dd..47d2fd27 100644 --- a/api/api.go +++ b/api/api.go @@ -20,6 +20,7 @@ import ( "github.com/go-chi/chi" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" diff --git a/api/api_test.go b/api/api_test.go index 717621cd..25abdeff 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -28,15 +28,17 @@ import ( "github.com/go-chi/chi" "github.com/pkg/errors" + "golang.org/x/crypto/ssh" + + "go.step.sm/crypto/jose" + "go.step.sm/crypto/x509util" + "github.com/smallstep/assert" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/templates" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/x509util" - "golang.org/x/crypto/ssh" ) const ( diff --git a/api/errors.go b/api/errors.go index bff46b55..522fa955 100644 --- a/api/errors.go +++ b/api/errors.go @@ -7,6 +7,7 @@ import ( "os" "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/errs" diff --git a/api/read/read.go b/api/read/read.go index fab8fa8f..de92c5d7 100644 --- a/api/read/read.go +++ b/api/read/read.go @@ -5,9 +5,10 @@ import ( "encoding/json" "io" - "github.com/smallstep/certificates/errs" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" + + "github.com/smallstep/certificates/errs" ) // JSON reads JSON from the request body and stores it in the value diff --git a/api/rekey.go b/api/rekey.go index b7958844..269086bb 100644 --- a/api/rekey.go +++ b/api/rekey.go @@ -3,6 +3,7 @@ package api import ( "net/http" + "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/errs" ) @@ -32,7 +33,7 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { } var body RekeyRequest - if err := ReadJSON(r.Body, &body); err != nil { + if err := read.JSON(r.Body, &body); err != nil { WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } diff --git a/api/revoke.go b/api/revoke.go index 25520e3e..49822e6d 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -4,11 +4,13 @@ import ( "context" "net/http" + "golang.org/x/crypto/ocsp" + + "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" - "golang.org/x/crypto/ocsp" ) // RevokeResponse is the response object that returns the health of the server. @@ -48,7 +50,7 @@ func (r *RevokeRequest) Validate() (err error) { // TODO: Add CRL and OCSP support. func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { var body RevokeRequest - if err := ReadJSON(r.Body, &body); err != nil { + if err := read.JSON(r.Body, &body); err != nil { WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } diff --git a/api/revoke_test.go b/api/revoke_test.go index 4ed4e3fe..7635ce68 100644 --- a/api/revoke_test.go +++ b/api/revoke_test.go @@ -13,6 +13,7 @@ import ( "testing" "github.com/pkg/errors" + "github.com/smallstep/assert" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" diff --git a/api/sign.go b/api/sign.go index 93c5f599..b2eef45d 100644 --- a/api/sign.go +++ b/api/sign.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" + "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" @@ -49,7 +50,7 @@ type SignResponse struct { // information in the certificate request. func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { var body SignRequest - if err := ReadJSON(r.Body, &body); err != nil { + if err := read.JSON(r.Body, &body); err != nil { WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } diff --git a/api/ssh.go b/api/ssh.go index c9be1527..fc185d07 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -9,12 +9,14 @@ import ( "time" "github.com/pkg/errors" + "golang.org/x/crypto/ssh" + + "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/templates" - "golang.org/x/crypto/ssh" ) // SSHAuthority is the interface implemented by a SSH CA authority. @@ -249,7 +251,7 @@ type SSHBastionResponse struct { // the request. func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { var body SSHSignRequest - if err := ReadJSON(r.Body, &body); err != nil { + if err := read.JSON(r.Body, &body); err != nil { WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } @@ -393,7 +395,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { // and servers. func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { var body SSHConfigRequest - if err := ReadJSON(r.Body, &body); err != nil { + if err := read.JSON(r.Body, &body); err != nil { WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } @@ -425,7 +427,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { var body SSHCheckPrincipalRequest - if err := ReadJSON(r.Body, &body); err != nil { + if err := read.JSON(r.Body, &body); err != nil { WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } @@ -464,7 +466,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { // SSHBastion provides returns the bastion configured if any. func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { var body SSHBastionRequest - if err := ReadJSON(r.Body, &body); err != nil { + if err := read.JSON(r.Body, &body); err != nil { WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } diff --git a/api/sshRekey.go b/api/sshRekey.go index 8670f0bd..b7581749 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -4,9 +4,11 @@ import ( "net/http" "time" + "golang.org/x/crypto/ssh" + + "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" - "golang.org/x/crypto/ssh" ) // SSHRekeyRequest is the request body of an SSH certificate request. @@ -38,7 +40,7 @@ type SSHRekeyResponse struct { // the request. func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { var body SSHRekeyRequest - if err := ReadJSON(r.Body, &body); err != nil { + if err := read.JSON(r.Body, &body); err != nil { WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } diff --git a/api/sshRenew.go b/api/sshRenew.go index 57b6f432..b98466bf 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -6,6 +6,8 @@ import ( "time" "github.com/pkg/errors" + + "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" ) @@ -36,7 +38,7 @@ type SSHRenewResponse struct { // the request. func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { var body SSHRenewRequest - if err := ReadJSON(r.Body, &body); err != nil { + if err := read.JSON(r.Body, &body); err != nil { WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } diff --git a/api/sshRevoke.go b/api/sshRevoke.go index 60f44f2a..2d2da1f7 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -3,11 +3,13 @@ package api import ( "net/http" + "golang.org/x/crypto/ocsp" + + "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" - "golang.org/x/crypto/ocsp" ) // SSHRevokeResponse is the response object that returns the health of the server. @@ -47,7 +49,7 @@ func (r *SSHRevokeRequest) Validate() (err error) { // NOTE: currently only Passive revocation is supported. func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { var body SSHRevokeRequest - if err := ReadJSON(r.Body, &body); err != nil { + if err := read.JSON(r.Body, &body); err != nil { WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } diff --git a/api/ssh_test.go b/api/ssh_test.go index a3d7da0d..88a301f5 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -18,12 +18,13 @@ import ( "testing" "time" + "golang.org/x/crypto/ssh" + "github.com/smallstep/assert" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/templates" - "golang.org/x/crypto/ssh" ) var ( diff --git a/api/utils.go b/api/utils.go index a7f4bf58..9daa0cd2 100644 --- a/api/utils.go +++ b/api/utils.go @@ -2,14 +2,13 @@ package api import ( "encoding/json" - "io" "log" "net/http" - "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/logging" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" + + "github.com/smallstep/certificates/logging" ) // EnableLogger is an interface that enables response logging for an object. @@ -88,22 +87,3 @@ func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) { } //LogEnabledResponse(w, v) } - -// ReadJSON reads JSON from the request body and stores it in the value -// pointed by v. -func ReadJSON(r io.Reader, v interface{}) error { - if err := json.NewDecoder(r).Decode(v); err != nil { - return errs.BadRequestErr(err, "error decoding json") - } - return nil -} - -// ReadProtoJSON reads JSON from the request body and stores it in the value -// pointed by v. -func ReadProtoJSON(r io.Reader, m proto.Message) error { - data, err := io.ReadAll(r) - if err != nil { - return errs.BadRequestErr(err, "error reading request body") - } - return protojson.Unmarshal(data, m) -} diff --git a/api/utils_test.go b/api/utils_test.go index 81146653..c683b4c2 100644 --- a/api/utils_test.go +++ b/api/utils_test.go @@ -9,6 +9,8 @@ import ( "testing" "github.com/pkg/errors" + + "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" ) @@ -104,7 +106,7 @@ func TestReadJSON(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := ReadJSON(tt.args.r, &tt.args.v) + err := read.JSON(tt.args.r, &tt.args.v) if (err != nil) != tt.wantErr { t.Errorf("ReadJSON() error = %v, wantErr %v", err, tt.wantErr) } From 4fb38afc573813343551f7c86ddafd973f637b3c Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Fri, 18 Mar 2022 15:28:42 +0200 Subject: [PATCH 31/44] authority/admin/api: refactored to use the read package --- authority/admin/api/admin.go | 9 ++++++--- authority/admin/api/provisioner.go | 8 +++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index 7aa66d0f..43607c52 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -5,10 +5,13 @@ import ( "net/http" "github.com/go-chi/chi" + + "go.step.sm/linkedca" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/linkedca" ) type adminAuthority interface { @@ -112,7 +115,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { // CreateAdmin creates a new admin. func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { var body CreateAdminRequest - if err := api.ReadJSON(r.Body, &body); err != nil { + if err := read.JSON(r.Body, &body); err != nil { api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) return } @@ -156,7 +159,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { // UpdateAdmin updates an existing admin. func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { var body UpdateAdminRequest - if err := api.ReadJSON(r.Body, &body); err != nil { + if err := read.JSON(r.Body, &body); err != nil { api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) return } diff --git a/authority/admin/api/provisioner.go b/authority/admin/api/provisioner.go index b8cc0f4c..2106733d 100644 --- a/authority/admin/api/provisioner.go +++ b/authority/admin/api/provisioner.go @@ -4,12 +4,14 @@ import ( "net/http" "github.com/go-chi/chi" + "go.step.sm/linkedca" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" - "go.step.sm/linkedca" ) // GetProvisionersResponse is the type for GET /admin/provisioners responses. @@ -72,7 +74,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { // CreateProvisioner creates a new prov. func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { var prov = new(linkedca.Provisioner) - if err := api.ReadProtoJSON(r.Body, prov); err != nil { + if err := read.ProtoJSON(r.Body, prov); err != nil { api.WriteError(w, err) return } @@ -122,7 +124,7 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { // UpdateProvisioner updates an existing prov. func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { var nu = new(linkedca.Provisioner) - if err := api.ReadProtoJSON(r.Body, nu); err != nil { + if err := read.ProtoJSON(r.Body, nu); err != nil { api.WriteError(w, err) return } From 9ba33bab4e302e88f051072a17befe315cb0a4b2 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Fri, 18 Mar 2022 15:29:20 +0200 Subject: [PATCH 32/44] ca: refactored to use the read package --- ca/client_test.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ca/client_test.go b/ca/client_test.go index a00ca1cf..794d0e35 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -17,13 +17,15 @@ import ( "time" "github.com/pkg/errors" + "go.step.sm/crypto/x509util" + "golang.org/x/crypto/ssh" + "github.com/smallstep/assert" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" - "go.step.sm/crypto/x509util" - "golang.org/x/crypto/ssh" ) const ( @@ -354,7 +356,7 @@ func TestClient_Sign(t *testing.T) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { body := new(api.SignRequest) - if err := api.ReadJSON(req.Body, body); err != nil { + if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) assert.Fatal(t, ok, "response expected to be error type") api.WriteError(w, e) @@ -426,7 +428,7 @@ func TestClient_Revoke(t *testing.T) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { body := new(api.RevokeRequest) - if err := api.ReadJSON(req.Body, body); err != nil { + if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) assert.Fatal(t, ok, "response expected to be error type") api.WriteError(w, e) From df89ed5acb0c4702bc77d10bf0e475624797e7e4 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Fri, 18 Mar 2022 16:58:25 +0200 Subject: [PATCH 33/44] api: moved read-related tests to api/read --- api/read/read_test.go | 46 +++++++++++++++++++++++++++++++++++++++++++ api/utils_test.go | 39 ------------------------------------ 2 files changed, 46 insertions(+), 39 deletions(-) create mode 100644 api/read/read_test.go diff --git a/api/read/read_test.go b/api/read/read_test.go new file mode 100644 index 00000000..f2eff1bc --- /dev/null +++ b/api/read/read_test.go @@ -0,0 +1,46 @@ +package read + +import ( + "io" + "reflect" + "strings" + "testing" + + "github.com/smallstep/certificates/errs" +) + +func TestJSON(t *testing.T) { + type args struct { + r io.Reader + v interface{} + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false}, + {"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := JSON(tt.args.r, &tt.args.v) + if (err != nil) != tt.wantErr { + t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + e, ok := err.(*errs.Error) + if ok { + if code := e.StatusCode(); code != 400 { + t.Errorf("error.StatusCode() = %v, wants 400", code) + } + } else { + t.Errorf("error type = %T, wants *Error", err) + } + } else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) { + t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"}) + } + }) + } +} diff --git a/api/utils_test.go b/api/utils_test.go index c683b4c2..12350c97 100644 --- a/api/utils_test.go +++ b/api/utils_test.go @@ -1,17 +1,13 @@ package api import ( - "io" "net/http" "net/http/httptest" "reflect" - "strings" "testing" "github.com/pkg/errors" - "github.com/smallstep/certificates/api/read" - "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" ) @@ -90,38 +86,3 @@ func TestJSON(t *testing.T) { }) } } - -func TestReadJSON(t *testing.T) { - type args struct { - r io.Reader - v interface{} - } - tests := []struct { - name string - args args - wantErr bool - }{ - {"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false}, - {"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := read.JSON(tt.args.r, &tt.args.v) - if (err != nil) != tt.wantErr { - t.Errorf("ReadJSON() error = %v, wantErr %v", err, tt.wantErr) - } - if tt.wantErr { - e, ok := err.(*errs.Error) - if ok { - if code := e.StatusCode(); code != 400 { - t.Errorf("error.StatusCode() = %v, wants 400", code) - } - } else { - t.Errorf("error type = %T, wants *Error", err) - } - } else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) { - t.Errorf("ReadJSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"}) - } - }) - } -} From e6b235927314e952e3841c9d655aacb6ce868c16 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Fri, 18 Mar 2022 18:48:43 +0200 Subject: [PATCH 34/44] ca: fixed import statement order --- ca/client_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ca/client_test.go b/ca/client_test.go index 794d0e35..4628d19b 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -17,9 +17,10 @@ import ( "time" "github.com/pkg/errors" - "go.step.sm/crypto/x509util" "golang.org/x/crypto/ssh" + "go.step.sm/crypto/x509util" + "github.com/smallstep/assert" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" From 101ca6a2d379a67ac8a4347b0fa5fe74f021ca87 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Mon, 21 Mar 2022 15:53:59 +0100 Subject: [PATCH 35/44] Check admin subjects before changing policy --- acme/api/order.go | 2 + authority/admin/api/acme.go | 2 +- authority/admin/api/admin.go | 4 +- authority/admin/api/admin_test.go | 4 +- authority/admin/api/handler.go | 3 +- authority/admin/api/middleware.go | 26 ++++--- authority/admin/api/middleware_test.go | 10 +-- authority/admin/api/policy.go | 48 +++++++++--- authority/admin/context.go | 10 +++ authority/administrator/collection.go | 2 +- authority/authority.go | 10 ++- authority/policy.go | 102 ++++++++++++++++++++----- authority/tls.go | 89 ++++++++++++++++----- policy/engine.go | 31 ++++---- policy/engine_test.go | 2 +- policy/options_test.go | 1 + 16 files changed, 255 insertions(+), 91 deletions(-) create mode 100644 authority/admin/context.go diff --git a/acme/api/order.go b/acme/api/order.go index e1adebb3..8fe37656 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -105,6 +105,8 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { // management of allowed/denied names based on just the name, without having bound to EAB. Still, // EAB is not illogical, because that's the way Accounts are connected to an external system and // thus make sense to also set the allowed/denied names based on that info. + // TODO: also perform check on the authority level here already, so that challenges are not performed + // and after that the CA fails to sign it. (i.e. h.ca function?) for _, identifier := range nor.Identifiers { // TODO: gather all errors, so that we can build subproblems; include the nor.Validate() error here too, like in example? diff --git a/authority/admin/api/acme.go b/authority/admin/api/acme.go index 27c3ba6f..131a8fff 100644 --- a/authority/admin/api/acme.go +++ b/authority/admin/api/acme.go @@ -14,7 +14,7 @@ import ( const ( // provisionerContextKey provisioner key - provisionerContextKey = ContextKey("provisioner") + provisionerContextKey = admin.ContextKey("provisioner") ) // CreateExternalAccountKeyRequest is the type for POST /admin/acme/eab requests diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index dd40784b..34db5ea2 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -26,8 +26,8 @@ type adminAuthority interface { UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error RemoveProvisioner(ctx context.Context, id string) error GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) - StoreAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error - UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error + StoreAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) error + UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) error RemoveAuthorityPolicy(ctx context.Context) error } diff --git a/authority/admin/api/admin_test.go b/authority/admin/api/admin_test.go index f1698139..bcea31b5 100644 --- a/authority/admin/api/admin_test.go +++ b/authority/admin/api/admin_test.go @@ -139,11 +139,11 @@ func (m *mockAdminAuthority) GetAuthorityPolicy(ctx context.Context) (*linkedca. return nil, errors.New("not implemented yet") } -func (m *mockAdminAuthority) StoreAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { +func (m *mockAdminAuthority) StoreAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) error { return errors.New("not implemented yet") } -func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { +func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) error { return errors.New("not implemented yet") } diff --git a/authority/admin/api/handler.go b/authority/admin/api/handler.go index e59b95e0..0dd45cb0 100644 --- a/authority/admin/api/handler.go +++ b/authority/admin/api/handler.go @@ -30,8 +30,7 @@ func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeRespo func (h *Handler) Route(r api.Router) { authnz := func(next nextHTTP) nextHTTP { - //return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next)) - return h.requireAPIEnabled(next) // TODO(hs): remove this; temporarily no auth checks for simple testing... + return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next)) } requireEABEnabled := func(next nextHTTP) nextHTTP { diff --git a/authority/admin/api/middleware.go b/authority/admin/api/middleware.go index 62aefdc3..c30c7219 100644 --- a/authority/admin/api/middleware.go +++ b/authority/admin/api/middleware.go @@ -7,6 +7,7 @@ import ( "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin/db/nosql" + "go.step.sm/linkedca" ) type nextHTTP = func(http.ResponseWriter, *http.Request) @@ -27,6 +28,7 @@ func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP { // extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token. func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { + tok := r.Header.Get("Authorization") if tok == "" { api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType, @@ -40,7 +42,7 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { return } - ctx := context.WithValue(r.Context(), adminContextKey, adm) + ctx := context.WithValue(r.Context(), admin.AdminContextKey, adm) next(w, r.WithContext(ctx)) } } @@ -49,13 +51,14 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { func (h *Handler) checkAction(next nextHTTP, supportedInStandalone bool) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - // actions allowed in standalone mode are always allowed + // actions allowed in standalone mode are always supported if supportedInStandalone { next(w, r) return } - // when in standalone mode, actions are not supported + // when not in standalone mode and using a nosql.DB backend, + // actions are not supported if _, ok := h.adminDB.(*nosql.DB); ok { api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "operation not supported in standalone mode")) @@ -67,11 +70,12 @@ func (h *Handler) checkAction(next nextHTTP, supportedInStandalone bool) nextHTT } } -// ContextKey is the key type for storing and searching for ACME request -// essentials in the context of a request. -type ContextKey string - -const ( - // adminContextKey account key - adminContextKey = ContextKey("admin") -) +// adminFromContext searches the context for a *linkedca.Admin. +// Returns the admin or an error. +func adminFromContext(ctx context.Context) (*linkedca.Admin, error) { + val, ok := ctx.Value(admin.AdminContextKey).(*linkedca.Admin) + if !ok || val == nil { + return nil, admin.NewError(admin.ErrorBadRequestType, "admin not in context") + } + return val, nil +} diff --git a/authority/admin/api/middleware_test.go b/authority/admin/api/middleware_test.go index 7fb4671a..ffa319db 100644 --- a/authority/admin/api/middleware_test.go +++ b/authority/admin/api/middleware_test.go @@ -152,7 +152,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { req.Header["Authorization"] = []string{"token"} createdAt := time.Now() var deletedAt time.Time - admin := &linkedca.Admin{ + adm := &linkedca.Admin{ Id: "adminID", AuthorityId: "authorityID", Subject: "admin", @@ -164,20 +164,20 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { auth := &mockAdminAuthority{ MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) { assert.Equals(t, "token", token) - return admin, nil + return adm, nil }, } next := func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - a := ctx.Value(adminContextKey) // verifying that the context now has a linkedca.Admin + a := ctx.Value(admin.AdminContextKey) // verifying that the context now has a linkedca.Admin adm, ok := a.(*linkedca.Admin) if !ok { t.Errorf("expected *linkedca.Admin; got %T", a) return } opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})} - if !cmp.Equal(admin, adm, opts...) { - t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(admin, adm, opts...)) + if !cmp.Equal(adm, adm, opts...) { + t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(adm, adm, opts...)) } w.Write(nil) // mock response with status 200 } diff --git a/authority/admin/api/policy.go b/authority/admin/api/policy.go index c318e5e5..2f64802f 100644 --- a/authority/admin/api/policy.go +++ b/authority/admin/api/policy.go @@ -87,8 +87,14 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r return } - if err := par.auth.StoreAuthorityPolicy(ctx, newPolicy); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error storing authority policy")) + adm, err := adminFromContext(ctx) + if err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error retrieving admin from context")) + return + } + + if err := par.auth.StoreAuthorityPolicy(ctx, adm, newPolicy); err != nil { + api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy")) return } @@ -103,25 +109,49 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r // UpdateAuthorityPolicy handles the PUT /admin/authority/policy request func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { - var policy = new(linkedca.Policy) - if err := api.ReadProtoJSON(r.Body, policy); err != nil { + + ctx := r.Context() + policy, err := par.auth.GetAuthorityPolicy(ctx) + + shouldWriteError := false + if ae, ok := err.(*admin.Error); ok { + shouldWriteError = !ae.IsType(admin.ErrorNotFoundType) + } + + if shouldWriteError { + api.WriteError(w, admin.WrapErrorISE(err, "error retrieving authority policy")) + return + } + + if policy == nil { + api.JSONNotFound(w) + return + } + + var newPolicy = new(linkedca.Policy) + if err := api.ReadProtoJSON(r.Body, newPolicy); err != nil { api.WriteError(w, err) return } - ctx := r.Context() - if err := par.auth.UpdateAuthorityPolicy(ctx, policy); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error updating authority policy")) + adm, err := adminFromContext(ctx) + if err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error retrieving admin from context")) return } - newPolicy, err := par.auth.GetAuthorityPolicy(ctx) + if err := par.auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil { + api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy")) + return + } + + newlyStoredPolicy, err := par.auth.GetAuthorityPolicy(ctx) if err != nil { api.WriteError(w, admin.WrapErrorISE(err, "error retrieving authority policy after updating")) return } - api.ProtoJSONStatus(w, newPolicy, http.StatusOK) + api.ProtoJSONStatus(w, newlyStoredPolicy, http.StatusOK) } // DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request diff --git a/authority/admin/context.go b/authority/admin/context.go new file mode 100644 index 00000000..87bf3e03 --- /dev/null +++ b/authority/admin/context.go @@ -0,0 +1,10 @@ +package admin + +// ContextKey is the key type for storing and searching for +// Admin API objects in request contexts. +type ContextKey string + +const ( + // AdminContextKey account key + AdminContextKey = ContextKey("admin") +) diff --git a/authority/administrator/collection.go b/authority/administrator/collection.go index 88d7bb2c..300c3e4f 100644 --- a/authority/administrator/collection.go +++ b/authority/administrator/collection.go @@ -78,7 +78,7 @@ func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool } // Store adds an admin to the collection and enforces the uniqueness of -// admin IDs and amdin subject <-> provisioner name combos. +// admin IDs and admin subject <-> provisioner name combos. func (c *Collection) Store(adm *linkedca.Admin, prov provisioner.Interface) error { // Input validation. if adm.ProvisionerId != prov.GetID() { diff --git a/authority/authority.go b/authority/authority.go index aaf0e478..29a10d7e 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -219,15 +219,19 @@ func (a *Authority) reloadPolicyEngines(ctx context.Context) error { if a.config.AuthorityConfig.EnableAdmin { linkedPolicy, err := a.adminDB.GetAuthorityPolicy(ctx) if err != nil { - return admin.WrapErrorISE(err, "error getting policy to initialize authority") + return admin.WrapErrorISE(err, "error getting policy to (re)load policy engines") } policyOptions = policyToCertificates(linkedPolicy) } else { policyOptions = a.config.AuthorityConfig.Policy } - // return early if no policy options set + // if no new or updated policy option is set, clear policy engines that (may have) + // been configured before and return early if policyOptions == nil { + a.x509Policy = nil + a.sshHostPolicy = nil + a.sshUserPolicy = nil return nil } @@ -574,7 +578,7 @@ func (a *Authority) init() error { return err } - // Load Policy Engines + // Load x509 and SSH Policy Engines if err := a.reloadPolicyEngines(context.Background()); err != nil { return err } diff --git a/authority/policy.go b/authority/policy.go index 8ef264d0..db44e5f4 100644 --- a/authority/policy.go +++ b/authority/policy.go @@ -2,10 +2,15 @@ package authority import ( "context" + "fmt" + + "github.com/pkg/errors" + + "go.step.sm/linkedca" "github.com/smallstep/certificates/authority/admin" - "github.com/smallstep/certificates/authority/policy" - "go.step.sm/linkedca" + authPolicy "github.com/smallstep/certificates/authority/policy" + policy "github.com/smallstep/certificates/policy" ) func (a *Authority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { @@ -20,31 +25,39 @@ func (a *Authority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, e return policy, nil } -func (a *Authority) StoreAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { +func (a *Authority) StoreAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) error { a.adminMutex.Lock() defer a.adminMutex.Unlock() + if err := a.checkPolicy(ctx, adm, policy); err != nil { + return err + } + if err := a.adminDB.CreateAuthorityPolicy(ctx, policy); err != nil { return err } if err := a.reloadPolicyEngines(ctx); err != nil { - return admin.WrapErrorISE(err, "error reloading admin resources when creating authority policy") + return admin.WrapErrorISE(err, "error reloading policy engines when creating authority policy") } return nil } -func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { +func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) error { a.adminMutex.Lock() defer a.adminMutex.Unlock() + if err := a.checkPolicy(ctx, adm, policy); err != nil { + return err + } + if err := a.adminDB.UpdateAuthorityPolicy(ctx, policy); err != nil { return err } if err := a.reloadPolicyEngines(ctx); err != nil { - return admin.WrapErrorISE(err, "error reloading admin resources when updating authority policy") + return admin.WrapErrorISE(err, "error reloading policy engines when updating authority policy") } return nil @@ -59,34 +72,84 @@ func (a *Authority) RemoveAuthorityPolicy(ctx context.Context) error { } if err := a.reloadPolicyEngines(ctx); err != nil { - return admin.WrapErrorISE(err, "error reloading admin resources when deleting authority policy") + return admin.WrapErrorISE(err, "error reloading policy engines when deleting authority policy") } return nil } -func policyToCertificates(p *linkedca.Policy) *policy.Options { +func (a *Authority) checkPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) error { + + // convert the policy; return early if nil + policyOptions := policyToCertificates(p) + if policyOptions == nil { + return nil + } + + engine, err := authPolicy.NewX509PolicyEngine(policyOptions.GetX509Options()) + if err != nil { + return admin.WrapErrorISE(err, "error creating temporary policy engine") + } + + // TODO(hs): Provide option to force the policy, even when the admin subject would be locked out? + + sans := []string{adm.Subject} + if err := isAllowed(engine, sans); err != nil { + return err + } + + // TODO(hs): perform the check for other admin subjects too? + // What logic to use for that: do all admins need access? Only super admins? At least one? + + return nil +} + +func isAllowed(engine authPolicy.X509Policy, sans []string) error { + var ( + allowed bool + err error + ) + if allowed, err = engine.AreSANsAllowed(sans); err != nil { + var policyErr *policy.NamePolicyError + if errors.As(err, &policyErr); policyErr.Reason == policy.NotAuthorizedForThisName { + return fmt.Errorf("the provided policy would lock out %s from the CA. Please update your policy to include %s as an allowed name", sans, sans) + } else { + return err + } + } + + if !allowed { + return fmt.Errorf("the provided policy would lock out %s from the CA. Please update your policy to include %s as an allowed name", sans, sans) + } + + return nil +} + +func policyToCertificates(p *linkedca.Policy) *authPolicy.Options { + // return early if p == nil { return nil } + // prepare full policy struct - opts := &policy.Options{ - X509: &policy.X509PolicyOptions{ - AllowedNames: &policy.X509NameOptions{}, - DeniedNames: &policy.X509NameOptions{}, + opts := &authPolicy.Options{ + X509: &authPolicy.X509PolicyOptions{ + AllowedNames: &authPolicy.X509NameOptions{}, + DeniedNames: &authPolicy.X509NameOptions{}, }, - SSH: &policy.SSHPolicyOptions{ - Host: &policy.SSHHostCertificateOptions{ - AllowedNames: &policy.SSHNameOptions{}, - DeniedNames: &policy.SSHNameOptions{}, + SSH: &authPolicy.SSHPolicyOptions{ + Host: &authPolicy.SSHHostCertificateOptions{ + AllowedNames: &authPolicy.SSHNameOptions{}, + DeniedNames: &authPolicy.SSHNameOptions{}, }, - User: &policy.SSHUserCertificateOptions{ - AllowedNames: &policy.SSHNameOptions{}, - DeniedNames: &policy.SSHNameOptions{}, + User: &authPolicy.SSHUserCertificateOptions{ + AllowedNames: &authPolicy.SSHNameOptions{}, + DeniedNames: &authPolicy.SSHNameOptions{}, }, }, } + // fill x509 policy configuration if p.X509 != nil { if p.X509.Allow != nil { @@ -102,6 +165,7 @@ func policyToCertificates(p *linkedca.Policy) *policy.Options { opts.X509.DeniedNames.URIDomains = p.X509.Deny.Uris } } + // fill ssh policy configuration if p.Ssh != nil { if p.Ssh.Host != nil { diff --git a/authority/tls.go b/authority/tls.go index 96c80e9a..297c796e 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -191,26 +191,20 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign } } - // If a policy is configured, perform allow/deny policy check on authority level - // TODO: policy currently also applies to admin token certs; how to circumvent? - // Allow any name of an admin in the DB? Or in the admin collection? - todoRemoveThis := false - if todoRemoveThis && a.x509Policy != nil { - allowed, err := a.x509Policy.AreCertificateNamesAllowed(leaf) - if err != nil { - return nil, errs.InternalServerErr(err, - errs.WithKeyVal("csr", csr), - errs.WithKeyVal("signOptions", signOpts), - errs.WithMessage("error creating certificate"), - ) - } - if !allowed { - // TODO: include SANs in error message? - return nil, errs.ApplyOptions( - errs.ForbiddenErr(errors.New("authority not allowed to sign"), "error creating certificate"), - opts..., - ) - } + // Check if authority is allowed to sign the certificate + var allowedToSign bool + if allowedToSign, err = a.isAllowedToSign(leaf); err != nil { + return nil, errs.InternalServerErr(err, + errs.WithKeyVal("csr", csr), + errs.WithKeyVal("signOptions", signOpts), + errs.WithMessage("error creating certificate"), + ) + } + if !allowedToSign { + return nil, errs.ApplyOptions( + errs.ForbiddenErr(errors.New("authority not allowed to sign"), "error creating certificate"), + opts..., + ) } // Sign certificate @@ -236,6 +230,61 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign return fullchain, nil } +// isAllowedToSign checks if the Authority is allowed to sign the X.509 certificate. +// It first checks if the certificate contains an admin subject that exists in the +// collection of admins. The CA is always allowed to sign those. If the cert contains +// different names and a policy is configured, the policy will be executed against +// the cert to see if the CA is allowed to sign it. +func (a *Authority) isAllowedToSign(cert *x509.Certificate) (bool, error) { + + // // check if certificate is an admin identity token certificate and the admin subject exists + // b := isAdminIdentityTokenCertificate(cert) + // _ = b + + // if isAdminIdentityTokenCertificate(cert) && a.admins.HasSubject(cert.Subject.CommonName) { + // return true, nil + // } + + // if no policy is configured, the cert is implicitly allowed + if a.x509Policy == nil { + return true, nil + } + + return a.x509Policy.AreCertificateNamesAllowed(cert) +} + +func isAdminIdentityTokenCertificate(cert *x509.Certificate) bool { + + // TODO: remove this check + + if cert.Issuer.CommonName != "" { + return false + } + + subject := cert.Subject.CommonName + if subject == "" { + return false + } + + dnsNames := cert.DNSNames + if len(dnsNames) != 1 { + return false + } + + if dnsNames[0] != subject { + return false + } + + extras := cert.ExtraExtensions + if len(extras) != 1 { + return false + } + + extra := extras[0] + + return extra.Id.Equal(asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}) +} + // Renew creates a new Certificate identical to the old certificate, except // with a validity window that begins 'now'. func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) { diff --git a/policy/engine.go b/policy/engine.go index e9038dd0..63d8452a 100755 --- a/policy/engine.go +++ b/policy/engine.go @@ -10,9 +10,10 @@ import ( "reflect" "strings" - "go.step.sm/crypto/x509util" "golang.org/x/crypto/ssh" "golang.org/x/net/idna" + + "go.step.sm/crypto/x509util" ) type NamePolicyReason int @@ -39,7 +40,7 @@ type NamePolicyError struct { Detail string } -func (e NamePolicyError) Error() string { +func (e *NamePolicyError) Error() string { switch e.Reason { case NotAuthorizedForThisName: return "not authorized to sign for this name: " + e.Detail @@ -295,7 +296,7 @@ func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailA // then return error, because DNS should be explicitly configured to be allowed in that case. In case there are // (other) excluded constraints, we'll allow a DNS (implicit allow; currently). if e.numberOfDNSDomainConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 { - return NamePolicyError{ + return &NamePolicyError{ Reason: NotAuthorizedForThisName, Detail: fmt.Sprintf("dns %q is not explicitly permitted by any constraint", dns), } @@ -307,7 +308,7 @@ func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailA } parsedDNS, err := idna.Lookup.ToASCII(dns) if err != nil { - return NamePolicyError{ + return &NamePolicyError{ Reason: CannotParseDomain, Detail: fmt.Sprintf("dns %q cannot be converted to ASCII", dns), } @@ -316,7 +317,7 @@ func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailA parsedDNS = "*" + parsedDNS } if _, ok := domainToReverseLabels(parsedDNS); !ok { - return NamePolicyError{ + return &NamePolicyError{ Reason: CannotParseDomain, Detail: fmt.Sprintf("cannot parse dns %q", dns), } @@ -331,7 +332,7 @@ func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailA for _, ip := range ips { if e.numberOfIPRangeConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 { - return NamePolicyError{ + return &NamePolicyError{ Reason: NotAuthorizedForThisName, Detail: fmt.Sprintf("ip %q is not explicitly permitted by any constraint", ip.String()), } @@ -346,14 +347,14 @@ func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailA for _, email := range emailAddresses { if e.numberOfEmailAddressConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 { - return NamePolicyError{ + return &NamePolicyError{ Reason: NotAuthorizedForThisName, Detail: fmt.Sprintf("email %q is not explicitly permitted by any constraint", email), } } mailbox, ok := parseRFC2821Mailbox(email) if !ok { - return NamePolicyError{ + return &NamePolicyError{ Reason: CannotParseRFC822Name, Detail: fmt.Sprintf("invalid rfc822Name %q", mailbox), } @@ -363,7 +364,7 @@ func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailA // https://datatracker.ietf.org/doc/html/rfc5280#section-7.5 domainASCII, err := idna.ToASCII(mailbox.domain) if err != nil { - return NamePolicyError{ + return &NamePolicyError{ Reason: CannotParseDomain, Detail: fmt.Sprintf("cannot parse email domain %q", email), } @@ -381,7 +382,7 @@ func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailA for _, uri := range uris { if e.numberOfURIDomainConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 { - return NamePolicyError{ + return &NamePolicyError{ Reason: NotAuthorizedForThisName, Detail: fmt.Sprintf("uri %q is not explicitly permitted by any constraint", uri.String()), } @@ -396,7 +397,7 @@ func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailA for _, principal := range principals { if e.numberOfPrincipalConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 { - return NamePolicyError{ + return &NamePolicyError{ Reason: NotAuthorizedForThisName, Detail: fmt.Sprintf("username principal %q is not explicitly permitted by any constraint", principal), } @@ -431,14 +432,14 @@ func checkNameConstraints( constraint := excludedValue.Index(i).Interface() match, err := match(parsedName, constraint) if err != nil { - return NamePolicyError{ + return &NamePolicyError{ Reason: CannotMatchNameToConstraint, Detail: err.Error(), } } if match { - return NamePolicyError{ + return &NamePolicyError{ Reason: NotAuthorizedForThisName, Detail: fmt.Sprintf("%s %q is excluded by constraint %q", nameType, name, constraint), } @@ -452,7 +453,7 @@ func checkNameConstraints( constraint := permittedValue.Index(i).Interface() var err error if ok, err = match(parsedName, constraint); err != nil { - return NamePolicyError{ + return &NamePolicyError{ Reason: CannotMatchNameToConstraint, Detail: err.Error(), } @@ -464,7 +465,7 @@ func checkNameConstraints( } if !ok { - return NamePolicyError{ + return &NamePolicyError{ Reason: NotAuthorizedForThisName, Detail: fmt.Sprintf("%s %q is not permitted by any constraint", nameType, name), } diff --git a/policy/engine_test.go b/policy/engine_test.go index 0259e8de..f7a4b20a 100755 --- a/policy/engine_test.go +++ b/policy/engine_test.go @@ -13,7 +13,7 @@ import ( ) // TODO(hs): the functionality in the policy engine is a nice candidate for trying fuzzing on -// TODO(hs): more complex uses cases that combine multiple names and permitted/excluded entries +// TODO(hs): more complex use cases that combine multiple names and permitted/excluded entries func TestNamePolicyEngine_matchDomainConstraint(t *testing.T) { tests := []struct { diff --git a/policy/options_test.go b/policy/options_test.go index 0fc54aa2..8a64f282 100644 --- a/policy/options_test.go +++ b/policy/options_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/smallstep/assert" ) From 390054b22e4cb49f01187e89339f04c7b2564d19 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 21 Mar 2022 16:22:26 -0700 Subject: [PATCH 36/44] Change go version to 1.17 and 1.18 --- .github/workflows/release.yml | 8 ++++---- .github/workflows/test.yml | 6 +++--- CHANGELOG.md | 1 + 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5d0416ef..2ab7084d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - go: [ '1.15', '1.16', '1.17' ] + go: [ '1.17', '1.18' ] outputs: is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }} steps: @@ -33,7 +33,7 @@ jobs: uses: golangci/golangci-lint-action@v2 with: # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version - version: 'v1.44.0' + version: 'v1.45.0' # Optional: working directory, useful for monorepos # working-directory: somedir @@ -106,7 +106,7 @@ jobs: name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.17 + go-version: 1.18 - name: APT Install id: aptInstall @@ -159,7 +159,7 @@ jobs: name: Setup Go uses: actions/setup-go@v2 with: - go-version: '1.17' + go-version: '1.18' - name: Install cosign uses: sigstore/cosign-installer@v1.1.0 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f36e78ef..64cb64cd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - go: [ '1.16', '1.17' ] + go: [ '1.17', '1.18' ] steps: - name: Checkout @@ -33,7 +33,7 @@ jobs: uses: golangci/golangci-lint-action@v2 with: # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version - version: 'v1.44.0' + version: 'v1.45.0' # Optional: working directory, useful for monorepos # working-directory: somedir @@ -58,7 +58,7 @@ jobs: run: V=1 make ci - name: Codecov - if: matrix.go == '1.17' + if: matrix.go == '1.18' uses: codecov/codecov-action@v1.2.1 with: file: ./coverage.out # optional diff --git a/CHANGELOG.md b/CHANGELOG.md index fc25c0ed..3164b3b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Added support for renew after expiry using the claim `allowRenewAfterExpiry`. ### Changed - Made SCEP CA URL paths dynamic +- Support two latest versions of golang (1.17, 1.18) ### Deprecated ### Removed ### Fixed From ad8a813abe89fc019bbb3242a3cbc48f110ccfd1 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 21 Mar 2022 16:53:57 -0700 Subject: [PATCH 37/44] Fix linter errors --- authority/provisioner/x5c.go | 4 +++- authority/provisioner/x5c_test.go | 2 ++ ca/ca.go | 3 --- ca/identity/client_test.go | 23 ++++++++++++++++++++++- ca/identity/identity_test.go | 2 ++ ca/tls.go | 2 -- ca/tls_options_test.go | 1 + 7 files changed, 30 insertions(+), 7 deletions(-) diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index 6f534c76..51b5d8fd 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -100,6 +100,7 @@ func (p *X5C) Init(config Config) (err error) { var ( block *pem.Block rest = p.Roots + count int ) for rest != nil { block, rest = pem.Decode(rest) @@ -110,11 +111,12 @@ func (p *X5C) Init(config Config) (err error) { if err != nil { return errors.Wrap(err, "error parsing x509 certificate from PEM block") } + count++ p.rootPool.AddCert(cert) } // Verify that at least one root was found. - if len(p.rootPool.Subjects()) == 0 { + if count == 0 { return errors.Errorf("no x509 certificates found in roots attribute for provisioner '%s'", p.GetName()) } diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 84e29b48..7932d045 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -118,6 +118,8 @@ M46l92gdOozT return ProvisionerValidateTest{ p: p, extraValid: func(p *X5C) error { + // nolint:staticcheck // We don't have a different way to + // check the number of certificates in the pool. numCerts := len(p.rootPool.Subjects()) if numCerts != 2 { return errors.Errorf("unexpected number of certs: want 2, but got %d", numCerts) diff --git a/ca/ca.go b/ca/ca.go index c95ba22f..dfb82731 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -450,9 +450,6 @@ func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) { tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven tlsConfig.ClientCAs = certPool - // Use server's most preferred ciphersuite - tlsConfig.PreferServerCipherSuites = true - return tlsConfig, nil } diff --git a/ca/identity/client_test.go b/ca/identity/client_test.go index 0f1234e9..9660a3bd 100644 --- a/ca/identity/client_test.go +++ b/ca/identity/client_test.go @@ -8,6 +8,7 @@ import ( "net/url" "os" "reflect" + "sort" "testing" ) @@ -196,7 +197,7 @@ func TestLoadClient(t *testing.T) { switch { case gotTransport.TLSClientConfig.GetClientCertificate == nil: t.Error("LoadClient() transport does not define GetClientCertificate") - case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()): + case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !equalPools(gotTransport.TLSClientConfig.RootCAs, wantTransport.TLSClientConfig.RootCAs): t.Errorf("LoadClient() = %#v, want %#v", got, tt.want) default: crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil) @@ -238,3 +239,23 @@ func Test_defaultsConfig_Validate(t *testing.T) { }) } } + +// nolint:staticcheck,gocritic +func equalPools(a, b *x509.CertPool) bool { + if reflect.DeepEqual(a, b) { + return true + } + subjects := a.Subjects() + sA := make([]string, len(subjects)) + for i := range subjects { + sA[i] = string(subjects[i]) + } + subjects = b.Subjects() + sB := make([]string, len(subjects)) + for i := range subjects { + sB[i] = string(subjects[i]) + } + sort.Strings(sA) + sort.Strings(sB) + return reflect.DeepEqual(sA, sB) +} diff --git a/ca/identity/identity_test.go b/ca/identity/identity_test.go index d3b1d541..55fc60fd 100644 --- a/ca/identity/identity_test.go +++ b/ca/identity/identity_test.go @@ -346,6 +346,8 @@ func TestIdentity_GetCertPool(t *testing.T) { return } if got != nil { + // nolint:staticcheck // we don't have a different way to check + // the certificates in the pool. subjects := got.Subjects() if !reflect.DeepEqual(subjects, tt.wantSubjects) { t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects) diff --git a/ca/tls.go b/ca/tls.go index 0738d0e0..7954cbdf 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -95,7 +95,6 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, // Note that with GetClientCertificate tlsConfig.Certificates is not used. // Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate() tlsConfig.GetClientCertificate = renewer.GetClientCertificate - tlsConfig.PreferServerCipherSuites = true // Apply options and initialize mutable tls.Config tlsCtx := newTLSOptionCtx(c, tlsConfig, sign) @@ -137,7 +136,6 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, // Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate() tlsConfig.GetCertificate = renewer.GetCertificate tlsConfig.GetClientCertificate = renewer.GetClientCertificate - tlsConfig.PreferServerCipherSuites = true tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert // Apply options and initialize mutable tls.Config diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index 7d94926b..ca5f80b8 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -542,6 +542,7 @@ func TestAddFederationToCAs(t *testing.T) { } } +// nolint:staticcheck,gocritic func equalPools(a, b *x509.CertPool) bool { if reflect.DeepEqual(a, b) { return true From 24a963766e23e3f1f55165392b63ee8bcbc44987 Mon Sep 17 00:00:00 2001 From: vijayjt <2975049+vijayjt@users.noreply.github.com> Date: Tue, 22 Mar 2022 00:10:43 +0000 Subject: [PATCH 38/44] Pass in the resource name regardless of if its a VM or managed identity --- authority/provisioner/azure.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index 391034bc..0afd396f 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -269,13 +269,8 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, str var subscription, group, name string identityObjectID := claims.ObjectID + subscription, group, name = re[1], re[2], re[4] - if strings.Contains(claims.XMSMirID, "virtualMachines") { - subscription, group, name = re[1], re[2], re[4] - } else { - // This is not a VM resource ID so we don't have the VM name so set that to the empty string - subscription, group, name = re[1], re[2], "" - } return &claims, name, group, subscription, identityObjectID, nil } From f1d586bc6d3be32b536bc6d16d068d667ef24482 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 21 Mar 2022 17:59:15 -0700 Subject: [PATCH 39/44] Change golang to Go --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3164b3b6..73c338f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Added support for renew after expiry using the claim `allowRenewAfterExpiry`. ### Changed - Made SCEP CA URL paths dynamic -- Support two latest versions of golang (1.17, 1.18) +- Support two latest versions of Go (1.17, 1.18) ### Deprecated ### Removed ### Fixed From 80abda22eed7270f5d123fd93eb0e545373b165c Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Tue, 22 Mar 2022 14:31:18 +0200 Subject: [PATCH 40/44] api/log: initial implementation of the package (#859) * api/log: initial implementation of the package * api: refactored to support api/log * scep/api: refactored to support api/log * api/log: documented the package * api: moved log-related tests to api/log --- api/errors.go | 3 ++- api/log/log.go | 47 +++++++++++++++++++++++++++++++++++++ api/log/log_test.go | 44 +++++++++++++++++++++++++++++++++++ api/utils.go | 56 ++++++++++----------------------------------- api/utils_test.go | 35 ---------------------------- scep/api/api.go | 14 ++++++------ 6 files changed, 112 insertions(+), 87 deletions(-) create mode 100644 api/log/log.go create mode 100644 api/log/log_test.go diff --git a/api/errors.go b/api/errors.go index 522fa955..49efd486 100644 --- a/api/errors.go +++ b/api/errors.go @@ -9,6 +9,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/api/log" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" @@ -60,6 +61,6 @@ func WriteError(w http.ResponseWriter, err error) { } if err := json.NewEncoder(w).Encode(err); err != nil { - LogError(w, err) + log.Error(w, err) } } diff --git a/api/log/log.go b/api/log/log.go new file mode 100644 index 00000000..78dae506 --- /dev/null +++ b/api/log/log.go @@ -0,0 +1,47 @@ +// Package log implements API-related logging helpers. +package log + +import ( + "log" + "net/http" + + "github.com/smallstep/certificates/logging" +) + +// Error adds to the response writer the given error if it implements +// logging.ResponseLogger. If it does not implement it, then writes the error +// using the log package. +func Error(rw http.ResponseWriter, err error) { + if rl, ok := rw.(logging.ResponseLogger); ok { + rl.WithFields(map[string]interface{}{ + "error": err, + }) + } else { + log.Println(err) + } +} + +// EnabledResponse log the response object if it implements the EnableLogger +// interface. +func EnabledResponse(rw http.ResponseWriter, v interface{}) { + type enableLogger interface { + ToLog() (interface{}, error) + } + + if el, ok := v.(enableLogger); ok { + out, err := el.ToLog() + if err != nil { + Error(rw, err) + + return + } + + if rl, ok := rw.(logging.ResponseLogger); ok { + rl.WithFields(map[string]interface{}{ + "response": out, + }) + } else { + log.Println(out) + } + } +} diff --git a/api/log/log_test.go b/api/log/log_test.go new file mode 100644 index 00000000..fcd3ea2b --- /dev/null +++ b/api/log/log_test.go @@ -0,0 +1,44 @@ +package log + +import ( + "errors" + "net/http" + "net/http/httptest" + "reflect" + "testing" + + "github.com/smallstep/certificates/logging" +) + +func TestError(t *testing.T) { + theError := errors.New("the error") + + type args struct { + rw http.ResponseWriter + err error + } + tests := []struct { + name string + args args + withFields bool + }{ + {"normalLogger", args{httptest.NewRecorder(), theError}, false}, + {"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Error(tt.args.rw, tt.args.err) + if tt.withFields { + if rl, ok := tt.args.rw.(logging.ResponseLogger); ok { + fields := rl.Fields() + if !reflect.DeepEqual(fields["error"], theError) { + t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError) + } + } else { + t.Error("ResponseWriter does not implement logging.ResponseLogger") + } + } + }) + } +} diff --git a/api/utils.go b/api/utils.go index 9daa0cd2..e3fcc9c4 100644 --- a/api/utils.go +++ b/api/utils.go @@ -2,52 +2,14 @@ package api import ( "encoding/json" - "log" "net/http" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" - "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/api/log" ) -// EnableLogger is an interface that enables response logging for an object. -type EnableLogger interface { - ToLog() (interface{}, error) -} - -// LogError adds to the response writer the given error if it implements -// logging.ResponseLogger. If it does not implement it, then writes the error -// using the log package. -func LogError(rw http.ResponseWriter, err error) { - if rl, ok := rw.(logging.ResponseLogger); ok { - rl.WithFields(map[string]interface{}{ - "error": err, - }) - } else { - log.Println(err) - } -} - -// LogEnabledResponse log the response object if it implements the EnableLogger -// interface. -func LogEnabledResponse(rw http.ResponseWriter, v interface{}) { - if el, ok := v.(EnableLogger); ok { - out, err := el.ToLog() - if err != nil { - LogError(rw, err) - return - } - if rl, ok := rw.(logging.ResponseLogger); ok { - rl.WithFields(map[string]interface{}{ - "response": out, - }) - } else { - log.Println(out) - } - } -} - // JSON writes the passed value into the http.ResponseWriter. func JSON(w http.ResponseWriter, v interface{}) { JSONStatus(w, v, http.StatusOK) @@ -59,10 +21,12 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) if err := json.NewEncoder(w).Encode(v); err != nil { - LogError(w, err) + log.Error(w, err) + return } - LogEnabledResponse(w, v) + + log.EnabledResponse(w, v) } // ProtoJSON writes the passed value into the http.ResponseWriter. @@ -78,12 +42,16 @@ func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) { b, err := protojson.Marshal(m) if err != nil { - LogError(w, err) + log.Error(w, err) + return } + if _, err := w.Write(b); err != nil { - LogError(w, err) + log.Error(w, err) + return } - //LogEnabledResponse(w, v) + + // log.EnabledResponse(w, v) } diff --git a/api/utils_test.go b/api/utils_test.go index 12350c97..f5e1e1cb 100644 --- a/api/utils_test.go +++ b/api/utils_test.go @@ -3,46 +3,11 @@ package api import ( "net/http" "net/http/httptest" - "reflect" "testing" - "github.com/pkg/errors" - "github.com/smallstep/certificates/logging" ) -func TestLogError(t *testing.T) { - theError := errors.New("the error") - type args struct { - rw http.ResponseWriter - err error - } - tests := []struct { - name string - args args - withFields bool - }{ - {"normalLogger", args{httptest.NewRecorder(), theError}, false}, - {"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - LogError(tt.args.rw, tt.args.err) - if tt.withFields { - if rl, ok := tt.args.rw.(logging.ResponseLogger); ok { - fields := rl.Fields() - if !reflect.DeepEqual(fields["error"], theError) { - t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError) - } - } else { - t.Error("ResponseWriter does not implement logging.ResponseLogger") - } - } - }) - } -} - func TestJSON(t *testing.T) { type args struct { rw http.ResponseWriter diff --git a/scep/api/api.go b/scep/api/api.go index 77c683ee..a326ea92 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -10,14 +10,14 @@ import ( "strings" "github.com/go-chi/chi" - "github.com/smallstep/certificates/api" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/scep" + microscep "github.com/micromdm/scep/v2/scep" + "github.com/pkg/errors" "go.mozilla.org/pkcs7" - "github.com/pkg/errors" - - microscep "github.com/micromdm/scep/v2/scep" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/log" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/scep" ) const ( @@ -337,7 +337,7 @@ func formatCapabilities(caps []string) []byte { func writeSCEPResponse(w http.ResponseWriter, response SCEPResponse) { if response.Error != nil { - api.LogError(w, response.Error) + log.Error(w, response.Error) } if response.Certificate != nil { From 907bdd686b7c29541055aff3327864f63589cd59 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 23 Mar 2022 23:14:04 +0100 Subject: [PATCH 41/44] Add armv5 build to GoReleaser configuration --- .goreleaser.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.goreleaser.yml b/.goreleaser.yml index 207c75bd..cd4826c9 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -19,6 +19,7 @@ builds: - linux_386 - linux_amd64 - linux_arm64 + - linux_arm_5 - linux_arm_6 - linux_arm_7 - windows_amd64 From 904d6712f5d933a110b8cee20073d5b961f86d48 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Thu, 24 Mar 2022 00:04:59 +0100 Subject: [PATCH 42/44] Add armv5 build for (cloud|aws)kms --- .goreleaser.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.goreleaser.yml b/.goreleaser.yml index cd4826c9..441d5785 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -40,6 +40,7 @@ builds: - linux_386 - linux_amd64 - linux_arm64 + - linux_arm_5 - linux_arm_6 - linux_arm_7 - windows_amd64 @@ -60,6 +61,7 @@ builds: - linux_386 - linux_amd64 - linux_arm64 + - linux_arm_5 - linux_arm_6 - linux_arm_7 - windows_amd64 From 6b620c8e9c844f66d4f41cd5a1796c48e38086aa Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Thu, 24 Mar 2022 10:54:45 +0100 Subject: [PATCH 43/44] Improve protobuf unmarshaling error handling --- api/utils.go | 52 +++++++++++++++++++++++++++++-- authority/admin/api/admin.go | 4 +-- authority/admin/api/admin_test.go | 16 +++++----- authority/admin/api/middleware.go | 5 +-- authority/admin/api/policy.go | 41 +++++++----------------- authority/policy.go | 20 ++++++------ go.sum | 3 +- policy/engine.go | 26 ++++++++++++++++ policy/engine_test.go | 3 +- 9 files changed, 116 insertions(+), 54 deletions(-) diff --git a/api/utils.go b/api/utils.go index b6ff7960..91091e25 100644 --- a/api/utils.go +++ b/api/utils.go @@ -2,14 +2,16 @@ package api import ( "encoding/json" + "errors" "io" "log" "net/http" - "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/logging" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" + + "github.com/smallstep/certificates/errs" + "github.com/smallstep/certificates/logging" ) // EnableLogger is an interface that enables response logging for an object. @@ -114,3 +116,49 @@ func ReadProtoJSON(r io.Reader, m proto.Message) error { } return protojson.Unmarshal(data, m) } + +// ReadProtoJSONWithCheck reads JSON from the request body and stores it in the value +// pointed by v. TODO(hs): move this to and integrate with render package. +func ReadProtoJSONWithCheck(w http.ResponseWriter, r io.Reader, m proto.Message) bool { + data, err := io.ReadAll(r) + if err != nil { + var wrapper = struct { + Status int `json:"code"` + Message string `json:"message"` + }{ + Status: http.StatusBadRequest, + Message: err.Error(), + } + data, err := json.Marshal(wrapper) // TODO(hs): handle err; even though it's very unlikely to fail + if err != nil { + panic(err) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write(data) + return false + } + if err := protojson.Unmarshal(data, m); err != nil { + if errors.Is(err, proto.Error) { + var wrapper = struct { + Message string `json:"message"` + }{ + Message: err.Error(), + } + data, err := json.Marshal(wrapper) // TODO(hs): handle err; even though it's very unlikely to fail + if err != nil { + panic(err) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write(data) + return false + } + + // fallback to the default error writer + WriteError(w, err) + return false + } + + return true +} diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index 34db5ea2..95b9ba98 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -26,8 +26,8 @@ type adminAuthority interface { UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error RemoveProvisioner(ctx context.Context, id string) error GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) - StoreAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) error - UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) error + CreateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) + UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) RemoveAuthorityPolicy(ctx context.Context) error } diff --git a/authority/admin/api/admin_test.go b/authority/admin/api/admin_test.go index bcea31b5..d9592ff2 100644 --- a/authority/admin/api/admin_test.go +++ b/authority/admin/api/admin_test.go @@ -14,11 +14,13 @@ import ( "github.com/go-chi/chi" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/types/known/timestamppb" + + "go.step.sm/linkedca" + "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/linkedca" - "google.golang.org/protobuf/types/known/timestamppb" ) type mockAdminAuthority struct { @@ -39,7 +41,7 @@ type mockAdminAuthority struct { MockRemoveProvisioner func(ctx context.Context, id string) error MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error) - MockStoreAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error + MockCreateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) (*linkedca.Policy, error) MockUpdateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error MockRemoveAuthorityPolicy func(ctx context.Context) error } @@ -139,12 +141,12 @@ func (m *mockAdminAuthority) GetAuthorityPolicy(ctx context.Context) (*linkedca. return nil, errors.New("not implemented yet") } -func (m *mockAdminAuthority) StoreAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) error { - return errors.New("not implemented yet") +func (m *mockAdminAuthority) CreateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + return nil, errors.New("not implemented yet") } -func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) error { - return errors.New("not implemented yet") +func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + return nil, errors.New("not implemented yet") } func (m *mockAdminAuthority) RemoveAuthorityPolicy(ctx context.Context) error { diff --git a/authority/admin/api/middleware.go b/authority/admin/api/middleware.go index c30c7219..74bb2234 100644 --- a/authority/admin/api/middleware.go +++ b/authority/admin/api/middleware.go @@ -4,10 +4,11 @@ import ( "context" "net/http" + "go.step.sm/linkedca" + "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin/db/nosql" - "go.step.sm/linkedca" ) type nextHTTP = func(http.ResponseWriter, *http.Request) @@ -42,7 +43,7 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { return } - ctx := context.WithValue(r.Context(), admin.AdminContextKey, adm) + ctx := linkedca.WithAdmin(r.Context(), adm) next(w, r.WithContext(ctx)) } } diff --git a/authority/admin/api/policy.go b/authority/admin/api/policy.go index 2f64802f..30e05c48 100644 --- a/authority/admin/api/policy.go +++ b/authority/admin/api/policy.go @@ -4,10 +4,12 @@ import ( "net/http" "github.com/go-chi/chi" + + "go.step.sm/linkedca" + "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/linkedca" ) type policyAdminResponderInterface interface { @@ -82,29 +84,19 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r } var newPolicy = new(linkedca.Policy) - if err := api.ReadProtoJSON(r.Body, newPolicy); err != nil { - api.WriteError(w, err) + if !api.ReadProtoJSONWithCheck(w, r.Body, newPolicy) { return } - adm, err := adminFromContext(ctx) - if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error retrieving admin from context")) - return - } + adm := linkedca.AdminFromContext(ctx) - if err := par.auth.StoreAuthorityPolicy(ctx, adm, newPolicy); err != nil { + var createdPolicy *linkedca.Policy + if createdPolicy, err = par.auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil { api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy")) return } - storedPolicy, err := par.auth.GetAuthorityPolicy(ctx) - if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error retrieving authority policy after updating")) - return - } - - api.JSONStatus(w, storedPolicy, http.StatusCreated) + api.JSONStatus(w, createdPolicy, http.StatusCreated) } // UpdateAuthorityPolicy handles the PUT /admin/authority/policy request @@ -134,24 +126,15 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r return } - adm, err := adminFromContext(ctx) - if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error retrieving admin from context")) - return - } + adm := linkedca.AdminFromContext(ctx) - if err := par.auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil { + var updatedPolicy *linkedca.Policy + if updatedPolicy, err = par.auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil { api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy")) return } - newlyStoredPolicy, err := par.auth.GetAuthorityPolicy(ctx) - if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error retrieving authority policy after updating")) - return - } - - api.ProtoJSONStatus(w, newlyStoredPolicy, http.StatusOK) + api.ProtoJSONStatus(w, updatedPolicy, http.StatusOK) } // DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request diff --git a/authority/policy.go b/authority/policy.go index db44e5f4..ee132f31 100644 --- a/authority/policy.go +++ b/authority/policy.go @@ -25,42 +25,42 @@ func (a *Authority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, e return policy, nil } -func (a *Authority) StoreAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) error { +func (a *Authority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { a.adminMutex.Lock() defer a.adminMutex.Unlock() if err := a.checkPolicy(ctx, adm, policy); err != nil { - return err + return nil, err } if err := a.adminDB.CreateAuthorityPolicy(ctx, policy); err != nil { - return err + return nil, err } if err := a.reloadPolicyEngines(ctx); err != nil { - return admin.WrapErrorISE(err, "error reloading policy engines when creating authority policy") + return nil, admin.WrapErrorISE(err, "error reloading policy engines when creating authority policy") } - return nil + return policy, nil // TODO: return the newly stored policy } -func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) error { +func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { a.adminMutex.Lock() defer a.adminMutex.Unlock() if err := a.checkPolicy(ctx, adm, policy); err != nil { - return err + return nil, err } if err := a.adminDB.UpdateAuthorityPolicy(ctx, policy); err != nil { - return err + return nil, err } if err := a.reloadPolicyEngines(ctx); err != nil { - return admin.WrapErrorISE(err, "error reloading policy engines when updating authority policy") + return nil, admin.WrapErrorISE(err, "error reloading policy engines when updating authority policy") } - return nil + return policy, nil // TODO: return the updated stored policy } func (a *Authority) RemoveAuthorityPolicy(ctx context.Context) error { diff --git a/go.sum b/go.sum index ba7cb531..e7681592 100644 --- a/go.sum +++ b/go.sum @@ -639,8 +639,9 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/thales-e-security/pool v0.0.2 h1:RAPs4q2EbWsTit6tpzuvTFlgFRJ3S8Evf5gtvVDbmPg= github.com/thales-e-security/pool v0.0.2/go.mod h1:qtpMm2+thHtqhLzTwgDBj/OuNnMpupY8mv0Phz0gjhU= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= diff --git a/policy/engine.go b/policy/engine.go index 63d8452a..c37e1f59 100755 --- a/policy/engine.go +++ b/policy/engine.go @@ -4,7 +4,9 @@ import ( "bytes" "crypto/x509" "crypto/x509/pkix" + "errors" "fmt" + "io" "net" "net/url" "reflect" @@ -40,6 +42,30 @@ type NamePolicyError struct { Detail string } +type NameError struct { + error + Reason NamePolicyReason +} + +func a() { + err := io.EOF + var ne *NameError + errors.As(err, ne) + errors.Is(err, ne) +} + +func newPolicyError(reason NamePolicyReason, err error) error { + return &NameError{ + error: err, + Reason: reason, + } +} + +func newPolicyErrorf(reason NamePolicyReason, format string, args ...interface{}) error { + err := fmt.Errorf(format, args...) + return newPolicyError(reason, err) +} + func (e *NamePolicyError) Error() string { switch e.Reason { case NotAuthorizedForThisName: diff --git a/policy/engine_test.go b/policy/engine_test.go index f7a4b20a..cf406e71 100755 --- a/policy/engine_test.go +++ b/policy/engine_test.go @@ -8,8 +8,9 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "github.com/smallstep/assert" "golang.org/x/crypto/ssh" + + "github.com/smallstep/assert" ) // TODO(hs): the functionality in the policy engine is a nice candidate for trying fuzzing on From 613c99f00f8fb156e732f80d352c3371da025d55 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Thu, 24 Mar 2022 13:10:49 +0100 Subject: [PATCH 44/44] Fix linting issues --- api/utils.go | 8 +++--- authority/admin/api/admin_test.go | 4 +-- authority/admin/api/middleware.go | 11 ------- authority/admin/api/middleware_test.go | 7 +---- authority/admin/api/policy.go | 10 +++---- authority/admin/db/nosql/policy.go | 14 ++++----- authority/policy.go | 25 ++++++++-------- authority/provisioner/aws.go | 5 ++-- authority/provisioner/sign_options.go | 9 +++--- authority/tls.go | 40 -------------------------- policy/engine.go | 26 ----------------- 11 files changed, 37 insertions(+), 122 deletions(-) diff --git a/api/utils.go b/api/utils.go index 67b46aa9..761430ed 100644 --- a/api/utils.go +++ b/api/utils.go @@ -83,13 +83,13 @@ func ReadProtoJSONWithCheck(w http.ResponseWriter, r io.Reader, m proto.Message) Status: http.StatusBadRequest, Message: err.Error(), } - data, err := json.Marshal(wrapper) // TODO(hs): handle err; even though it's very unlikely to fail + errData, err := json.Marshal(wrapper) if err != nil { panic(err) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) - w.Write(data) + w.Write(errData) return false } if err := protojson.Unmarshal(data, m); err != nil { @@ -99,13 +99,13 @@ func ReadProtoJSONWithCheck(w http.ResponseWriter, r io.Reader, m proto.Message) }{ Message: err.Error(), } - data, err := json.Marshal(wrapper) // TODO(hs): handle err; even though it's very unlikely to fail + errData, err := json.Marshal(wrapper) if err != nil { panic(err) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) - w.Write(data) + w.Write(errData) return false } diff --git a/authority/admin/api/admin_test.go b/authority/admin/api/admin_test.go index d9592ff2..678cf6a1 100644 --- a/authority/admin/api/admin_test.go +++ b/authority/admin/api/admin_test.go @@ -141,11 +141,11 @@ func (m *mockAdminAuthority) GetAuthorityPolicy(ctx context.Context) (*linkedca. return nil, errors.New("not implemented yet") } -func (m *mockAdminAuthority) CreateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { +func (m *mockAdminAuthority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { return nil, errors.New("not implemented yet") } -func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { +func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { return nil, errors.New("not implemented yet") } diff --git a/authority/admin/api/middleware.go b/authority/admin/api/middleware.go index 74bb2234..4ca62bfc 100644 --- a/authority/admin/api/middleware.go +++ b/authority/admin/api/middleware.go @@ -1,7 +1,6 @@ package api import ( - "context" "net/http" "go.step.sm/linkedca" @@ -70,13 +69,3 @@ func (h *Handler) checkAction(next nextHTTP, supportedInStandalone bool) nextHTT next(w, r) } } - -// adminFromContext searches the context for a *linkedca.Admin. -// Returns the admin or an error. -func adminFromContext(ctx context.Context) (*linkedca.Admin, error) { - val, ok := ctx.Value(admin.AdminContextKey).(*linkedca.Admin) - if !ok || val == nil { - return nil, admin.NewError(admin.ErrorBadRequestType, "admin not in context") - } - return val, nil -} diff --git a/authority/admin/api/middleware_test.go b/authority/admin/api/middleware_test.go index ffa319db..158374d0 100644 --- a/authority/admin/api/middleware_test.go +++ b/authority/admin/api/middleware_test.go @@ -169,12 +169,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { } next := func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - a := ctx.Value(admin.AdminContextKey) // verifying that the context now has a linkedca.Admin - adm, ok := a.(*linkedca.Admin) - if !ok { - t.Errorf("expected *linkedca.Admin; got %T", a) - return - } + adm := linkedca.AdminFromContext(ctx) // verifying that the context now has a linkedca.Admin opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})} if !cmp.Equal(adm, adm, opts...) { t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(adm, adm, opts...)) diff --git a/authority/admin/api/policy.go b/authority/admin/api/policy.go index 30e05c48..6b59803f 100644 --- a/authority/admin/api/policy.go +++ b/authority/admin/api/policy.go @@ -8,6 +8,7 @@ import ( "go.step.sm/linkedca" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" ) @@ -121,7 +122,7 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r } var newPolicy = new(linkedca.Policy) - if err := api.ReadProtoJSON(r.Body, newPolicy); err != nil { + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { api.WriteError(w, err) return } @@ -220,7 +221,7 @@ func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, } var newPolicy = new(linkedca.Policy) - if err := api.ReadProtoJSON(r.Body, newPolicy); err != nil { + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { api.WriteError(w, err) return } @@ -256,7 +257,7 @@ func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, } var policy = new(linkedca.Policy) - if err := api.ReadProtoJSON(r.Body, policy); err != nil { + if err := read.ProtoJSON(r.Body, policy); err != nil { api.WriteError(w, err) return } @@ -271,7 +272,7 @@ func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, api.ProtoJSONStatus(w, policy, http.StatusOK) } -// DeleteProvisionerPolicy ... +// DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -308,7 +309,6 @@ func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, api.JSON(w, &DeleteResponse{Status: "ok"}) } -// GetACMEAccountPolicy ... func (par *PolicyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { api.JSON(w, "ok") } diff --git a/authority/admin/db/nosql/policy.go b/authority/admin/db/nosql/policy.go index 94ff2a0e..8e11ddb0 100644 --- a/authority/admin/db/nosql/policy.go +++ b/authority/admin/db/nosql/policy.go @@ -63,13 +63,13 @@ func (db *DB) getDBAuthorityPolicy(ctx context.Context, authorityID string) (*db return dbap, nil } -func (db *DB) unmarshalAuthorityPolicy(data []byte, authorityID string) (*linkedca.Policy, error) { - dbap, err := db.unmarshalDBAuthorityPolicy(data, authorityID) - if err != nil { - return nil, err - } - return dbap.convert(), nil -} +// func (db *DB) unmarshalAuthorityPolicy(data []byte, authorityID string) (*linkedca.Policy, error) { +// dbap, err := db.unmarshalDBAuthorityPolicy(data, authorityID) +// if err != nil { +// return nil, err +// } +// return dbap.convert(), nil +// } func (db *DB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { diff --git a/authority/policy.go b/authority/policy.go index ee132f31..88f301e0 100644 --- a/authority/policy.go +++ b/authority/policy.go @@ -17,23 +17,23 @@ func (a *Authority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, e a.adminMutex.Lock() defer a.adminMutex.Unlock() - policy, err := a.adminDB.GetAuthorityPolicy(ctx) + p, err := a.adminDB.GetAuthorityPolicy(ctx) if err != nil { return nil, err } - return policy, nil + return p, nil } -func (a *Authority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { +func (a *Authority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) { a.adminMutex.Lock() defer a.adminMutex.Unlock() - if err := a.checkPolicy(ctx, adm, policy); err != nil { + if err := a.checkPolicy(ctx, adm, p); err != nil { return nil, err } - if err := a.adminDB.CreateAuthorityPolicy(ctx, policy); err != nil { + if err := a.adminDB.CreateAuthorityPolicy(ctx, p); err != nil { return nil, err } @@ -41,18 +41,18 @@ func (a *Authority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Adm return nil, admin.WrapErrorISE(err, "error reloading policy engines when creating authority policy") } - return policy, nil // TODO: return the newly stored policy + return p, nil // TODO: return the newly stored policy } -func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { +func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) { a.adminMutex.Lock() defer a.adminMutex.Unlock() - if err := a.checkPolicy(ctx, adm, policy); err != nil { + if err := a.checkPolicy(ctx, adm, p); err != nil { return nil, err } - if err := a.adminDB.UpdateAuthorityPolicy(ctx, policy); err != nil { + if err := a.adminDB.UpdateAuthorityPolicy(ctx, p); err != nil { return nil, err } @@ -60,7 +60,7 @@ func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Adm return nil, admin.WrapErrorISE(err, "error reloading policy engines when updating authority policy") } - return policy, nil // TODO: return the updated stored policy + return p, nil // TODO: return the updated stored policy } func (a *Authority) RemoveAuthorityPolicy(ctx context.Context) error { @@ -111,11 +111,10 @@ func isAllowed(engine authPolicy.X509Policy, sans []string) error { ) if allowed, err = engine.AreSANsAllowed(sans); err != nil { var policyErr *policy.NamePolicyError - if errors.As(err, &policyErr); policyErr.Reason == policy.NotAuthorizedForThisName { + if isPolicyErr := errors.As(err, &policyErr); isPolicyErr && policyErr.Reason == policy.NotAuthorizedForThisName { return fmt.Errorf("the provided policy would lock out %s from the CA. Please update your policy to include %s as an allowed name", sans, sans) - } else { - return err } + return err } if !allowed { diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index 0bbe546b..f8f14671 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -266,7 +266,6 @@ type AWS struct { Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` config *awsConfig - audiences Audiences ctl *Controller x509Policy policy.X509Policy sshHostPolicy policy.HostPolicy @@ -557,7 +556,7 @@ func (p *AWS) readURL(url string) ([]byte, error) { if err != nil { return nil, err } - return nil, fmt.Errorf("Request for metadata returned non-successful status code %d", + return nil, fmt.Errorf("request for metadata returned non-successful status code %d", resp.StatusCode) } @@ -590,7 +589,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) { } defer resp.Body.Close() if resp.StatusCode >= 400 { - return nil, fmt.Errorf("Request for API token returned non-successful status code %d", resp.StatusCode) + return nil, fmt.Errorf("request for API token returned non-successful status code %d", resp.StatusCode) } token, err := io.ReadAll(resp.Body) if err != nil { diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index 2d8a13c3..df2551a3 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -6,7 +6,6 @@ import ( "crypto/rsa" "crypto/x509" "crypto/x509/pkix" - "encoding/asn1" "encoding/json" "net" "net/http" @@ -427,10 +426,10 @@ func (v *x509NamePolicyValidator) Valid(cert *x509.Certificate, _ SignOptions) e return err } -var ( - stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} - stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) -) +// var ( +// stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} +// stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) +// ) // type stepProvisionerASN1 struct { // Type int diff --git a/authority/tls.go b/authority/tls.go index 297c796e..df38091c 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -237,14 +237,6 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign // the cert to see if the CA is allowed to sign it. func (a *Authority) isAllowedToSign(cert *x509.Certificate) (bool, error) { - // // check if certificate is an admin identity token certificate and the admin subject exists - // b := isAdminIdentityTokenCertificate(cert) - // _ = b - - // if isAdminIdentityTokenCertificate(cert) && a.admins.HasSubject(cert.Subject.CommonName) { - // return true, nil - // } - // if no policy is configured, the cert is implicitly allowed if a.x509Policy == nil { return true, nil @@ -253,38 +245,6 @@ func (a *Authority) isAllowedToSign(cert *x509.Certificate) (bool, error) { return a.x509Policy.AreCertificateNamesAllowed(cert) } -func isAdminIdentityTokenCertificate(cert *x509.Certificate) bool { - - // TODO: remove this check - - if cert.Issuer.CommonName != "" { - return false - } - - subject := cert.Subject.CommonName - if subject == "" { - return false - } - - dnsNames := cert.DNSNames - if len(dnsNames) != 1 { - return false - } - - if dnsNames[0] != subject { - return false - } - - extras := cert.ExtraExtensions - if len(extras) != 1 { - return false - } - - extra := extras[0] - - return extra.Id.Equal(asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}) -} - // Renew creates a new Certificate identical to the old certificate, except // with a validity window that begins 'now'. func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) { diff --git a/policy/engine.go b/policy/engine.go index c37e1f59..63d8452a 100755 --- a/policy/engine.go +++ b/policy/engine.go @@ -4,9 +4,7 @@ import ( "bytes" "crypto/x509" "crypto/x509/pkix" - "errors" "fmt" - "io" "net" "net/url" "reflect" @@ -42,30 +40,6 @@ type NamePolicyError struct { Detail string } -type NameError struct { - error - Reason NamePolicyReason -} - -func a() { - err := io.EOF - var ne *NameError - errors.As(err, ne) - errors.Is(err, ne) -} - -func newPolicyError(reason NamePolicyReason, err error) error { - return &NameError{ - error: err, - Reason: reason, - } -} - -func newPolicyErrorf(reason NamePolicyReason, format string, args ...interface{}) error { - err := fmt.Errorf(format, args...) - return newPolicyError(reason, err) -} - func (e *NamePolicyError) Error() string { switch e.Reason { case NotAuthorizedForThisName: