diff --git a/CHANGELOG.md b/CHANGELOG.md index a09e9a28..fc25c0ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased - 0.18.3] - DATE ### Added +- Added support for renew after expiry using the claim `allowRenewAfterExpiry`. ### Changed - Made SCEP CA URL paths dynamic ### Deprecated diff --git a/api/api.go b/api/api.go index 16e24bb2..912e39dd 100644 --- a/api/api.go +++ b/api/api.go @@ -33,6 +33,7 @@ type Authority interface { // context specifies the Authorize[Sign|Revoke|etc.] method. Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) AuthorizeSign(ott string) ([]provisioner.SignOption, error) + AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) GetTLSOptions() *config.TLSOptions Root(shasum string) (*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) @@ -43,7 +44,7 @@ type Authority interface { GetProvisioners(cursor string, limit int) (provisioner.List, string, error) Revoke(context.Context, *authority.RevokeOptions) error GetEncryptedKey(kid string) (string, error) - GetRoots() (federation []*x509.Certificate, err error) + GetRoots() ([]*x509.Certificate, error) GetFederation() ([]*x509.Certificate, error) Version() authority.Version } diff --git a/api/api_test.go b/api/api_test.go index c7528f9b..717621cd 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -13,6 +13,7 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/base64" "encoding/json" "encoding/pem" "fmt" @@ -34,6 +35,7 @@ import ( "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/templates" "go.step.sm/crypto/jose" + "go.step.sm/crypto/x509util" "golang.org/x/crypto/ssh" ) @@ -171,6 +173,7 @@ type mockAuthority struct { ret1, ret2 interface{} err error authorizeSign func(ott string) ([]provisioner.SignOption, error) + authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error) getTLSOptions func() *authority.TLSOptions root func(shasum string) (*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) @@ -208,6 +211,13 @@ func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, err return m.ret1.([]provisioner.SignOption), m.err } +func (m *mockAuthority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) { + if m.authorizeRenewToken != nil { + return m.authorizeRenewToken(ctx, ott) + } + return m.ret1.(*x509.Certificate), m.err +} + func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions { if m.getTLSOptions != nil { return m.getTLSOptions() @@ -920,48 +930,141 @@ func Test_caHandler_Renew(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } + + // Prepare root and leaf for renew after expiry test. + now := time.Now() + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + leafPub, leafPriv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + root := &x509.Certificate{ + Subject: pkix.Name{CommonName: "Test Root CA"}, + PublicKey: rootPub, + KeyUsage: x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IsCA: true, + NotBefore: now.Add(-2 * time.Hour), + NotAfter: now.Add(time.Hour), + } + root, err = x509util.CreateCertificate(root, root, rootPub, rootPriv) + if err != nil { + t.Fatal(err) + } + expiredLeaf := &x509.Certificate{ + Subject: pkix.Name{CommonName: "Leaf certificate"}, + PublicKey: leafPub, + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + EmailAddresses: []string{"test@example.org"}, + } + expiredLeaf, err = x509util.CreateCertificate(expiredLeaf, root, leafPub, rootPriv) + if err != nil { + t.Fatal(err) + } + + // Generate renew after expiry token + so := new(jose.SignerOptions) + so.WithType("JWT") + so.WithHeader("x5cInsecure", []string{base64.StdEncoding.EncodeToString(expiredLeaf.Raw)}) + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: leafPriv}, so) + if err != nil { + t.Fatal(err) + } + generateX5cToken := func(claims jose.Claims) string { + s, err := jose.Signed(sig).Claims(claims).CompactSerialize() + if err != nil { + t.Fatal(err) + } + return s + } + tests := []struct { name string tls *tls.ConnectionState + header http.Header cert *x509.Certificate root *x509.Certificate err error statusCode int }{ - {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, - {"no tls", nil, nil, nil, nil, http.StatusBadRequest}, - {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, - {"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, + {"ok", cs, nil, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, + {"ok renew after expiry", &tls.ConnectionState{}, http.Header{ + "Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{ + NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + })}, + }, expiredLeaf, root, nil, http.StatusCreated}, + {"no tls", nil, nil, nil, nil, nil, http.StatusBadRequest}, + {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, nil, http.StatusBadRequest}, + {"renew error", cs, nil, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, + {"fail expired token", &tls.ConnectionState{}, http.Header{ + "Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{ + NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)), + })}, + }, expiredLeaf, root, errs.Forbidden("an error"), http.StatusUnauthorized}, + {"fail invalid root", &tls.ConnectionState{}, http.Header{ + "Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{ + NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)), + })}, + }, expiredLeaf, parseCertificate(rootPEM), errs.Forbidden("an error"), http.StatusUnauthorized}, } - expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, + authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) { + jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root}) + if err != nil { + return nil, errs.Unauthorized(err.Error()) + } + var claims jose.Claims + if err := jwt.Claims(chain[0][0].PublicKey, &claims); err != nil { + return nil, errs.Unauthorized(err.Error()) + } + if err := claims.ValidateWithLeeway(jose.Expected{ + Time: now, + }, time.Minute); err != nil { + return nil, errs.Unauthorized(err.Error()) + } + return chain[0][0], nil + }, getTLSOptions: func() *authority.TLSOptions { return nil }, }).(*caHandler) req := httptest.NewRequest("POST", "http://example.com/renew", nil) req.TLS = tt.tls + req.Header = tt.header w := httptest.NewRecorder() h.Renew(logging.NewResponseLogger(w), req) - res := w.Result() - if res.StatusCode != tt.statusCode { - t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) - } + res := w.Result() + defer res.Body.Close() body, err := io.ReadAll(res.Body) - res.Body.Close() if err != nil { t.Errorf("caHandler.Renew unexpected error = %v", err) } + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + t.Errorf("%s", body) + } + if tt.statusCode < http.StatusBadRequest { + expected := []byte(`{"crt":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `",` + + `"ca":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `",` + + `"certChain":["` + + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `","` + + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `"]}`) + if !bytes.Equal(bytes.TrimSpace(body), expected) { - t.Errorf("caHandler.Root Body = %s, wants %s", body, expected) + t.Errorf("caHandler.Root Body = \n%s, wants \n%s", body, expected) } } }) diff --git a/api/renew.go b/api/renew.go index 725322ee..408d91a3 100644 --- a/api/renew.go +++ b/api/renew.go @@ -1,20 +1,28 @@ package api import ( + "crypto/x509" "net/http" + "strings" "github.com/smallstep/certificates/errs" ) +const ( + authorizationHeader = "Authorization" + bearerScheme = "Bearer" +) + // Renew uses the information of certificate in the TLS connection to create a // new one. func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { - if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing client certificate")) + cert, err := h.getPeerCertificate(r) + if err != nil { + WriteError(w, err) return } - certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0]) + certChain, err := h.Authority.Renew(cert) if err != nil { WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) return @@ -33,3 +41,15 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { TLSOptions: h.Authority.GetTLSOptions(), }, http.StatusCreated) } + +func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) { + if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { + return r.TLS.PeerCertificates[0], nil + } + if s := r.Header.Get(authorizationHeader); s != "" { + if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 { + return h.Authority.AuthorizeRenewToken(r.Context(), parts[1]) + } + } + return nil, errs.BadRequest("missing client certificate") +} diff --git a/authority/authority.go b/authority/authority.go index f396c588..cc26635e 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -70,10 +70,12 @@ type Authority struct { startTime time.Time // Custom functions - sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error) - sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error) - sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) - getIdentityFunc provisioner.GetIdentityFunc + sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error) + sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error) + sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) + getIdentityFunc provisioner.GetIdentityFunc + authorizeRenewFunc provisioner.AuthorizeRenewFunc + authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc adminMutex sync.RWMutex } diff --git a/authority/authorize.go b/authority/authorize.go index 5108f567..7c1c2ff6 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "encoding/hex" "net/http" + "net/url" "strconv" "strings" "time" @@ -276,6 +277,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error { func (a *Authority) authorizeRenew(cert *x509.Certificate) error { serial := cert.SerialNumber.String() var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)} + isRevoked, err := a.IsRevoked(serial) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) @@ -283,7 +285,6 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error { if isRevoked { return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...) } - p, ok := a.provisioners.LoadByCertificate(cert) if !ok { return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...) @@ -371,3 +372,80 @@ func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error } return nil } + +// AuthorizeRenewToken validates the renew token and returns the leaf +// certificate in the x5cInsecure header. +func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) { + var claims jose.Claims + jwt, chain, err := jose.ParseX5cInsecure(ott, a.rootX509Certs) + if err != nil { + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token")) + } + leaf := chain[0][0] + if err := jwt.Claims(leaf.PublicKey, &claims); err != nil { + return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token")) + } + + p, ok := a.provisioners.LoadByCertificate(leaf) + if !ok { + return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate") + } + if err := a.UseToken(ott, p); err != nil { + return nil, err + } + + if err := claims.ValidateWithLeeway(jose.Expected{ + Issuer: p.GetName(), + Subject: leaf.Subject.CommonName, + Time: time.Now().UTC(), + }, time.Minute); err != nil { + switch err { + case jose.ErrInvalidIssuer: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid issuer claim (iss)")) + case jose.ErrInvalidSubject: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid subject claim (sub)")) + case jose.ErrNotValidYet: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token not valid yet (nbf)")) + case jose.ErrExpired: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token is expired (exp)")) + case jose.ErrIssuedInTheFuture: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token issued in the future (iat)")) + default: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token")) + } + } + + audiences := a.config.GetAudiences().Renew + if !matchesAudience(claims.Audience, audiences) { + return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token: invalid audience claim (aud)")) + } + + return leaf, nil +} + +// matchesAudience returns true if A and B share at least one element. +func matchesAudience(as, bs []string) bool { + if len(bs) == 0 || len(as) == 0 { + return false + } + + for _, b := range bs { + for _, a := range as { + if b == a || stripPort(a) == stripPort(b) { + return true + } + } + } + return false +} + +// stripPort attempts to strip the port from the given url. If parsing the url +// produces errors it will just return the passed argument. +func stripPort(rawurl string) string { + u, err := url.Parse(rawurl) + if err != nil { + return rawurl + } + u.Host = u.Hostname() + return u.String() +} diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 6d524a25..b631741a 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -3,11 +3,15 @@ package authority import ( "context" "crypto" + "crypto/ed25519" "crypto/rand" "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" "encoding/base64" "fmt" "net/http" + "reflect" "strconv" "testing" "time" @@ -20,6 +24,7 @@ import ( "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" + "go.step.sm/crypto/x509util" "golang.org/x/crypto/ssh" ) @@ -753,6 +758,7 @@ func TestAuthority_Authorize(t *testing.T) { func TestAuthority_authorizeRenew(t *testing.T) { fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt") + fooCrt.NotAfter = time.Now().Add(time.Hour) assert.FatalError(t, err) renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt") @@ -822,7 +828,7 @@ func TestAuthority_authorizeRenew(t *testing.T) { return &authorizeTest{ auth: a, cert: renewDisabledCrt, - err: errors.New("authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'renew_disabled'"), + err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'renew_disabled'"), code: http.StatusUnauthorized, } }, @@ -909,6 +915,7 @@ func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *provisioner. } func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { + now := time.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) if err != nil { return nil, nil, err @@ -917,6 +924,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, if err != nil { return nil, nil, err } + if cert.ValidAfter == 0 { + cert.ValidAfter = uint64(now.Unix()) + } + if cert.ValidBefore == 0 { + cert.ValidBefore = uint64(now.Add(time.Hour).Unix()) + } if err := cert.SignCert(rand.Reader, signer); err != nil { return nil, nil, err } @@ -1003,6 +1016,23 @@ func TestAuthority_authorizeSSHSign(t *testing.T) { } func TestAuthority_authorizeSSHRenew(t *testing.T) { + now := time.Now().UTC() + sshpop := func(a *Authority) (*ssh.Certificate, string) { + p, ok := a.provisioners.Load("sshpop/sshpop") + assert.Fatal(t, ok, "sshpop provisioner not found in test authority") + key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") + assert.FatalError(t, err) + signer, ok := key.(crypto.Signer) + assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") + sshSigner, err := ssh.NewSignerFromSigner(signer) + assert.FatalError(t, err) + cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) + assert.FatalError(t, err) + token, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", []string{"foo.smallstep.com"}, now, jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + return cert, token + } + a := testAuthority(t) jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) @@ -1012,8 +1042,6 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) { (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) assert.FatalError(t, err) - now := time.Now().UTC() - validIssuer := "step-cli" type authorizeTest struct { @@ -1050,27 +1078,34 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) { code: http.StatusUnauthorized, } }, + "fail/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest { + aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error { + return errs.Forbidden("forbidden") + })) + _, token := sshpop(aa) + return &authorizeTest{ + auth: aa, + token: token, + err: errors.New("authority.authorizeSSHRenew: forbidden"), + code: http.StatusForbidden, + } + }, "ok": func(t *testing.T) *authorizeTest { - key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") - assert.FatalError(t, err) - signer, ok := key.(crypto.Signer) - assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") - sshSigner, err := ssh.NewSignerFromSigner(signer) - assert.FatalError(t, err) - - cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) - assert.FatalError(t, err) - - p, ok := a.provisioners.Load("sshpop/sshpop") - assert.Fatal(t, ok, "sshpop provisioner not found in test authority") - - tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", - []string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert)) - assert.FatalError(t, err) - + cert, token := sshpop(a) return &authorizeTest{ auth: a, - token: tok, + token: token, + cert: cert, + } + }, + "ok/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest { + aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error { + return nil + })) + cert, token := sshpop(aa) + return &authorizeTest{ + auth: aa, + token: token, cert: cert, } }, @@ -1290,3 +1325,283 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) { }) } } + +func TestAuthority_AuthorizeRenewToken(t *testing.T) { + ctx := context.Background() + type stepProvisionerASN1 struct { + Type int + Name []byte + CredentialID []byte + KeyValuePairs []string `asn1:"optional,omitempty"` + } + + _, signer, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + csr, err := x509util.CreateCertificateRequest("test.example.com", []string{"test.example.com"}, signer) + if err != nil { + t.Fatal(err) + } + _, otherSigner, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + + generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) { + chain, err := a.Sign(csr, provisioner.SignOptions{}, opts...) + if err != nil { + t.Fatal(err) + } + + var x5c []string + for _, c := range chain { + x5c = append(x5c, base64.StdEncoding.EncodeToString(c.Raw)) + } + + so := new(jose.SignerOptions) + so.WithType("JWT") + so.WithHeader("x5cInsecure", x5c) + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: key}, so) + if err != nil { + t.Fatal(err) + } + s, err := jose.Signed(sig).Claims(claims).CompactSerialize() + if err != nil { + t.Fatal(err) + } + return s, chain[0] + } + + now := time.Now() + a1 := testAuthority(t) + t1, c1 := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + t2, c2 := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + IssuedAt: jose.NewNumericDate(now), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now.Add(-time.Hour) + cert.NotAfter = now.Add(-time.Minute) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badProvisioner, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badIssuer, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "bad-issuer", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badSubject, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "bad-subject", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badNotBefore, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/sign"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now.Add(5 * time.Minute)), + Expiry: jose.NewNumericDate(now.Add(10 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badExpiry, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/sign"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now.Add(-5 * time.Minute)), + Expiry: jose.NewNumericDate(now.Add(-time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badIssuedAt, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/sign"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + IssuedAt: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badAudience, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/sign"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + + type args struct { + ctx context.Context + ott string + } + tests := []struct { + name string + authority *Authority + args args + want *x509.Certificate + wantErr bool + }{ + {"ok", a1, args{ctx, t1}, c1, false}, + {"ok expired cert", a1, args{ctx, t2}, c2, false}, + {"fail token", a1, args{ctx, "not.a.token"}, nil, true}, + {"fail token reuse", a1, args{ctx, t1}, nil, true}, + {"fail token signature", a1, args{ctx, badSigner}, nil, true}, + {"fail token provisioner", a1, args{ctx, badProvisioner}, nil, true}, + {"fail token iss", a1, args{ctx, badIssuer}, nil, true}, + {"fail token sub", a1, args{ctx, badSubject}, nil, true}, + {"fail token iat", a1, args{ctx, badNotBefore}, nil, true}, + {"fail token iat", a1, args{ctx, badExpiry}, nil, true}, + {"fail token iat", a1, args{ctx, badIssuedAt}, nil, true}, + {"fail token aud", a1, args{ctx, badAudience}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.authority.AuthorizeRenewToken(tt.args.ctx, tt.args.ott) + if (err != nil) != tt.wantErr { + t.Errorf("Authority.AuthorizeRenewToken() error = %+v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.AuthorizeRenewToken() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/config/config.go b/authority/config/config.go index 589b5bbf..2c437725 100644 --- a/authority/config/config.go +++ b/authority/config/config.go @@ -26,23 +26,27 @@ var ( DefaultBackdate = time.Minute // DefaultDisableRenewal disables renewals per provisioner. DefaultDisableRenewal = false + // DefaultAllowRenewAfterExpiry allows renewals even if the certificate is + // expired. + DefaultAllowRenewAfterExpiry = false // DefaultEnableSSHCA enable SSH CA features per provisioner or globally // for all provisioners. DefaultEnableSSHCA = false // GlobalProvisionerClaims default claims for the Authority. Can be overridden // by provisioner specific claims. GlobalProvisionerClaims = provisioner.Claims{ - MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs - MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DisableRenewal: &DefaultDisableRenewal, - MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs - MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour}, - MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs - MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, - DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, - EnableSSHCA: &DefaultEnableSSHCA, + MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs + MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs + MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour}, + MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs + MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, + DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, + EnableSSHCA: &DefaultEnableSSHCA, + DisableRenewal: &DefaultDisableRenewal, + AllowRenewAfterExpiry: &DefaultAllowRenewAfterExpiry, } ) @@ -269,28 +273,32 @@ func (c *Config) GetAudiences() provisioner.Audiences { } for _, name := range c.DNSNames { + hostname := toHostname(name) audiences.Sign = append(audiences.Sign, - fmt.Sprintf("https://%s/1.0/sign", toHostname(name)), - fmt.Sprintf("https://%s/sign", toHostname(name)), - fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)), - fmt.Sprintf("https://%s/ssh/sign", toHostname(name))) + fmt.Sprintf("https://%s/1.0/sign", hostname), + fmt.Sprintf("https://%s/sign", hostname), + fmt.Sprintf("https://%s/1.0/ssh/sign", hostname), + fmt.Sprintf("https://%s/ssh/sign", hostname)) + audiences.Renew = append(audiences.Renew, + fmt.Sprintf("https://%s/1.0/renew", hostname), + fmt.Sprintf("https://%s/renew", hostname)) audiences.Revoke = append(audiences.Revoke, - fmt.Sprintf("https://%s/1.0/revoke", toHostname(name)), - fmt.Sprintf("https://%s/revoke", toHostname(name))) + fmt.Sprintf("https://%s/1.0/revoke", hostname), + fmt.Sprintf("https://%s/revoke", hostname)) audiences.SSHSign = append(audiences.SSHSign, - fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)), - fmt.Sprintf("https://%s/ssh/sign", toHostname(name)), - fmt.Sprintf("https://%s/1.0/sign", toHostname(name)), - fmt.Sprintf("https://%s/sign", toHostname(name))) + fmt.Sprintf("https://%s/1.0/ssh/sign", hostname), + fmt.Sprintf("https://%s/ssh/sign", hostname), + fmt.Sprintf("https://%s/1.0/sign", hostname), + fmt.Sprintf("https://%s/sign", hostname)) audiences.SSHRevoke = append(audiences.SSHRevoke, - fmt.Sprintf("https://%s/1.0/ssh/revoke", toHostname(name)), - fmt.Sprintf("https://%s/ssh/revoke", toHostname(name))) + fmt.Sprintf("https://%s/1.0/ssh/revoke", hostname), + fmt.Sprintf("https://%s/ssh/revoke", hostname)) audiences.SSHRenew = append(audiences.SSHRenew, - fmt.Sprintf("https://%s/1.0/ssh/renew", toHostname(name)), - fmt.Sprintf("https://%s/ssh/renew", toHostname(name))) + fmt.Sprintf("https://%s/1.0/ssh/renew", hostname), + fmt.Sprintf("https://%s/ssh/renew", hostname)) audiences.SSHRekey = append(audiences.SSHRekey, - fmt.Sprintf("https://%s/1.0/ssh/rekey", toHostname(name)), - fmt.Sprintf("https://%s/ssh/rekey", toHostname(name))) + fmt.Sprintf("https://%s/1.0/ssh/rekey", hostname), + fmt.Sprintf("https://%s/ssh/rekey", hostname)) } return audiences diff --git a/authority/options.go b/authority/options.go index f92db99b..a1238b1d 100644 --- a/authority/options.go +++ b/authority/options.go @@ -92,6 +92,24 @@ func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, e } } +// WithAuthorizeRenewFunc sets a custom function that authorizes the renewal of +// an X.509 certificate. +func WithAuthorizeRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error) Option { + return func(a *Authority) error { + a.authorizeRenewFunc = fn + return nil + } +} + +// WithAuthorizeSSHRenewFunc sets a custom function that authorizes the renewal +// of a SSH certificate. +func WithAuthorizeSSHRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error) Option { + return func(a *Authority) error { + a.authorizeSSHRenewFunc = fn + return nil + } +} + // WithSSHBastionFunc sets a custom function to get the bastion for a // given user-host pair. func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*config.Bastion, error)) Option { diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index 21958d36..913d0ace 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -6,7 +6,6 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" ) // ACME is the acme provisioner type, an entity that can authorize the ACME @@ -24,7 +23,7 @@ type ACME struct { RequireEAB bool `json:"requireEAB,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -69,7 +68,7 @@ func (p *ACME) GetOptions() *Options { // DefaultTLSCertDuration returns the default TLS cert duration enforced by // the provisioner. func (p *ACME) DefaultTLSCertDuration() time.Duration { - return p.claimer.DefaultTLSCertDuration() + return p.ctl.Claimer.DefaultTLSCertDuration() } // Init initializes and validates the fields of a JWK type. @@ -81,12 +80,8 @@ func (p *ACME) Init(config Config) (err error) { return errors.New("provisioner name cannot be empty") } - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - return err + p.ctl, err = NewController(p, p.Claims, config) + return } // AuthorizeSign does not do any validation, because all validation is handled @@ -97,10 +92,10 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // modifiers / withOptions newProvisionerExtensionOption(TypeACME, p.Name, ""), newForceCNOption(p.ForceCN), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } @@ -118,8 +113,5 @@ func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error { // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } diff --git a/authority/provisioner/acme_test.go b/authority/provisioner/acme_test.go index bd173f87..bc4e97e0 100644 --- a/authority/provisioner/acme_test.go +++ b/authority/provisioner/acme_test.go @@ -91,6 +91,7 @@ func TestACME_Init(t *testing.T) { } func TestACME_AuthorizeRenew(t *testing.T) { + now := time.Now().Truncate(time.Second) type test struct { p *ACME cert *x509.Certificate @@ -104,21 +105,27 @@ func TestACME_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p.Claims = &Claims{DisableRenewal: &disable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ - p: p, - cert: &x509.Certificate{}, + p: p, + cert: &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }, code: http.StatusUnauthorized, - err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()), + err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { p, err := generateACME() assert.FatalError(t, err) return test{ - p: p, - cert: &x509.Certificate{}, + p: p, + cert: &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }, } }, } @@ -172,18 +179,18 @@ func TestACME_AuthorizeSign(t *testing.T) { for _, o := range opts { switch v := o.(type) { case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeACME)) + assert.Equals(t, v.Type, TypeACME) assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.CredentialID, "") assert.Len(t, 0, v.KeyValuePairs) case *forceCNOption: assert.Equals(t, v.ForceCN, tc.p.ForceCN) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration()) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index fdad7b4a..5f79d7d0 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -264,9 +264,8 @@ type AWS struct { IIDRoots string `json:"iidRoots,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer config *awsConfig - audiences Audiences + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -400,15 +399,11 @@ func (p *AWS) Init(config Config) (err error) { case p.InstanceAge.Value() < 0: return errors.New("provisioner instanceAge cannot be negative") } - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } + // Add default config if p.config, err = newAWSConfig(p.IIDRoots); err != nil { return err } - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) // validate IMDS versions if len(p.IMDSVersions) == 0 { @@ -425,7 +420,9 @@ func (p *AWS) Init(config Config) (err error) { } } - return nil + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // AuthorizeSign validates the given token and returns the sign options that @@ -473,11 +470,11 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, commonNameValidator(payload.Claims.Subject), - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), ), nil } @@ -486,10 +483,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // assertConfig initializes the config if it has not been initialized @@ -664,7 +658,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { } // validate audiences with the defaults - if !matchesAudience(payload.Audience, p.audiences.Sign) { + if !matchesAudience(payload.Audience, p.ctl.Audiences.Sign) { return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)") } @@ -704,7 +698,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName()) } claims, err := p.authorizeToken(token) @@ -752,11 +746,11 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 0d2786db..559a48f1 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -677,18 +677,18 @@ func TestAWS_AuthorizeSign(t *testing.T) { switch v := o.(type) { case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeAWS)) + assert.Equals(t, v.Type, TypeAWS) assert.Equals(t, v.Name, tt.aws.GetName()) assert.Equals(t, v.CredentialID, tt.aws.Accounts[0]) assert.Len(t, 2, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.aws.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.aws.ctl.Claimer.DefaultTLSCertDuration()) case commonNameValidator: assert.Equals(t, string(v), tt.args.cn) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.aws.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.aws.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.aws.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.aws.ctl.Claimer.MaxTLSCertDuration()) case ipAddressesValidator: assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")}) case emailAddressesValidator: @@ -726,7 +726,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p3.Claims = &Claims{EnableSSHCA: &disable} - p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) + p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) assert.FatalError(t, err) t1, err := p1.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com") @@ -747,7 +747,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), @@ -824,6 +824,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { } func TestAWS_AuthorizeRenew(t *testing.T) { + now := time.Now().Truncate(time.Second) p1, err := generateAWS() assert.FatalError(t, err) p2, err := generateAWS() @@ -832,7 +833,7 @@ func TestAWS_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -845,8 +846,14 @@ func TestAWS_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index 384617e0..d9654566 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -96,10 +96,10 @@ type Azure struct { DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer config *azureConfig oidcConfig openIDConfiguration keyStore *keyStore + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -203,27 +203,24 @@ func (p *Azure) Init(config Config) (err error) { case p.Audience == "": // use default audience p.Audience = azureDefaultAudience } + // Initialize config p.assertConfig() - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - // Decode and validate openid-configuration endpoint - if err := getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil { - return err + if err = getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil { + return } if err := p.oidcConfig.Validate(); err != nil { return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL) } // Get JWK key set if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil { - return err + return } - return nil + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken returns the claims, name, group, subscription, identityObjectID, error. @@ -355,10 +352,10 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), ), nil } @@ -367,15 +364,12 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName()) } @@ -420,11 +414,11 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index 4ab734d5..c05685b7 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -506,18 +506,18 @@ func TestAzure_AuthorizeSign(t *testing.T) { switch v := o.(type) { case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeAzure)) + assert.Equals(t, v.Type, TypeAzure) assert.Equals(t, v.Name, tt.azure.GetName()) assert.Equals(t, v.CredentialID, tt.azure.TenantID) assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.azure.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.azure.ctl.Claimer.DefaultTLSCertDuration()) case commonNameValidator: assert.Equals(t, string(v), "virtualMachine") case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.azure.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.azure.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.azure.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.azure.ctl.Claimer.MaxTLSCertDuration()) case ipAddressesValidator: assert.Equals(t, v, nil) case emailAddressesValidator: @@ -536,6 +536,7 @@ func TestAzure_AuthorizeSign(t *testing.T) { } func TestAzure_AuthorizeRenew(t *testing.T) { + now := time.Now().Truncate(time.Second) p1, err := generateAzure() assert.FatalError(t, err) p2, err := generateAzure() @@ -544,7 +545,7 @@ func TestAzure_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -557,8 +558,14 @@ func TestAzure_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -595,7 +602,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p3.Claims = &Claims{EnableSSHCA: &disable} - p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) + p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) assert.FatalError(t, err) t1, err := p1.GetIdentityToken("subject", "caURL") @@ -616,7 +623,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"virtualMachine"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), diff --git a/authority/provisioner/claims.go b/authority/provisioner/claims.go index 629a313c..2a3e2c61 100644 --- a/authority/provisioner/claims.go +++ b/authority/provisioner/claims.go @@ -10,10 +10,10 @@ import ( // Claims so that individual provisioners can override global claims. type Claims struct { // TLS CA properties - MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"` - MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"` - DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"` - DisableRenewal *bool `json:"disableRenewal,omitempty"` + MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"` + MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"` + DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"` + // SSH CA properties MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"` MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"` @@ -22,6 +22,10 @@ type Claims struct { MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"` DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"` EnableSSHCA *bool `json:"enableSSHCA,omitempty"` + + // Renewal properties + DisableRenewal *bool `json:"disableRenewal,omitempty"` + AllowRenewAfterExpiry *bool `json:"allowRenewAfterExpiry,omitempty"` } // Claimer is the type that controls claims. It provides an interface around the @@ -40,19 +44,22 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) { // Claims returns the merge of the inner and global claims. func (c *Claimer) Claims() Claims { disableRenewal := c.IsDisableRenewal() + allowRenewAfterExpiry := c.AllowRenewAfterExpiry() enableSSHCA := c.IsSSHCAEnabled() + return Claims{ - MinTLSDur: &Duration{c.MinTLSCertDuration()}, - MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, - DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, - DisableRenewal: &disableRenewal, - MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()}, - MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()}, - DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()}, - MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()}, - MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()}, - DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()}, - EnableSSHCA: &enableSSHCA, + MinTLSDur: &Duration{c.MinTLSCertDuration()}, + MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, + DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, + MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()}, + MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()}, + DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()}, + MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()}, + MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()}, + DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()}, + EnableSSHCA: &enableSSHCA, + DisableRenewal: &disableRenewal, + AllowRenewAfterExpiry: &allowRenewAfterExpiry, } } @@ -102,6 +109,16 @@ func (c *Claimer) IsDisableRenewal() bool { return *c.claims.DisableRenewal } +// AllowRenewAfterExpiry returns if the renewal flow is authorized if the +// certificate is expired. If the property is not set within the provisioner +// then the global value from the authority configuration will be used. +func (c *Claimer) AllowRenewAfterExpiry() bool { + if c.claims == nil || c.claims.AllowRenewAfterExpiry == nil { + return *c.global.AllowRenewAfterExpiry + } + return *c.claims.AllowRenewAfterExpiry +} + // DefaultSSHCertDuration returns the default SSH certificate duration for the // given certificate type. func (c *Claimer) DefaultSSHCertDuration(certType uint32) (time.Duration, error) { diff --git a/authority/provisioner/collection.go b/authority/provisioner/collection.go index 1bec8689..8bbace5f 100644 --- a/authority/provisioner/collection.go +++ b/authority/provisioner/collection.go @@ -152,8 +152,8 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) // proper id to load the provisioner. func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) { for _, e := range cert.Extensions { - if e.Id.Equal(stepOIDProvisioner) { - var provisioner stepProvisionerASN1 + if e.Id.Equal(StepOIDProvisioner) { + var provisioner extensionASN1 if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { return nil, false } diff --git a/authority/provisioner/collection_test.go b/authority/provisioner/collection_test.go index 348b797c..24db4593 100644 --- a/authority/provisioner/collection_test.go +++ b/authority/provisioner/collection_test.go @@ -147,6 +147,17 @@ func TestCollection_LoadByToken(t *testing.T) { } func TestCollection_LoadByCertificate(t *testing.T) { + mustExtension := func(typ Type, name, credentialID string) pkix.Extension { + e := Extension{ + Type: typ, Name: name, CredentialID: credentialID, + } + ext, err := e.ToExtension() + if err != nil { + t.Fatal(err) + } + return ext + } + p1, err := generateJWK() assert.FatalError(t, err) p2, err := generateOIDC() @@ -159,30 +170,21 @@ func TestCollection_LoadByCertificate(t *testing.T) { byName.Store(p2.GetName(), p2) byName.Store(p3.GetName(), p3) - ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID) - assert.FatalError(t, err) - ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID) - assert.FatalError(t, err) - ok3Ext, err := createProvisionerExtension(int(TypeACME), p3.Name, "") - assert.FatalError(t, err) - notFoundExt, err := createProvisionerExtension(1, "foo", "bar") - assert.FatalError(t, err) - ok1Cert := &x509.Certificate{ - Extensions: []pkix.Extension{ok1Ext}, + Extensions: []pkix.Extension{mustExtension(1, p1.Name, p1.Key.KeyID)}, } ok2Cert := &x509.Certificate{ - Extensions: []pkix.Extension{ok2Ext}, + Extensions: []pkix.Extension{mustExtension(2, p2.Name, p2.ClientID)}, } ok3Cert := &x509.Certificate{ - Extensions: []pkix.Extension{ok3Ext}, + Extensions: []pkix.Extension{mustExtension(TypeACME, p3.Name, "")}, } notFoundCert := &x509.Certificate{ - Extensions: []pkix.Extension{notFoundExt}, + Extensions: []pkix.Extension{mustExtension(1, "foo", "bar")}, } badCert := &x509.Certificate{ Extensions: []pkix.Extension{ - {Id: stepOIDProvisioner, Critical: false, Value: []byte("foobar")}, + {Id: StepOIDProvisioner, Critical: false, Value: []byte("foobar")}, }, } diff --git a/authority/provisioner/controller.go b/authority/provisioner/controller.go new file mode 100644 index 00000000..a91ebaac --- /dev/null +++ b/authority/provisioner/controller.go @@ -0,0 +1,194 @@ +package provisioner + +import ( + "context" + "crypto/x509" + "regexp" + "strings" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" + "golang.org/x/crypto/ssh" +) + +// Controller wraps a provisioner with other attributes useful in callback +// functions. +type Controller struct { + Interface + Audiences *Audiences + Claimer *Claimer + IdentityFunc GetIdentityFunc + AuthorizeRenewFunc AuthorizeRenewFunc + AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc +} + +// NewController initializes a new provisioner controller. +func NewController(p Interface, claims *Claims, config Config) (*Controller, error) { + claimer, err := NewClaimer(claims, config.Claims) + if err != nil { + return nil, err + } + return &Controller{ + Interface: p, + Audiences: &config.Audiences, + Claimer: claimer, + IdentityFunc: config.GetIdentityFunc, + AuthorizeRenewFunc: config.AuthorizeRenewFunc, + AuthorizeSSHRenewFunc: config.AuthorizeSSHRenewFunc, + }, nil +} + +// GetIdentity returns the identity for a given email. +func (c *Controller) GetIdentity(ctx context.Context, email string) (*Identity, error) { + if c.IdentityFunc != nil { + return c.IdentityFunc(ctx, c.Interface, email) + } + return DefaultIdentityFunc(ctx, c.Interface, email) +} + +// AuthorizeRenew returns nil if the given cert can be renewed, returns an error +// otherwise. +func (c *Controller) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { + if c.AuthorizeRenewFunc != nil { + return c.AuthorizeRenewFunc(ctx, c, cert) + } + return DefaultAuthorizeRenew(ctx, c, cert) +} + +// AuthorizeSSHRenew returns nil if the given cert can be renewed, returns an +// error otherwise. +func (c *Controller) AuthorizeSSHRenew(ctx context.Context, cert *ssh.Certificate) error { + if c.AuthorizeSSHRenewFunc != nil { + return c.AuthorizeSSHRenewFunc(ctx, c, cert) + } + return DefaultAuthorizeSSHRenew(ctx, c, cert) +} + +// Identity is the type representing an externally supplied identity that is used +// by provisioners to populate certificate fields. +type Identity struct { + Usernames []string `json:"usernames"` + Permissions `json:"permissions"` +} + +// GetIdentityFunc is a function that returns an identity. +type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error) + +// AuthorizeRenewFunc is a function that returns nil if the renewal of a +// certificate is enabled. +type AuthorizeRenewFunc func(ctx context.Context, p *Controller, cert *x509.Certificate) error + +// AuthorizeSSHRenewFunc is a function that returns nil if the renewal of the +// given SSH certificate is enabled. +type AuthorizeSSHRenewFunc func(ctx context.Context, p *Controller, cert *ssh.Certificate) error + +// DefaultIdentityFunc return a default identity depending on the provisioner +// type. For OIDC email is always present and the usernames might +// contain empty strings. +func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) { + switch k := p.(type) { + case *OIDC: + // OIDC principals would be: + // ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled + // 2. Sanitized local. + // 3. Raw local (if different). + // 4. Email address. + name := SanitizeSSHUserPrincipal(email) + if !sshUserRegex.MatchString(name) { + return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email) + } + usernames := []string{name} + if i := strings.LastIndex(email, "@"); i >= 0 { + usernames = append(usernames, email[:i]) + } + usernames = append(usernames, email) + return &Identity{ + Usernames: SanitizeStringSlices(usernames), + }, nil + default: + return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k) + } +} + +// DefaultAuthorizeRenew is the default implementation of AuthorizeRenew. It +// will return an error if the provisioner has the renewal disabled, if the +// certificate is not yet valid or if the certificate is expired and renew after +// expiry is disabled. +func DefaultAuthorizeRenew(ctx context.Context, p *Controller, cert *x509.Certificate) error { + if p.Claimer.IsDisableRenewal() { + return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName()) + } + + now := time.Now().Truncate(time.Second) + if now.Before(cert.NotBefore) { + return errs.Unauthorized("certificate is not yet valid" + " " + now.UTC().Format(time.RFC3339Nano) + " vs " + cert.NotBefore.Format(time.RFC3339Nano)) + } + if now.After(cert.NotAfter) && !p.Claimer.AllowRenewAfterExpiry() { + return errs.Unauthorized("certificate has expired") + } + + return nil +} + +// DefaultAuthorizeSSHRenew is the default implementation of AuthorizeSSHRenew. It +// will return an error if the provisioner has the renewal disabled, if the +// certificate is not yet valid or if the certificate is expired and renew after +// expiry is disabled. +func DefaultAuthorizeSSHRenew(ctx context.Context, p *Controller, cert *ssh.Certificate) error { + if p.Claimer.IsDisableRenewal() { + return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName()) + } + + unixNow := time.Now().Unix() + if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { + return errs.Unauthorized("certificate is not yet valid") + } + if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewAfterExpiry() { + return errs.Unauthorized("certificate has expired") + } + + return nil +} + +var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$") + +// SanitizeStringSlices removes duplicated an empty strings. +func SanitizeStringSlices(original []string) []string { + output := []string{} + seen := make(map[string]struct{}) + for _, entry := range original { + if entry == "" { + continue + } + if _, value := seen[entry]; !value { + seen[entry] = struct{}{} + output = append(output, entry) + } + } + return output +} + +// SanitizeSSHUserPrincipal grabs an email or a string with the format +// local@domain and returns a sanitized version of the local, valid to be used +// as a user name. If the email starts with a letter between a and z, the +// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`. +func SanitizeSSHUserPrincipal(email string) string { + if i := strings.LastIndex(email, "@"); i >= 0 { + email = email[:i] + } + return strings.Map(func(r rune) rune { + switch { + case r >= 'a' && r <= 'z': + return r + case r >= '0' && r <= '9': + return r + case r == '-': + return '-' + case r == '.': // drop dots + return -1 + default: + return '_' + } + }, strings.ToLower(email)) +} diff --git a/authority/provisioner/controller_test.go b/authority/provisioner/controller_test.go new file mode 100644 index 00000000..9fb90e9d --- /dev/null +++ b/authority/provisioner/controller_test.go @@ -0,0 +1,391 @@ +package provisioner + +import ( + "context" + "crypto/x509" + "fmt" + "reflect" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +var trueValue = true + +func mustClaimer(t *testing.T, claims *Claims, global Claims) *Claimer { + t.Helper() + c, err := NewClaimer(claims, global) + if err != nil { + t.Fatal(err) + } + return c +} +func mustDuration(t *testing.T, s string) *Duration { + t.Helper() + d, err := NewDuration(s) + if err != nil { + t.Fatal(err) + } + return d +} + +func TestNewController(t *testing.T) { + type args struct { + p Interface + claims *Claims + config Config + } + tests := []struct { + name string + args args + want *Controller + wantErr bool + }{ + {"ok", args{&JWK{}, nil, Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + }}, &Controller{ + Interface: &JWK{}, + Audiences: &testAudiences, + Claimer: mustClaimer(t, nil, globalProvisionerClaims), + }, false}, + {"ok with claims", args{&JWK{}, &Claims{ + DisableRenewal: &defaultDisableRenewal, + }, Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + }}, &Controller{ + Interface: &JWK{}, + Audiences: &testAudiences, + Claimer: mustClaimer(t, &Claims{ + DisableRenewal: &defaultDisableRenewal, + }, globalProvisionerClaims), + }, false}, + {"fail claimer", args{&JWK{}, &Claims{ + MinTLSDur: mustDuration(t, "24h"), + MaxTLSDur: mustDuration(t, "2h"), + }, Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewController(tt.args.p, tt.args.claims, tt.args.config) + if (err != nil) != tt.wantErr { + t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewController() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestController_GetIdentity(t *testing.T) { + ctx := context.Background() + type fields struct { + Interface Interface + IdentityFunc GetIdentityFunc + } + type args struct { + ctx context.Context + email string + } + tests := []struct { + name string + fields fields + args args + want *Identity + wantErr bool + }{ + {"ok", fields{&OIDC{}, nil}, args{ctx, "jane@doe.org"}, &Identity{ + Usernames: []string{"jane", "jane@doe.org"}, + }, false}, + {"ok custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) { + return &Identity{Usernames: []string{"jane"}}, nil + }}, args{ctx, "jane@doe.org"}, &Identity{ + Usernames: []string{"jane"}, + }, false}, + {"fail provisioner", fields{&JWK{}, nil}, args{ctx, "jane@doe.org"}, nil, true}, + {"fail custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) { + return nil, fmt.Errorf("an error") + }}, args{ctx, "jane@doe.org"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Controller{ + Interface: tt.fields.Interface, + IdentityFunc: tt.fields.IdentityFunc, + } + got, err := c.GetIdentity(tt.args.ctx, tt.args.email) + if (err != nil) != tt.wantErr { + t.Errorf("Controller.GetIdentity() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Controller.GetIdentity() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestController_AuthorizeRenew(t *testing.T) { + ctx := context.Background() + now := time.Now().Truncate(time.Second) + type fields struct { + Interface Interface + Claimer *Claimer + AuthorizeRenewFunc AuthorizeRenewFunc + } + type args struct { + ctx context.Context + cert *x509.Certificate + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { + return nil + }}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { + return nil + }}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + }}, false}, + {"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, true}, + {"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now.Add(time.Hour), + NotAfter: now.Add(2 * time.Hour), + }}, true}, + {"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + }}, true}, + {"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { + return fmt.Errorf("an error") + }}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Controller{ + Interface: tt.fields.Interface, + Claimer: tt.fields.Claimer, + AuthorizeRenewFunc: tt.fields.AuthorizeRenewFunc, + } + if err := c.AuthorizeRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("Controller.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestController_AuthorizeSSHRenew(t *testing.T) { + ctx := context.Background() + now := time.Now() + type fields struct { + Interface Interface + Claimer *Claimer + AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc + } + type args struct { + ctx context.Context + cert *ssh.Certificate + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, false}, + {"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { + return nil + }}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, false}, + {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { + return nil + }}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, false}, + {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Add(-time.Hour).Unix()), + ValidBefore: uint64(now.Add(-time.Minute).Unix()), + }}, false}, + {"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, true}, + {"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Add(time.Hour).Unix()), + ValidBefore: uint64(now.Add(2 * time.Hour).Unix()), + }}, true}, + {"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Add(-time.Hour).Unix()), + ValidBefore: uint64(now.Add(-time.Minute).Unix()), + }}, true}, + {"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { + return fmt.Errorf("an error") + }}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Controller{ + Interface: tt.fields.Interface, + Claimer: tt.fields.Claimer, + AuthorizeSSHRenewFunc: tt.fields.AuthorizeSSHRenewFunc, + } + if err := c.AuthorizeSSHRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("Controller.AuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestDefaultAuthorizeRenew(t *testing.T) { + ctx := context.Background() + now := time.Now().Truncate(time.Second) + type args struct { + ctx context.Context + p *Controller + cert *x509.Certificate + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, nil, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"ok renew after expiry", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + }}, false}, + {"fail disabled", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, true}, + {"fail not yet valid", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now.Add(time.Hour), + NotAfter: now.Add(2 * time.Hour), + }}, true}, + {"fail expired", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := DefaultAuthorizeRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("DefaultAuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestDefaultAuthorizeSSHRenew(t *testing.T) { + ctx := context.Background() + now := time.Now() + type args struct { + ctx context.Context + p *Controller + cert *ssh.Certificate + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, nil, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, false}, + {"ok renew after expiry", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Add(-time.Hour).Unix()), + ValidBefore: uint64(now.Add(-time.Minute).Unix()), + }}, false}, + {"fail disabled", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, true}, + {"fail not yet valid", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Add(time.Hour).Unix()), + ValidBefore: uint64(now.Add(2 * time.Hour).Unix()), + }}, true}, + {"fail expired", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Add(-time.Hour).Unix()), + ValidBefore: uint64(now.Add(-time.Minute).Unix()), + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := DefaultAuthorizeSSHRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("DefaultAuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/authority/provisioner/extension.go b/authority/provisioner/extension.go new file mode 100644 index 00000000..c316329d --- /dev/null +++ b/authority/provisioner/extension.go @@ -0,0 +1,73 @@ +package provisioner + +import ( + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" +) + +var ( + // StepOIDRoot is the root OID for smallstep. + StepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} + + // StepOIDProvisioner is the OID for the provisioner extension. + StepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(StepOIDRoot, 1)...) +) + +// Extension is the Go representation of the provisioner extension. +type Extension struct { + Type Type + Name string + CredentialID string + KeyValuePairs []string +} + +type extensionASN1 struct { + Type int + Name []byte + CredentialID []byte + KeyValuePairs []string `asn1:"optional,omitempty"` +} + +// Marshal marshals the extension using encoding/asn1. +func (e *Extension) Marshal() ([]byte, error) { + return asn1.Marshal(extensionASN1{ + Type: int(e.Type), + Name: []byte(e.Name), + CredentialID: []byte(e.CredentialID), + KeyValuePairs: e.KeyValuePairs, + }) +} + +// ToExtension returns the pkix.Extension representation of the provisioner +// extension. +func (e *Extension) ToExtension() (pkix.Extension, error) { + b, err := e.Marshal() + if err != nil { + return pkix.Extension{}, err + } + return pkix.Extension{ + Id: StepOIDProvisioner, + Value: b, + }, nil +} + +// GetProvisionerExtension goes through all the certificate extensions and +// returns the provisioner extension (1.3.6.1.4.1.37476.9000.64.1). +func GetProvisionerExtension(cert *x509.Certificate) (*Extension, bool) { + for _, e := range cert.Extensions { + if e.Id.Equal(StepOIDProvisioner) { + var provisioner extensionASN1 + if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { + return nil, false + } + return &Extension{ + Type: Type(provisioner.Type), + Name: string(provisioner.Name), + CredentialID: string(provisioner.CredentialID), + KeyValuePairs: provisioner.KeyValuePairs, + }, true + } + } + return nil, false +} diff --git a/authority/provisioner/extension_test.go b/authority/provisioner/extension_test.go new file mode 100644 index 00000000..69be9e18 --- /dev/null +++ b/authority/provisioner/extension_test.go @@ -0,0 +1,158 @@ +package provisioner + +import ( + "crypto/x509" + "crypto/x509/pkix" + "reflect" + "testing" + + "go.step.sm/crypto/pemutil" +) + +func TestExtension_Marshal(t *testing.T) { + type fields struct { + Type Type + Name string + CredentialID string + KeyValuePairs []string + } + tests := []struct { + name string + fields fields + want []byte + wantErr bool + }{ + {"ok", fields{TypeJWK, "name", "credentialID", nil}, []byte{ + 0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, + 0x44, + }, false}, + {"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, []byte{ + 0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, + 0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f, + 0x13, 0x03, 0x62, 0x61, 0x72, + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Extension{ + Type: tt.fields.Type, + Name: tt.fields.Name, + CredentialID: tt.fields.CredentialID, + KeyValuePairs: tt.fields.KeyValuePairs, + } + got, err := e.Marshal() + if (err != nil) != tt.wantErr { + t.Errorf("Extension.Marshal() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Extension.Marshal() = %x, want %v", got, tt.want) + } + }) + } +} + +func TestExtension_ToExtension(t *testing.T) { + type fields struct { + Type Type + Name string + CredentialID string + KeyValuePairs []string + } + tests := []struct { + name string + fields fields + want pkix.Extension + wantErr bool + }{ + {"ok", fields{TypeJWK, "name", "credentialID", nil}, pkix.Extension{ + Id: StepOIDProvisioner, + Value: []byte{ + 0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, + 0x44, + }, + }, false}, + {"ok empty pairs", fields{TypeJWK, "name", "credentialID", []string{}}, pkix.Extension{ + Id: StepOIDProvisioner, + Value: []byte{ + 0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, + 0x44, + }, + }, false}, + {"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, pkix.Extension{ + Id: StepOIDProvisioner, + Value: []byte{ + 0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, + 0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f, + 0x13, 0x03, 0x62, 0x61, 0x72, + }, + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Extension{ + Type: tt.fields.Type, + Name: tt.fields.Name, + CredentialID: tt.fields.CredentialID, + KeyValuePairs: tt.fields.KeyValuePairs, + } + got, err := e.ToExtension() + if (err != nil) != tt.wantErr { + t.Errorf("Extension.ToExtension() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Extension.ToExtension() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetProvisionerExtension(t *testing.T) { + mustCertificate := func(fn string) *x509.Certificate { + cert, err := pemutil.ReadCertificate(fn) + if err != nil { + t.Fatal(err) + } + return cert + } + + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + args args + want *Extension + want1 bool + }{ + {"ok", args{mustCertificate("testdata/certs/good-extension.crt")}, &Extension{ + Type: TypeJWK, + Name: "mariano@smallstep.com", + CredentialID: "nvgnR8wSzpUlrt_tC3mvrhwhBx9Y7T1WL_JjcFVWYBQ", + }, true}, + {"fail unmarshal", args{mustCertificate("testdata/certs/bad-extension.crt")}, nil, false}, + {"missing extension", args{mustCertificate("testdata/certs/aws.crt")}, nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := GetProvisionerExtension(tt.args.cert) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetProvisionerExtension() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("GetProvisionerExtension() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} diff --git a/authority/provisioner/gcp.go b/authority/provisioner/gcp.go index e46f4ce4..6070b640 100644 --- a/authority/provisioner/gcp.go +++ b/authority/provisioner/gcp.go @@ -88,10 +88,9 @@ type GCP struct { InstanceAge Duration `json:"instanceAge,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer config *gcpConfig keyStore *keyStore - audiences Audiences + ctl *Controller } // GetID returns the provisioner unique identifier. The name should uniquely @@ -194,8 +193,7 @@ func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) { } // Init validates and initializes the GCP provisioner. -func (p *GCP) Init(config Config) error { - var err error +func (p *GCP) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -204,20 +202,18 @@ func (p *GCP) Init(config Config) error { case p.InstanceAge.Value() < 0: return errors.New("provisioner instanceAge cannot be negative") } + // Initialize config p.assertConfig() - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } + // Initialize key store - p.keyStore, err = newKeyStore(p.config.CertsURL) - if err != nil { - return err + if p.keyStore, err = newKeyStore(p.config.CertsURL); err != nil { + return } - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) - return nil + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // AuthorizeSign validates the given token and returns the sign options that @@ -269,19 +265,16 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), ), nil } // AuthorizeRenew returns an error if the renewal is disabled. func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // assertConfig initializes the config if it has not been initialized. @@ -328,7 +321,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { } // validate audiences with the defaults - if !matchesAudience(claims.Audience, p.audiences.Sign) { + if !matchesAudience(claims.Audience, p.ctl.Audiences.Sign) { return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)") } @@ -383,7 +376,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName()) } claims, err := p.authorizeToken(token) @@ -431,11 +424,11 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index 5f6f9bc7..b8c437c3 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -549,18 +549,18 @@ func TestGCP_AuthorizeSign(t *testing.T) { switch v := o.(type) { case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeGCP)) + assert.Equals(t, v.Type, TypeGCP) assert.Equals(t, v.Name, tt.gcp.GetName()) assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0]) assert.Len(t, 4, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.gcp.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.gcp.ctl.Claimer.DefaultTLSCertDuration()) case commonNameSliceValidator: assert.Equals(t, []string(v), []string{"instance-name", "instance-id", "instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.gcp.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.gcp.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.gcp.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.gcp.ctl.Claimer.MaxTLSCertDuration()) case ipAddressesValidator: assert.Equals(t, v, nil) case emailAddressesValidator: @@ -595,7 +595,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p3.Claims = &Claims{EnableSSHCA: &disable} - p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) + p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) assert.FatalError(t, err) t1, err := generateGCPToken(p1.ServiceAccounts[0], @@ -622,7 +622,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), @@ -698,6 +698,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { } func TestGCP_AuthorizeRenew(t *testing.T) { + now := time.Now().Truncate(time.Second) p1, err := generateGCP() assert.FatalError(t, err) p2, err := generateGCP() @@ -706,7 +707,7 @@ func TestGCP_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -719,8 +720,14 @@ func TestGCP_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renewal-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renewal-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index 137915c8..c014bec0 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -35,8 +35,7 @@ type JWK struct { EncryptedKey string `json:"encryptedKey,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer - audiences Audiences + ctl *Controller } // GetID returns the provisioner unique identifier. The name and credential id @@ -98,13 +97,8 @@ func (p *JWK) Init(config Config) (err error) { return errors.New("provisioner key cannot be empty") } - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - p.audiences = config.Audiences - return err + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken performs common jwt authorization actions and returns the @@ -146,13 +140,13 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err // AuthorizeRevoke returns an error if the provisioner does not have rights to // revoke the certificate with serial number in the `sub` property. func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error { - _, err := p.authorizeToken(token, p.audiences.Revoke) + _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - claims, err := p.authorizeToken(token, p.audiences.Sign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") } @@ -179,12 +173,12 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators commonNameValidator(claims.Subject), defaultPublicKeyValidator{}, defaultSANsValidator(claims.SANs), - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } @@ -193,18 +187,15 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName()) } - claims, err := p.authorizeToken(token, p.audiences.SSHSign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign") } @@ -261,11 +252,11 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, return append(signOptions, // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, ), nil @@ -273,6 +264,6 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // AuthorizeSSHRevoke returns nil if the token is valid, false otherwise. func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error { - _, err := p.authorizeToken(token, p.audiences.SSHRevoke) + _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke) return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke") } diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index deae8f7a..dde2f836 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -76,13 +76,13 @@ func TestJWK_Init(t *testing.T) { }, "fail-bad-claims": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences, Claims: &Claims{DefaultTLSDur: &Duration{0}}}, + p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, Claims: &Claims{DefaultTLSDur: &Duration{0}}}, err: errors.New("claims: MinTLSCertDuration must be greater than 0"), } }, "ok": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences}, + p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}}, } }, } @@ -300,18 +300,18 @@ func TestJWK_AuthorizeSign(t *testing.T) { switch v := o.(type) { case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeJWK)) + assert.Equals(t, v.Type, TypeJWK) assert.Equals(t, v.Name, tt.prov.GetName()) assert.Equals(t, v.CredentialID, tt.prov.Key.KeyID) assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration()) case commonNameValidator: assert.Equals(t, string(v), "subject") case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) case defaultSANsValidator: assert.Equals(t, []string(v), tt.sans) default: @@ -325,6 +325,7 @@ func TestJWK_AuthorizeSign(t *testing.T) { } func TestJWK_AuthorizeRenew(t *testing.T) { + now := time.Now().Truncate(time.Second) p1, err := generateJWK() assert.FatalError(t, err) p2, err := generateJWK() @@ -333,7 +334,7 @@ func TestJWK_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -346,8 +347,14 @@ func TestJWK_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -373,7 +380,7 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p2.Claims = &Claims{EnableSSHCA: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) jwk, err := decryptJSONWebKey(p1.EncryptedKey) @@ -402,8 +409,8 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - userDuration := p1.claimer.DefaultUserSSHCertDuration() - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), @@ -485,8 +492,8 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) { signer, err := generateJSONWebKey() assert.FatalError(t, err) - userDuration := p1.claimer.DefaultUserSSHCertDuration() - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), diff --git a/authority/provisioner/k8sSA.go b/authority/provisioner/k8sSA.go index d260f5ec..557d571a 100644 --- a/authority/provisioner/k8sSA.go +++ b/authority/provisioner/k8sSA.go @@ -42,16 +42,15 @@ type k8sSAPayload struct { // entity trusted to make signature requests. type K8sSA struct { *base - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - PubKeys []byte `json:"publicKeys,omitempty"` - Claims *Claims `json:"claims,omitempty"` - Options *Options `json:"options,omitempty"` - claimer *Claimer - audiences Audiences + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` + PubKeys []byte `json:"publicKeys,omitempty"` + Claims *Claims `json:"claims,omitempty"` + Options *Options `json:"options,omitempty"` //kauthn kauthn.AuthenticationV1Interface pubKeys []interface{} + ctl *Controller } // GetID returns the provisioner unique identifier. The name and credential id @@ -138,13 +137,8 @@ func (p *K8sSA) Init(config Config) (err error) { p.kauthn = k8s.AuthenticationV1() */ - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - p.audiences = config.Audiences - return err + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken performs common jwt authorization actions and returns the @@ -211,13 +205,13 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, // AuthorizeRevoke returns an error if the provisioner does not have rights to // revoke the certificate with serial number in the `sub` property. func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error { - _, err := p.authorizeToken(token, p.audiences.Revoke) + _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - claims, err := p.authorizeToken(token, p.audiences.Sign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign") } @@ -240,27 +234,24 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeK8sSA, p.Name, ""), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } // AuthorizeRenew returns an error if the renewal is disabled. func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign validates an request for an SSH certificate. func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()) } - claims, err := p.authorizeToken(token, p.audiences.SSHSign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign") } @@ -282,11 +273,11 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio // Require type, key-id and principals in the SignSSHOptions. &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}, // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/k8sSA_test.go b/authority/provisioner/k8sSA_test.go index 176cdfd3..378d4471 100644 --- a/authority/provisioner/k8sSA_test.go +++ b/authority/provisioner/k8sSA_test.go @@ -179,6 +179,7 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) { } func TestK8sSA_AuthorizeRenew(t *testing.T) { + now := time.Now().Truncate(time.Second) type test struct { p *K8sSA cert *x509.Certificate @@ -192,21 +193,27 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p.Claims = &Claims{DisableRenewal: &disable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ - p: p, - cert: &x509.Certificate{}, + p: p, + cert: &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }, code: http.StatusUnauthorized, - err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName()), + err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { p, err := generateK8sSA(nil) assert.FatalError(t, err) return test{ - p: p, - cert: &x509.Certificate{}, + p: p, + cert: &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }, } }, } @@ -276,16 +283,16 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { switch v := o.(type) { case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeK8sSA)) + assert.Equals(t, v.Type, TypeK8sSA) assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.CredentialID, "") assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration()) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } @@ -313,7 +320,7 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p.Claims = &Claims{EnableSSHCA: &disable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ p: p, @@ -365,11 +372,11 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { case *sshCertOptionsRequireValidator: assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}) case *sshCertValidityValidator: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) case *sshDefaultPublicKeyValidator: case *sshCertDefaultValidator: case *sshDefaultDuration: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/nebula.go b/authority/provisioner/nebula.go index 72a275ff..1a6eee3e 100644 --- a/authority/provisioner/nebula.go +++ b/authority/provisioner/nebula.go @@ -34,19 +34,18 @@ const ( // https://signal.org/docs/specifications/xeddsa/#xeddsa and implemented by // go.step.sm/crypto/x25519. type Nebula struct { - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - Roots []byte `json:"roots"` - Claims *Claims `json:"claims,omitempty"` - Options *Options `json:"options,omitempty"` - claimer *Claimer - caPool *nebula.NebulaCAPool - audiences Audiences + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` + Roots []byte `json:"roots"` + Claims *Claims `json:"claims,omitempty"` + Options *Options `json:"options,omitempty"` + caPool *nebula.NebulaCAPool + ctl *Controller } // Init verifies and initializes the Nebula provisioner. -func (p *Nebula) Init(config Config) error { +func (p *Nebula) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -56,19 +55,14 @@ func (p *Nebula) Init(config Config) error { return errors.New("provisioner root(s) cannot be empty") } - var err error - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - p.caPool, err = nebula.NewCAPoolFromBytes(p.Roots) if err != nil { return errs.InternalServer("failed to create ca pool: %v", err) } - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) - - return nil + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // GetID returns the provisioner id. @@ -120,7 +114,7 @@ func (p *Nebula) GetEncryptedKey() (kid, key string, ok bool) { // AuthorizeSign returns the list of SignOption for a Sign request. func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - crt, claims, err := p.authorizeToken(token, p.audiences.Sign) + crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, err } @@ -154,7 +148,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // modifiers / withOptions newProvisionerExtensionOption(TypeNebula, p.Name, ""), profileLimitDuration{ - def: p.claimer.DefaultTLSCertDuration(), + def: p.ctl.Claimer.DefaultTLSCertDuration(), notBefore: crt.Details.NotBefore, notAfter: crt.Details.NotAfter, }, @@ -165,18 +159,18 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, IPs: crt.Details.Ips, }, defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // Currently the Nebula provisioner only grants host SSH certificates. func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) } - crt, claims, err := p.authorizeToken(token, p.audiences.SSHSign) + crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, err } @@ -254,11 +248,11 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti return append(signOptions, templateOptions, // Checks the validity bounds, and set the validity if has not been set. - &sshLimitDuration{p.claimer, crt.Details.NotAfter}, + &sshLimitDuration{p.ctl.Claimer, crt.Details.NotAfter}, // Validate public key. &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil @@ -266,23 +260,20 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti // AuthorizeRenew returns an error if the renewal is disabled. func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("renew is disabled for nebula provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, crt) } // AuthorizeRevoke returns an error if the token is not valid. func (p *Nebula) AuthorizeRevoke(ctx context.Context, token string) error { - return p.validateToken(token, p.audiences.Revoke) + return p.validateToken(token, p.ctl.Audiences.Revoke) } // AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid. func (p *Nebula) AuthorizeSSHRevoke(ctx context.Context, token string) error { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) } - if _, _, err := p.authorizeToken(token, p.audiences.SSHRevoke); err != nil { + if _, _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke); err != nil { return err } return nil diff --git a/authority/provisioner/nebula_test.go b/authority/provisioner/nebula_test.go index bc539af1..b190d607 100644 --- a/authority/provisioner/nebula_test.go +++ b/authority/provisioner/nebula_test.go @@ -327,7 +327,7 @@ func TestNebula_GetIDForToken(t *testing.T) { func TestNebula_GetTokenID(t *testing.T) { p, ca, signer := mustNebulaProvisioner(t) c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer) - t1 := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv) + t1 := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv) _, claims, err := parseToken(t1) if err != nil { t.Fatal(err) @@ -441,8 +441,8 @@ func TestNebula_AuthorizeSign(t *testing.T) { ctx := context.TODO() p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv) - okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), nil, crt, priv) + ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv) + okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), nil, crt, priv) pBadOptions, _, _ := mustNebulaProvisioner(t) pBadOptions.caPool = p.caPool @@ -483,20 +483,20 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) { // Ok provisioner p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan", "10.1.0.1"}, }, crt, priv) - okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), nil, crt, priv) - okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), nil, crt, priv) + okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ ValidAfter: NewTimeDuration(now().Add(1 * time.Hour)), ValidBefore: NewTimeDuration(now().Add(10 * time.Hour)), }, crt, priv) - failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ CertType: "user", }, crt, priv) - failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan", "10.1.0.1", "foo.bar"}, @@ -549,6 +549,8 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) { func TestNebula_AuthorizeRenew(t *testing.T) { ctx := context.TODO() + now := time.Now().Truncate(time.Second) + // Ok provisioner p, _, _ := mustNebulaProvisioner(t) @@ -567,8 +569,14 @@ func TestNebula_AuthorizeRenew(t *testing.T) { args args wantErr bool }{ - {"ok", p, args{ctx, &x509.Certificate{}}, false}, - {"fail disabled", pDisabled, args{ctx, &x509.Certificate{}}, true}, + {"ok", p, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"fail disabled", pDisabled, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -584,12 +592,12 @@ func TestNebula_AuthorizeRevoke(t *testing.T) { // Ok provisioner p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv) + ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv) // Fail different CA nc, signer := mustNebulaCA(t) crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer) - failToken := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv) + failToken := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv) type args struct { ctx context.Context @@ -618,12 +626,12 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) { // Ok provisioner p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv) + ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv) // Fail different CA nc, signer := mustNebulaCA(t) crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer) - failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv) + failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv) // Provisioner with SSH disabled var bFalse bool @@ -657,7 +665,7 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) { func TestNebula_AuthorizeSSHRenew(t *testing.T) { p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRenew[0], now(), nil, crt, priv) + t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRenew[0], now(), nil, crt, priv) type args struct { ctx context.Context @@ -689,7 +697,7 @@ func TestNebula_AuthorizeSSHRenew(t *testing.T) { func TestNebula_AuthorizeSSHRekey(t *testing.T) { p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRekey[0], now(), nil, crt, priv) + t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRekey[0], now(), nil, crt, priv) type args struct { ctx context.Context @@ -726,20 +734,20 @@ func TestNebula_authorizeToken(t *testing.T) { t1 := now() p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) - okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, nil, crt, priv) - okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, &SignSSHOptions{ + ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) + okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, nil, crt, priv) + okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan"}, }, crt, priv) - okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, nil, crt, priv) + okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, nil, crt, priv) // Token with errors - failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv) - failIssuer := mustNebulaToken(t, "test.lan", "foo", p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) + failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv) + failIssuer := mustNebulaToken(t, "test.lan", "foo", p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv) - failSubject := mustNebulaToken(t, "", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) + failSubject := mustNebulaToken(t, "", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) // Not a nebula token jwk, err := generateJSONWebKey() @@ -761,7 +769,7 @@ func TestNebula_authorizeToken(t *testing.T) { IssuedAt: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), - Audience: []string{p.audiences.Sign[0]}, + Audience: []string{p.ctl.Audiences.Sign[0]}, } sshClaims := jose.Claims{ ID: "[REPLACEME]", @@ -770,7 +778,7 @@ func TestNebula_authorizeToken(t *testing.T) { IssuedAt: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), - Audience: []string{p.audiences.SSHSign[0]}, + Audience: []string{p.ctl.Audiences.SSHSign[0]}, } type args struct { @@ -785,14 +793,14 @@ func TestNebula_authorizeToken(t *testing.T) { want1 *jwtPayload wantErr bool }{ - {"ok x509", p, args{ok, p.audiences.Sign}, crt, &jwtPayload{ + {"ok x509", p, args{ok, p.ctl.Audiences.Sign}, crt, &jwtPayload{ Claims: x509Claims, SANs: []string{"10.1.0.1"}, }, false}, - {"ok x509 no sans", p, args{okNoSANs, p.audiences.Sign}, crt, &jwtPayload{ + {"ok x509 no sans", p, args{okNoSANs, p.ctl.Audiences.Sign}, crt, &jwtPayload{ Claims: x509Claims, }, false}, - {"ok ssh", p, args{okSSH, p.audiences.SSHSign}, crt, &jwtPayload{ + {"ok ssh", p, args{okSSH, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{ Claims: sshClaims, Step: &stepPayload{ SSH: &SignSSHOptions{ @@ -802,16 +810,16 @@ func TestNebula_authorizeToken(t *testing.T) { }, }, }, false}, - {"ok ssh no principals", p, args{okSSHNoOptions, p.audiences.SSHSign}, crt, &jwtPayload{ + {"ok ssh no principals", p, args{okSSHNoOptions, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{ Claims: sshClaims, }, false}, - {"fail parse", p, args{"bad.token", p.audiences.Sign}, nil, nil, true}, - {"fail header", p, args{simpleToken, p.audiences.Sign}, nil, nil, true}, - {"fail verify", p2, args{ok, p.audiences.Sign}, nil, nil, true}, - {"fail claims nbf", p, args{failNotBefore, p.audiences.Sign}, nil, nil, true}, - {"fail claims iss", p, args{failIssuer, p.audiences.Sign}, nil, nil, true}, - {"fail claims aud", p, args{failAudience, p.audiences.Sign}, nil, nil, true}, - {"fail claims sub", p, args{failSubject, p.audiences.Sign}, nil, nil, true}, + {"fail parse", p, args{"bad.token", p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail header", p, args{simpleToken, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail verify", p2, args{ok, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail claims nbf", p, args{failNotBefore, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail claims iss", p, args{failIssuer, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail claims aud", p, args{failAudience, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail claims sub", p, args{failSubject, p.ctl.Audiences.Sign}, nil, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go index ac1f2a25..1fc9bb4b 100644 --- a/authority/provisioner/oidc.go +++ b/authority/provisioner/oidc.go @@ -92,8 +92,7 @@ type OIDC struct { Options *Options `json:"options,omitempty"` configuration openIDConfiguration keyStore *keyStore - claimer *Claimer - getIdentityFunc GetIdentityFunc + ctl *Controller } func sanitizeEmail(email string) string { @@ -172,11 +171,6 @@ func (o *OIDC) Init(config Config) (err error) { } } - // Update claims with global ones - if o.claimer, err = NewClaimer(o.Claims, config.Claims); err != nil { - return err - } - // Decode and validate openid-configuration endpoint u, err := url.Parse(o.ConfigurationEndpoint) if err != nil { @@ -201,13 +195,8 @@ func (o *OIDC) Init(config Config) (err error) { return err } - // Set the identity getter if it exists, otherwise use the default. - if config.GetIdentityFunc == nil { - o.getIdentityFunc = DefaultIdentityFunc - } else { - o.getIdentityFunc = config.GetIdentityFunc - } - return nil + o.ctl, err = NewController(o, o.Claims, config) + return } // ValidatePayload validates the given token payload. @@ -359,10 +348,10 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID), - profileDefaultDuration(o.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(o.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()), + newValidityValidator(o.ctl.Claimer.MinTLSCertDuration(), o.ctl.Claimer.MaxTLSCertDuration()), }, nil } @@ -371,15 +360,12 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if o.claimer.IsDisableRenewal() { - return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner '%s'", o.GetName()) - } - return nil + return o.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !o.claimer.IsSSHCAEnabled() { + if !o.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner '%s'", o.GetName()) } claims, err := o.authorizeToken(token) @@ -394,7 +380,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption // Get the identity using either the default identityFunc or one injected // externally. Note that the PreferredUsername might be empty. // TBD: Would preferred_username present a safety issue here? - iden, err := o.getIdentityFunc(ctx, o, claims.Email) + iden, err := o.ctl.GetIdentity(ctx, claims.Email) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign") } @@ -445,11 +431,11 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption return append(signOptions, // Set the validity bounds if not set. - &sshDefaultDuration{o.claimer}, + &sshDefaultDuration{o.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{o.claimer}, + &sshCertValidityValidator{o.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index 7bf6ad7a..c1a94b1d 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -327,16 +327,16 @@ func TestOIDC_AuthorizeSign(t *testing.T) { switch v := o.(type) { case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeOIDC)) + assert.Equals(t, v.Type, TypeOIDC) assert.Equals(t, v.Name, tt.prov.GetName()) assert.Equals(t, v.CredentialID, tt.prov.ClientID) assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration()) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) case emailOnlyIdentity: assert.Equals(t, string(v), "name@smallstep.com") default: @@ -411,6 +411,7 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) { } func TestOIDC_AuthorizeRenew(t *testing.T) { + now := time.Now().Truncate(time.Second) p1, err := generateOIDC() assert.FatalError(t, err) p2, err := generateOIDC() @@ -419,7 +420,7 @@ func TestOIDC_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -432,8 +433,14 @@ func TestOIDC_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -478,7 +485,7 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p6.Claims = &Claims{EnableSSHCA: &disable} - p6.claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims) + p6.ctl.Claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims) assert.FatalError(t, err) // Update configuration endpoints and initialize @@ -494,10 +501,10 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { assert.FatalError(t, p4.Init(config)) assert.FatalError(t, p5.Init(config)) - p4.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { + p4.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { return &Identity{Usernames: []string{"max", "mariano"}}, nil } - p5.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { + p5.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { return nil, errors.New("force") } // Additional test needed for empty usernames and duplicate email and usernames @@ -527,8 +534,8 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - userDuration := p1.claimer.DefaultUserSSHCertDuration() - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SignSSHOptions{ CertType: "user", Principals: []string{"name", "name@smallstep.com"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 55ebe092..7438ea17 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -6,7 +6,6 @@ import ( "encoding/json" stderrors "errors" "net/url" - "regexp" "strings" "github.com/pkg/errors" @@ -47,6 +46,7 @@ var ErrAllowTokenReuse = stderrors.New("allow token reuse") // Audiences stores all supported audiences by request type. type Audiences struct { Sign []string + Renew []string Revoke []string SSHSign []string SSHRevoke []string @@ -57,6 +57,7 @@ type Audiences struct { // All returns all supported audiences across all request types in one list. func (a Audiences) All() (auds []string) { auds = a.Sign + auds = append(auds, a.Renew...) auds = append(auds, a.Revoke...) auds = append(auds, a.SSHSign...) auds = append(auds, a.SSHRevoke...) @@ -70,6 +71,7 @@ func (a Audiences) All() (auds []string) { func (a Audiences) WithFragment(fragment string) Audiences { ret := Audiences{ Sign: make([]string, len(a.Sign)), + Renew: make([]string, len(a.Renew)), Revoke: make([]string, len(a.Revoke)), SSHSign: make([]string, len(a.SSHSign)), SSHRevoke: make([]string, len(a.SSHRevoke)), @@ -83,6 +85,13 @@ func (a Audiences) WithFragment(fragment string) Audiences { ret.Sign[i] = s } } + for i, s := range a.Renew { + if u, err := url.Parse(s); err == nil { + ret.Renew[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() + } else { + ret.Renew[i] = s + } + } for i, s := range a.Revoke { if u, err := url.Parse(s); err == nil { ret.Revoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() @@ -210,6 +219,12 @@ type Config struct { // GetIdentityFunc is a function that returns an identity that will be // used by the provisioner to populate certificate attributes. GetIdentityFunc GetIdentityFunc + // AuthorizeRenewFunc is a function that returns nil if a given X.509 + // certificate can be renewed. + AuthorizeRenewFunc AuthorizeRenewFunc + // AuthorizeSSHRenewFunc is a function that returns nil if a given SSH + // certificate can be renewed. + AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc } type provisioner struct { @@ -278,32 +293,6 @@ func (l *List) UnmarshalJSON(data []byte) error { return nil } -var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$") - -// SanitizeSSHUserPrincipal grabs an email or a string with the format -// local@domain and returns a sanitized version of the local, valid to be used -// as a user name. If the email starts with a letter between a and z, the -// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`. -func SanitizeSSHUserPrincipal(email string) string { - if i := strings.LastIndex(email, "@"); i >= 0 { - email = email[:i] - } - return strings.Map(func(r rune) rune { - switch { - case r >= 'a' && r <= 'z': - return r - case r >= '0' && r <= '9': - return r - case r == '-': - return '-' - case r == '.': // drop dots - return -1 - default: - return '_' - } - }, strings.ToLower(email)) -} - type base struct{} // AuthorizeSign returns an unimplemented error. Provisioners should overwrite @@ -348,66 +337,12 @@ func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certif return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented") } -// Identity is the type representing an externally supplied identity that is used -// by provisioners to populate certificate fields. -type Identity struct { - Usernames []string `json:"usernames"` - Permissions `json:"permissions"` -} - // Permissions defines extra extensions and critical options to grant to an SSH certificate. type Permissions struct { Extensions map[string]string `json:"extensions"` CriticalOptions map[string]string `json:"criticalOptions"` } -// GetIdentityFunc is a function that returns an identity. -type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error) - -// DefaultIdentityFunc return a default identity depending on the provisioner -// type. For OIDC email is always present and the usernames might -// contain empty strings. -func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) { - switch k := p.(type) { - case *OIDC: - // OIDC principals would be: - // ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled - // 2. Sanitized local. - // 3. Raw local (if different). - // 4. Email address. - name := SanitizeSSHUserPrincipal(email) - if !sshUserRegex.MatchString(name) { - return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email) - } - usernames := []string{name} - if i := strings.LastIndex(email, "@"); i >= 0 { - usernames = append(usernames, email[:i]) - } - usernames = append(usernames, email) - return &Identity{ - Usernames: SanitizeStringSlices(usernames), - }, nil - default: - return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k) - } -} - -// SanitizeStringSlices removes duplicated an empty strings. -func SanitizeStringSlices(original []string) []string { - output := []string{} - seen := make(map[string]struct{}) - for _, entry := range original { - if entry == "" { - continue - } - if _, value := seen[entry]; !value { - seen[entry] = struct{}{} - output = append(output, entry) - } - } - return output -} - // MockProvisioner for testing type MockProvisioner struct { Mret1, Mret2, Mret3 interface{} diff --git a/authority/provisioner/scep.go b/authority/provisioner/scep.go index 5d67762c..f4cffd78 100644 --- a/authority/provisioner/scep.go +++ b/authority/provisioner/scep.go @@ -11,28 +11,30 @@ import ( // SCEP provisioning flow type SCEP struct { *base - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` ForceCN bool `json:"forceCN,omitempty"` ChallengePassword string `json:"challenge,omitempty"` Capabilities []string `json:"capabilities,omitempty"` + // IncludeRoot makes the provisioner return the CA root in addition to the // intermediate in the GetCACerts response IncludeRoot bool `json:"includeRoot,omitempty"` + // MinimumPublicKeyLength is the minimum length for public keys in CSRs MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"` + // Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7 // at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63 // Defaults to 0, being DES-CBC - EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` - Options *Options `json:"options,omitempty"` - Claims *Claims `json:"claims,omitempty"` - claimer *Claimer + EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` + Options *Options `json:"options,omitempty"` + Claims *Claims `json:"claims,omitempty"` secretChallengePassword string encryptionAlgorithm int + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -77,7 +79,7 @@ func (s *SCEP) GetOptions() *Options { // DefaultTLSCertDuration returns the default TLS cert duration enforced by // the provisioner. func (s *SCEP) DefaultTLSCertDuration() time.Duration { - return s.claimer.DefaultTLSCertDuration() + return s.ctl.Claimer.DefaultTLSCertDuration() } // Init initializes and validates the fields of a SCEP type. @@ -90,11 +92,6 @@ func (s *SCEP) Init(config Config) (err error) { return errors.New("provisioner name cannot be empty") } - // Update claims with global ones - if s.claimer, err = NewClaimer(s.Claims, config.Claims); err != nil { - return err - } - // Mask the actual challenge value, so it won't be marshaled s.secretChallengePassword = s.ChallengePassword s.ChallengePassword = "*** redacted ***" @@ -115,7 +112,8 @@ func (s *SCEP) Init(config Config) (err error) { // TODO: add other, SCEP specific, options? - return err + s.ctl, err = NewController(s, s.Claims, config) + return } // AuthorizeSign does not do any verification, because all verification is handled @@ -126,10 +124,10 @@ func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // modifiers / withOptions newProvisionerExtensionOption(TypeSCEP, s.Name, ""), newForceCNOption(s.ForceCN), - profileDefaultDuration(s.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(s.ctl.Claimer.DefaultTLSCertDuration()), // validators newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength), - newValidityValidator(s.claimer.MinTLSCertDuration(), s.claimer.MaxTLSCertDuration()), + newValidityValidator(s.ctl.Claimer.MinTLSCertDuration(), s.ctl.Claimer.MaxTLSCertDuration()), }, nil } diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index 34b2e99b..80dfc66e 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -6,7 +6,6 @@ import ( "crypto/rsa" "crypto/x509" "crypto/x509/pkix" - "encoding/asn1" "encoding/json" "net" "net/http" @@ -14,7 +13,6 @@ import ( "reflect" "time" - "github.com/pkg/errors" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/x509util" @@ -404,17 +402,12 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error { return nil } -var ( - stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} - stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) -) - -type stepProvisionerASN1 struct { - Type int - Name []byte - CredentialID []byte - KeyValuePairs []string `asn1:"optional,omitempty"` -} +// type stepProvisionerASN1 struct { +// Type int +// Name []byte +// CredentialID []byte +// KeyValuePairs []string `asn1:"optional,omitempty"` +// } type forceCNOption struct { ForceCN bool @@ -441,23 +434,22 @@ func (o *forceCNOption) Modify(cert *x509.Certificate, _ SignOptions) error { } type provisionerExtensionOption struct { - Type int - Name string - CredentialID string - KeyValuePairs []string + Extension } func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValuePairs ...string) *provisionerExtensionOption { return &provisionerExtensionOption{ - Type: int(typ), - Name: name, - CredentialID: credentialID, - KeyValuePairs: keyValuePairs, + Extension: Extension{ + Type: typ, + Name: name, + CredentialID: credentialID, + KeyValuePairs: keyValuePairs, + }, } } func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error { - ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID, o.KeyValuePairs...) + ext, err := o.ToExtension() if err != nil { return errs.NewError(http.StatusInternalServerError, err, "error creating certificate") } @@ -471,20 +463,3 @@ func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOption cert.ExtraExtensions = append([]pkix.Extension{ext}, cert.ExtraExtensions...) return nil } - -func createProvisionerExtension(typ int, name, credentialID string, keyValuePairs ...string) (pkix.Extension, error) { - b, err := asn1.Marshal(stepProvisionerASN1{ - Type: typ, - Name: []byte(name), - CredentialID: []byte(credentialID), - KeyValuePairs: keyValuePairs, - }) - if err != nil { - return pkix.Extension{}, errors.Wrap(err, "error marshaling provisioner extension") - } - return pkix.Extension{ - Id: stepOIDProvisioner, - Critical: false, - Value: b, - }, nil -} diff --git a/authority/provisioner/sign_options_test.go b/authority/provisioner/sign_options_test.go index 32b8e3c6..fc4d675a 100644 --- a/authority/provisioner/sign_options_test.go +++ b/authority/provisioner/sign_options_test.go @@ -636,18 +636,18 @@ func Test_newProvisionerExtension_Option(t *testing.T) { valid: func(cert *x509.Certificate) { if assert.Len(t, 1, cert.ExtraExtensions) { ext := cert.ExtraExtensions[0] - assert.Equals(t, ext.Id, stepOIDProvisioner) + assert.Equals(t, ext.Id, StepOIDProvisioner) } }, } }, "ok/prepend": func() test { return test{ - cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: stepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}}, + cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: StepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}}, valid: func(cert *x509.Certificate) { if assert.Len(t, 3, cert.ExtraExtensions) { ext := cert.ExtraExtensions[0] - assert.Equals(t, ext.Id, stepOIDProvisioner) + assert.Equals(t, ext.Id, StepOIDProvisioner) assert.False(t, ext.Critical) } }, diff --git a/authority/provisioner/sign_ssh_options_test.go b/authority/provisioner/sign_ssh_options_test.go index b59d6945..28a35639 100644 --- a/authority/provisioner/sign_ssh_options_test.go +++ b/authority/provisioner/sign_ssh_options_test.go @@ -685,7 +685,7 @@ func Test_sshCertDefaultValidator_Valid(t *testing.T) { func Test_sshCertValidityValidator(t *testing.T) { p, err := generateX5C(nil) assert.FatalError(t, err) - v := sshCertValidityValidator{p.claimer} + v := sshCertValidityValidator{p.ctl.Claimer} n := now() tests := []struct { name string @@ -806,7 +806,7 @@ func Test_sshValidityModifier(t *testing.T) { tests := map[string]func() test{ "fail/type-not-set": func() test { return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)}, cert: &ssh.Certificate{ ValidAfter: uint64(n.Unix()), ValidBefore: uint64(n.Add(8 * time.Hour).Unix()), @@ -816,7 +816,7 @@ func Test_sshValidityModifier(t *testing.T) { }, "fail/type-not-recognized": func() test { return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)}, cert: &ssh.Certificate{ CertType: 4, ValidAfter: uint64(n.Unix()), @@ -827,7 +827,7 @@ func Test_sshValidityModifier(t *testing.T) { }, "fail/requested-validAfter-after-limit": func() test { return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: uint64(n.Add(2 * time.Hour).Unix()), @@ -838,7 +838,7 @@ func Test_sshValidityModifier(t *testing.T) { }, "fail/requested-validBefore-after-limit": func() test { return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: uint64(n.Unix()), @@ -850,7 +850,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/no-limit": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer}, cert: &ssh.Certificate{ CertType: 1, }, @@ -863,7 +863,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/defaults": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer}, cert: &ssh.Certificate{ CertType: 1, }, @@ -876,7 +876,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/valid-requested-validBefore": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, @@ -891,7 +891,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/empty-requested-validBefore-limit-after-default": func() test { va := uint64(n.Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(24 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(24 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, @@ -905,7 +905,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/empty-requested-validBefore-limit-before-default": func() test { va := uint64(n.Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, diff --git a/authority/provisioner/sshpop.go b/authority/provisioner/sshpop.go index 3039d2a3..9de0fca2 100644 --- a/authority/provisioner/sshpop.go +++ b/authority/provisioner/sshpop.go @@ -29,8 +29,7 @@ type SSHPOP struct { Type string `json:"type"` Name string `json:"name"` Claims *Claims `json:"claims,omitempty"` - claimer *Claimer - audiences Audiences + ctl *Controller sshPubKeys *SSHKeys } @@ -83,7 +82,7 @@ func (p *SSHPOP) GetEncryptedKey() (string, string, bool) { } // Init initializes and validates the fields of a SSHPOP type. -func (p *SSHPOP) Init(config Config) error { +func (p *SSHPOP) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -93,15 +92,11 @@ func (p *SSHPOP) Init(config Config) error { return errors.New("provisioner public SSH validation keys cannot be empty") } - // Update claims with global ones - var err error - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) p.sshPubKeys = config.SSHKeys - return nil + + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken performs common jwt authorization actions and returns the @@ -109,7 +104,7 @@ func (p *SSHPOP) Init(config Config) error { // e.g. a Sign request will auth/validate different fields than a Revoke request. // // Checking for certificate revocation has been moved to the authority package. -func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) { +func (p *SSHPOP) authorizeToken(token string, audiences []string, checkValidity bool) (*sshPOPPayload, error) { sshCert, jwt, err := ExtractSSHPOPCert(token) if err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, @@ -117,13 +112,18 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa } // Check validity period of the certificate. - n := time.Now() - if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) { - return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future") - } - if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) { - return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past") + // + // Controller.AuthorizeSSHRenew will validate this on the renewal flow. + if checkValidity { + unixNow := time.Now().Unix() + if after := int64(sshCert.ValidAfter); after < 0 || unixNow < int64(sshCert.ValidAfter) { + return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future") + } + if before := int64(sshCert.ValidBefore); sshCert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) { + return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past") + } } + sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey) if !ok { return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey") @@ -186,7 +186,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa // AuthorizeSSHRevoke validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { - claims, err := p.authorizeToken(token, p.audiences.SSHRevoke) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke, true) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") } @@ -199,22 +199,20 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { // AuthorizeSSHRenew validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { - claims, err := p.authorizeToken(token, p.audiences.SSHRenew) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRenew, false) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") } if claims.sshCert.CertType != ssh.HostCert { return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") } - - return claims.sshCert, nil - + return claims.sshCert, p.ctl.AuthorizeSSHRenew(ctx, claims.sshCert) } // AuthorizeSSHRekey validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) { - claims, err := p.authorizeToken(token, p.audiences.SSHRekey) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRekey, true) if err != nil { return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") } @@ -225,11 +223,10 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, }, nil - } // ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go index da036864..b548fe71 100644 --- a/authority/provisioner/sshpop_test.go +++ b/authority/provisioner/sshpop_test.go @@ -38,6 +38,7 @@ func TestSSHPOP_Getters(t *testing.T) { } func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { + now := time.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) if err != nil { return nil, nil, err @@ -46,6 +47,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, if err != nil { return nil, nil, err } + if cert.ValidAfter == 0 { + cert.ValidAfter = uint64(now.Unix()) + } + if cert.ValidBefore == 0 { + cert.ValidBefore = uint64(now.Add(time.Hour).Unix()) + } if err := cert.SignCert(rand.Reader, signer); err != nil { return nil, nil, err } @@ -207,7 +214,7 @@ func TestSSHPOP_authorizeToken(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { + if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign, true); err != nil { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tc.code) @@ -455,7 +462,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { case *sshDefaultPublicKeyValidator: case *sshCertDefaultValidator: case *sshCertValidityValidator: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/testdata/certs/bad-extension.crt b/authority/provisioner/testdata/certs/bad-extension.crt new file mode 100644 index 00000000..ecce0f28 --- /dev/null +++ b/authority/provisioner/testdata/certs/bad-extension.crt @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDeTCCAx+gAwIBAgIRAOTItW2pYuSU+PkmLW090iUwCgYIKoZIzj0EAwIwJDEi +MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjUy +MjBaFw0yMjAzMTIyMjUzMjBaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs +aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg +U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0 +ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAE/9vvOZ1Zzysnf3VeGyotMJEMZdAborB36Ah5QL/3yQNMRWIc +pv9Dwx19pHw7SquVE8jIaPPJSjaeWnfMPDYDxaOCAbcwggGzMA4GA1UdDwEB/wQE +AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw +ADAdBgNVHQ4EFgQUkJUg6AsqWlqTZt6BHidRMwh1vKYwHwYDVR0jBBgwFoAUDpTg +d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB +hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu +Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh +NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA +ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G +A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2 +LXNlcnZlci1nMy5jcmwwFwYMKwYBBAGCpGTGKEABBAdmb29vYmFyMAoGCCqGSM49 +BAMCA0gAMEUCIQCWYqOuk4bLkVVeHvo3P8TlJJ3fw6ijDDLstvdrQqAl5wIgEjSY +wVcR649Oc8PJGh/43Kpx0+4OTYPQrD/JqphVF7g= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/authority/provisioner/testdata/certs/good-extension.crt b/authority/provisioner/testdata/certs/good-extension.crt new file mode 100644 index 00000000..103353a7 --- /dev/null +++ b/authority/provisioner/testdata/certs/good-extension.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDujCCA2GgAwIBAgIRAM5celDKTTqAGycljO7FZdEwCgYIKoZIzj0EAwIwJDEi +MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjQx +MDRaFw0yMjAzMTIyMjQyMDRaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs +aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg +U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0 +ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAEkXffZYlSJRMxJrZHmUpEMC4jQYCkF86mLJY0iLZ8k00N/xF0 +4rAGwzTU/l9tfRpNl+z/XfMMWPXS0Q8NU/o4S6OCAfkwggH1MA4GA1UdDwEB/wQE +AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw +ADAdBgNVHQ4EFgQUL3sSlYW8Tf2l2P+gFTdn5wsUjfgwHwYDVR0jBBgwFoAUDpTg +d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB +hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu +Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh +NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA +ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G +A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2 +LXNlcnZlci1nMy5jcmwwWQYMKwYBBAGCpGTGKEABBEkwRwIBAQQVbWFyaWFub0Bz +bWFsbHN0ZXAuY29tBCtudmduUjh3U3pwVWxydF90QzNtdnJod2hCeDlZN1QxV0xf +SmpjRlZXWUJRMAoGCCqGSM49BAMCA0cAMEQCIE6umrhSbeQWWVK5cWBvXj5c0cGB +bUF0rNw/dsaCaWcwAiAKSkmjhsC63DVPXPCNUki90YgVovO69foO1ZaB43lx5w== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index fe2678fc..669693d6 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -24,20 +24,22 @@ import ( ) var ( - defaultDisableRenewal = false - defaultEnableSSHCA = true - globalProvisionerClaims = Claims{ - MinTLSDur: &Duration{5 * time.Minute}, - MaxTLSDur: &Duration{24 * time.Hour}, - DefaultTLSDur: &Duration{24 * time.Hour}, - DisableRenewal: &defaultDisableRenewal, - MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs - MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, - DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour}, - MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs - MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, - DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, - EnableSSHCA: &defaultEnableSSHCA, + defaultDisableRenewal = false + defaultAllowRenewAfterExpiry = false + defaultEnableSSHCA = true + globalProvisionerClaims = Claims{ + MinTLSDur: &Duration{5 * time.Minute}, + MaxTLSDur: &Duration{24 * time.Hour}, + DefaultTLSDur: &Duration{24 * time.Hour}, + MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs + MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, + DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour}, + MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs + MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, + DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, + EnableSSHCA: &defaultEnableSSHCA, + DisableRenewal: &defaultDisableRenewal, + AllowRenewAfterExpiry: &defaultAllowRenewAfterExpiry, } testAudiences = Audiences{ Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"}, @@ -172,19 +174,18 @@ func generateJWK() (*JWK, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - return &JWK{ + + p := &JWK{ Name: name, Type: "JWK", Key: &public, EncryptedKey: encrypted, Claims: &globalProvisionerClaims, - audiences: testAudiences, - claimer: claimer, - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { @@ -205,23 +206,21 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } pubKeys := []interface{}{fooPub, barPub} if inputPubKey != nil { pubKeys = append(pubKeys, inputPubKey) } - return &K8sSA{ - Name: K8sSAName, - Type: "K8sSA", - Claims: &globalProvisionerClaims, - audiences: testAudiences, - claimer: claimer, - pubKeys: pubKeys, - }, nil + p := &K8sSA{ + Name: K8sSAName, + Type: "K8sSA", + Claims: &globalProvisionerClaims, + pubKeys: pubKeys, + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateSSHPOP() (*SSHPOP, error) { @@ -229,11 +228,6 @@ func generateSSHPOP() (*SSHPOP, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub") if err != nil { return nil, err @@ -251,17 +245,19 @@ func generateSSHPOP() (*SSHPOP, error) { return nil, err } - return &SSHPOP{ - Name: name, - Type: "SSHPOP", - Claims: &globalProvisionerClaims, - audiences: testAudiences, - claimer: claimer, + p := &SSHPOP{ + Name: name, + Type: "SSHPOP", + Claims: &globalProvisionerClaims, sshPubKeys: &SSHKeys{ UserKeys: []ssh.PublicKey{userKey}, HostKeys: []ssh.PublicKey{hostKey}, }, - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateX5C(root []byte) (*X5C, error) { @@ -283,11 +279,6 @@ M46l92gdOozT if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - rootPool := x509.NewCertPool() var ( @@ -305,15 +296,17 @@ M46l92gdOozT } rootPool.AddCert(cert) } - return &X5C{ - Name: name, - Type: "X5C", - Roots: root, - Claims: &globalProvisionerClaims, - audiences: testAudiences, - claimer: claimer, - rootPool: rootPool, - }, nil + p := &X5C{ + Name: name, + Type: "X5C", + Roots: root, + Claims: &globalProvisionerClaims, + rootPool: rootPool, + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateOIDC() (*OIDC, error) { @@ -333,11 +326,7 @@ func generateOIDC() (*OIDC, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - return &OIDC{ + p := &OIDC{ Name: name, Type: "OIDC", ClientID: clientID, @@ -351,8 +340,11 @@ func generateOIDC() (*OIDC, error) { keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, expiry: time.Now().Add(24 * time.Hour), }, - claimer: claimer, - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateGCP() (*GCP, error) { @@ -368,23 +360,21 @@ func generateGCP() (*GCP, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - return &GCP{ + p := &GCP{ Type: "GCP", Name: name, ServiceAccounts: []string{serviceAccount}, Claims: &globalProvisionerClaims, - claimer: claimer, config: newGCPConfig(), keyStore: &keyStore{ keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, expiry: time.Now().Add(24 * time.Hour), }, - audiences: testAudiences.WithFragment("gcp/" + name), - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences.WithFragment("gcp/" + name), + }) + return p, err } func generateAWS() (*AWS, error) { @@ -396,10 +386,6 @@ func generateAWS() (*AWS, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } block, _ := pem.Decode([]byte(awsTestCertificate)) if block == nil || block.Type != "CERTIFICATE" { return nil, errors.New("error decoding AWS certificate") @@ -408,13 +394,12 @@ func generateAWS() (*AWS, error) { if err != nil { return nil, errors.Wrap(err, "error parsing AWS certificate") } - return &AWS{ + p := &AWS{ Type: "AWS", Name: name, Accounts: []string{accountID}, Claims: &globalProvisionerClaims, IMDSVersions: []string{"v2", "v1"}, - claimer: claimer, config: &awsConfig{ identityURL: awsIdentityURL, signatureURL: awsSignatureURL, @@ -423,8 +408,11 @@ func generateAWS() (*AWS, error) { certificates: []*x509.Certificate{cert}, signatureAlgorithm: awsSignatureAlgorithm, }, - audiences: testAudiences.WithFragment("aws/" + name), - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences.WithFragment("aws/" + name), + }) + return p, err } func generateAWSWithServer() (*AWS, *httptest.Server, error) { @@ -505,10 +493,6 @@ func generateAWSV1Only() (*AWS, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } block, _ := pem.Decode([]byte(awsTestCertificate)) if block == nil || block.Type != "CERTIFICATE" { return nil, errors.New("error decoding AWS certificate") @@ -517,13 +501,12 @@ func generateAWSV1Only() (*AWS, error) { if err != nil { return nil, errors.Wrap(err, "error parsing AWS certificate") } - return &AWS{ + p := &AWS{ Type: "AWS", Name: name, Accounts: []string{accountID}, Claims: &globalProvisionerClaims, IMDSVersions: []string{"v1"}, - claimer: claimer, config: &awsConfig{ identityURL: awsIdentityURL, signatureURL: awsSignatureURL, @@ -532,8 +515,11 @@ func generateAWSV1Only() (*AWS, error) { certificates: []*x509.Certificate{cert}, signatureAlgorithm: awsSignatureAlgorithm, }, - audiences: testAudiences.WithFragment("aws/" + name), - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences.WithFragment("aws/" + name), + }) + return p, err } func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) { @@ -600,21 +586,16 @@ func generateAzure() (*Azure, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } jwk, err := generateJSONWebKey() if err != nil { return nil, err } - return &Azure{ + p := &Azure{ Type: "Azure", Name: name, TenantID: tenantID, Audience: azureDefaultAudience, Claims: &globalProvisionerClaims, - claimer: claimer, config: newAzureConfig(tenantID), oidcConfig: openIDConfiguration{ Issuer: "https://sts.windows.net/" + tenantID + "/", @@ -624,7 +605,11 @@ func generateAzure() (*Azure, error) { keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, expiry: time.Now().Add(24 * time.Hour), }, - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateAzureWithServer() (*Azure, *httptest.Server, error) { diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index aa44245d..6f534c76 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -26,15 +26,14 @@ type x5cPayload struct { // signature requests. type X5C struct { *base - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - Roots []byte `json:"roots"` - Claims *Claims `json:"claims,omitempty"` - Options *Options `json:"options,omitempty"` - claimer *Claimer - audiences Audiences - rootPool *x509.CertPool + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` + Roots []byte `json:"roots"` + Claims *Claims `json:"claims,omitempty"` + Options *Options `json:"options,omitempty"` + ctl *Controller + rootPool *x509.CertPool } // GetID returns the provisioner unique identifier. The name and credential id @@ -86,7 +85,7 @@ func (p *X5C) GetEncryptedKey() (string, string, bool) { } // Init initializes and validates the fields of a X5C type. -func (p *X5C) Init(config Config) error { +func (p *X5C) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -119,14 +118,9 @@ func (p *X5C) Init(config Config) error { return errors.Errorf("no x509 certificates found in roots attribute for provisioner '%s'", p.GetName()) } - // Update claims with global ones - var err error - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) - return nil + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken performs common jwt authorization actions and returns the @@ -189,13 +183,13 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err // AuthorizeRevoke returns an error if the provisioner does not have rights to // revoke the certificate with serial number in the `sub` property. func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error { - _, err := p.authorizeToken(token, p.audiences.Revoke) + _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - claims, err := p.authorizeToken(token, p.audiences.Sign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign") } @@ -227,31 +221,30 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeX5C, p.Name, ""), - profileLimitDuration{p.claimer.DefaultTLSCertDuration(), - claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter}, + profileLimitDuration{ + p.ctl.Claimer.DefaultTLSCertDuration(), + claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter, + }, // validators commonNameValidator(claims.Subject), defaultSANsValidator(claims.SANs), defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } // AuthorizeRenew returns an error if the renewal is disabled. func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()) } - claims, err := p.authorizeToken(token, p.audiences.SSHSign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign") } @@ -314,11 +307,11 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, return append(signOptions, // Checks the validity bounds, and set the validity if has not been set. - &sshLimitDuration{p.claimer, claims.chains[0][0].NotAfter}, + &sshLimitDuration{p.ctl.Claimer, claims.chains[0][0].NotAfter}, // Validate public key. &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 2959f8c6..84e29b48 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -2,6 +2,7 @@ package provisioner import ( "context" + "crypto/x509" "net/http" "testing" "time" @@ -69,7 +70,7 @@ func TestX5C_Init(t *testing.T) { }, "fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo"), audiences: testAudiences}, + p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo")}, err: errors.Errorf("no x509 certificates found in roots attribute for provisioner 'foo'"), } }, @@ -141,7 +142,7 @@ M46l92gdOozT } } else { if assert.Nil(t, tc.err) { - assert.Equals(t, tc.p.audiences, config.Audiences.WithFragment(tc.p.GetID())) + assert.Equals(t, *tc.p.ctl.Audiences, config.Audiences.WithFragment(tc.p.GetID())) if tc.extraValid != nil { assert.Nil(t, tc.extraValid(tc.p)) } @@ -468,14 +469,14 @@ func TestX5C_AuthorizeSign(t *testing.T) { switch v := o.(type) { case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeX5C)) + assert.Equals(t, v.Type, TypeX5C) assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.CredentialID, "") assert.Len(t, 0, v.KeyValuePairs) case profileLimitDuration: - assert.Equals(t, v.def, tc.p.claimer.DefaultTLSCertDuration()) + assert.Equals(t, v.def, tc.p.ctl.Claimer.DefaultTLSCertDuration()) - claims, err := tc.p.authorizeToken(tc.token, tc.p.audiences.Sign) + claims, err := tc.p.authorizeToken(tc.token, tc.p.ctl.Audiences.Sign) assert.FatalError(t, err) assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter) case commonNameValidator: @@ -484,8 +485,8 @@ func TestX5C_AuthorizeSign(t *testing.T) { case defaultSANsValidator: assert.Equals(t, []string(v), tc.sans) case *validityValidator: - assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } @@ -551,6 +552,7 @@ func TestX5C_AuthorizeRevoke(t *testing.T) { } func TestX5C_AuthorizeRenew(t *testing.T) { + now := time.Now().Truncate(time.Second) type test struct { p *X5C code int @@ -563,12 +565,12 @@ func TestX5C_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p.Claims = &Claims{DisableRenewal: &disable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ p: p, code: http.StatusUnauthorized, - err: errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName()), + err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { @@ -582,7 +584,10 @@ func TestX5C_AuthorizeRenew(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if err := tc.p.AuthorizeRenew(context.Background(), nil); err != nil { + if err := tc.p.AuthorizeRenew(context.Background(), &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }); err != nil { if assert.NotNil(t, tc.err) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") @@ -618,7 +623,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { // disable sshCA enable := false p.Claims = &Claims{EnableSSHCA: &enable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ p: p, @@ -774,10 +779,10 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { case sshCertDefaultsModifier: assert.Equals(t, SignSSHOptions(v), SignSSHOptions{CertType: SSHUserCert}) case *sshLimitDuration: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter) case *sshCertValidityValidator: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc: default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) diff --git a/authority/provisioners.go b/authority/provisioners.go index 8dc27c6a..a6ac5aa8 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -108,7 +108,9 @@ func (a *Authority) generateProvisionerConfig(ctx context.Context) (provisioner. UserKeys: sshKeys.UserKeys, HostKeys: sshKeys.HostKeys, }, - GetIdentityFunc: a.getIdentityFunc, + GetIdentityFunc: a.getIdentityFunc, + AuthorizeRenewFunc: a.authorizeRenewFunc, + AuthorizeSSHRenewFunc: a.authorizeSSHRenewFunc, }, nil } @@ -435,7 +437,8 @@ func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) { } pc := &provisioner.Claims{ - DisableRenewal: &c.DisableRenewal, + DisableRenewal: &c.DisableRenewal, + AllowRenewAfterExpiry: &c.AllowRenewAfterExpiry, } var err error @@ -473,12 +476,18 @@ func claimsToLinkedca(c *provisioner.Claims) *linkedca.Claims { } disableRenewal := config.DefaultDisableRenewal + allowRenewAfterExpiry := config.DefaultAllowRenewAfterExpiry + if c.DisableRenewal != nil { disableRenewal = *c.DisableRenewal } + if c.AllowRenewAfterExpiry != nil { + allowRenewAfterExpiry = *c.AllowRenewAfterExpiry + } lc := &linkedca.Claims{ - DisableRenewal: disableRenewal, + DisableRenewal: disableRenewal, + AllowRenewAfterExpiry: allowRenewAfterExpiry, } if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil { diff --git a/authority/tls_test.go b/authority/tls_test.go index aeadaf0f..6ccf02ca 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -757,7 +757,7 @@ func TestAuthority_Renew(t *testing.T) { now := time.Now().UTC() nb1 := now.Add(-time.Minute * 7) - na1 := now + na1 := now.Add(time.Hour) so := &provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(nb1), NotAfter: provisioner.NewTimeDuration(na1), @@ -798,7 +798,20 @@ func TestAuthority_Renew(t *testing.T) { "fail/unauthorized": func() (*renewTest, error) { return &renewTest{ cert: certNoRenew, - err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"), + err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"), + code: http.StatusUnauthorized, + }, nil + }, + "fail/WithAuthorizeRenewFunc": func() (*renewTest, error) { + aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error { + return errs.Unauthorized("not authorized") + })) + aa.x509CAService = a.x509CAService + aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template + return &renewTest{ + auth: aa, + cert: cert, + err: errors.New("authority.Rekey: authority.authorizeRenew: not authorized"), code: http.StatusUnauthorized, }, nil }, @@ -820,6 +833,17 @@ func TestAuthority_Renew(t *testing.T) { cert: cert, }, nil }, + "ok/WithAuthorizeRenewFunc": func() (*renewTest, error) { + aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error { + return nil + })) + aa.x509CAService = a.x509CAService + aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template + return &renewTest{ + auth: aa, + cert: cert, + }, nil + }, } for name, genTestCase := range tests { @@ -856,7 +880,7 @@ func TestAuthority_Renew(t *testing.T) { expiry := now.Add(time.Minute * 7) assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute))) - assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) + assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour))) tmplt := a.config.AuthorityConfig.Template assert.Equals(t, leaf.Subject.String(), @@ -956,7 +980,7 @@ func TestAuthority_Rekey(t *testing.T) { now := time.Now().UTC() nb1 := now.Add(-time.Minute * 7) - na1 := now + na1 := now.Add(time.Hour) so := &provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(nb1), NotAfter: provisioner.NewTimeDuration(na1), @@ -998,7 +1022,7 @@ func TestAuthority_Rekey(t *testing.T) { "fail/unauthorized": func() (*renewTest, error) { return &renewTest{ cert: certNoRenew, - err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"), + err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"), code: http.StatusUnauthorized, }, nil }, @@ -1063,7 +1087,7 @@ func TestAuthority_Rekey(t *testing.T) { expiry := now.Add(time.Minute * 7) assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute))) - assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) + assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour))) tmplt := a.config.AuthorityConfig.Template assert.Equals(t, leaf.Subject.String(), diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 9482d657..0e16bd7d 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -408,6 +408,7 @@ func TestBootstrapClientServerRotation(t *testing.T) { server.ServeTLS(listener, "", "") }() defer server.Close() + time.Sleep(1 * time.Second) // Create bootstrap client token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") @@ -419,7 +420,6 @@ func TestBootstrapClientServerRotation(t *testing.T) { // doTest does a request that requires mTLS doTest := func(client *http.Client) error { - time.Sleep(1 * time.Second) // test with ca resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody) if err != nil { diff --git a/ca/client.go b/ca/client.go index 6bc48a42..3a36fcd6 100644 --- a/ca/client.go +++ b/ca/client.go @@ -563,6 +563,11 @@ func (c *Client) retryOnError(r *http.Response) bool { return false } +// GetCaURL returns the configured CA url. +func (c *Client) GetCaURL() string { + return c.endpoint.String() +} + // GetRootCAs returns the RootCAs certificate pool from the configured // transport. func (c *Client) GetRootCAs() *x509.CertPool { @@ -723,6 +728,36 @@ retry: return &sign, nil } +// RenewWithToken performs the renew request to the CA with the given +// authorization token and returns the api.SignResponse struct. This method is +// generally used to renew an expired certificate. +func (c *Client) RenewWithToken(token string) (*api.SignResponse, error) { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"}) + req, err := http.NewRequest("POST", u.String(), http.NoBody) + if err != nil { + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error creating request") + } + req.Header.Add("Authorization", "Bearer "+token) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; client POST %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readError(resp.Body) + } + var sign api.SignResponse + if err := readJSON(resp.Body, &sign); err != nil { + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error reading %s", u) + } + return &sign, nil +} + // Rekey performs the rekey request to the CA and returns the api.SignResponse // struct. func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) { diff --git a/ca/client_test.go b/ca/client_test.go index 29a4848d..a00ca1cf 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -529,6 +529,74 @@ func TestClient_Renew(t *testing.T) { } } +func TestClient_RenewWithToken(t *testing.T) { + ok := &api.SignResponse{ + ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + CertChainPEM: []api.Certificate{ + {Certificate: parseCertificate(certPEM)}, + {Certificate: parseCertificate(rootPEM)}, + }, + } + + tests := []struct { + name string + response interface{} + responseCode int + wantErr bool + err error + }{ + {"ok", ok, 200, false, nil}, + {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, + {"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, + {"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.Header.Get("Authorization") != "Bearer token" { + api.JSONStatus(w, errs.InternalServer("force"), 500) + } else { + api.JSONStatus(w, tt.response, tt.responseCode) + } + }) + + got, err := c.RenewWithToken("token") + if (err != nil) != tt.wantErr { + fmt.Printf("%+v", err) + t.Errorf("Client.RenewWithToken() error = %v, wantErr %v", err, tt.wantErr) + return + } + + switch { + case err != nil: + if got != nil { + t.Errorf("Client.RenewWithToken() = %v, want nil", got) + } + + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.responseCode) + assert.HasPrefix(t, err.Error(), tt.err.Error()) + default: + if !reflect.DeepEqual(got, tt.response) { + t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response) + } + } + }) + } +} + func TestClient_Rekey(t *testing.T) { ok := &api.SignResponse{ ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, @@ -1060,3 +1128,28 @@ func TestClient_SSHBastion(t *testing.T) { }) } } + +func TestClient_GetCaURL(t *testing.T) { + tests := []struct { + name string + caURL string + want string + }{ + {"ok", "https://ca.com", "https://ca.com"}, + {"ok no schema", "ca.com", "https://ca.com"}, + {"ok with port", "https://ca.com:9000", "https://ca.com:9000"}, + {"ok with version", "https://ca.com/1.0", "https://ca.com/1.0"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + if got := c.GetCaURL(); got != tt.want { + t.Errorf("Client.GetCaURL() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/ca/testdata/ca.json b/ca/testdata/ca.json index d40325e8..2a336f24 100644 --- a/ca/testdata/ca.json +++ b/ca/testdata/ca.json @@ -6,7 +6,7 @@ "password": "password", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.3, diff --git a/ca/testdata/federated-ca.json b/ca/testdata/federated-ca.json index 342adfcf..0b1c6c8d 100644 --- a/ca/testdata/federated-ca.json +++ b/ca/testdata/federated-ca.json @@ -6,7 +6,7 @@ "password": "asdf", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, diff --git a/ca/testdata/rotate-ca-0.json b/ca/testdata/rotate-ca-0.json index 20dd603a..aa9353ed 100644 --- a/ca/testdata/rotate-ca-0.json +++ b/ca/testdata/rotate-ca-0.json @@ -5,7 +5,7 @@ "password": "password", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, diff --git a/ca/testdata/rotate-ca-1.json b/ca/testdata/rotate-ca-1.json index b038f694..c78ba035 100644 --- a/ca/testdata/rotate-ca-1.json +++ b/ca/testdata/rotate-ca-1.json @@ -5,7 +5,7 @@ "password": "password", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, diff --git a/ca/testdata/rotate-ca-2.json b/ca/testdata/rotate-ca-2.json index 7ec965d0..2db1c992 100644 --- a/ca/testdata/rotate-ca-2.json +++ b/ca/testdata/rotate-ca-2.json @@ -5,7 +5,7 @@ "password": "asdf", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, diff --git a/ca/testdata/rotate-ca-3.json b/ca/testdata/rotate-ca-3.json index 968da6ba..50f4a118 100644 --- a/ca/testdata/rotate-ca-3.json +++ b/ca/testdata/rotate-ca-3.json @@ -5,7 +5,7 @@ "password": "asdf", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, diff --git a/go.mod b/go.mod index e6696529..6033d05e 100644 --- a/go.mod +++ b/go.mod @@ -34,8 +34,8 @@ require ( github.com/urfave/cli v1.22.4 go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 go.step.sm/cli-utils v0.7.0 - go.step.sm/crypto v0.15.0 - go.step.sm/linkedca v0.10.0 + go.step.sm/crypto v0.15.3 + go.step.sm/linkedca v0.11.0 golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect diff --git a/go.sum b/go.sum index 123df6e4..c7a18aad 100644 --- a/go.sum +++ b/go.sum @@ -683,10 +683,10 @@ go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqe go.step.sm/cli-utils v0.7.0 h1:2GvY5Muid1yzp7YQbfCCS+gK3q7zlHjjLL5Z0DXz8ds= go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/E= go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= -go.step.sm/crypto v0.15.0 h1:VioBln+x3+RoejgeBhvxkLGVYdWRy6PFiAaUUN29/E0= -go.step.sm/crypto v0.15.0/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g= -go.step.sm/linkedca v0.10.0 h1:+bqymMRulHYkVde4l16FnqFVskoS6HCWJN5Z5cxAqF8= -go.step.sm/linkedca v0.10.0/go.mod h1:5uTRjozEGSPAZal9xJqlaD38cvJcLe3o1VAFVjqcORo= +go.step.sm/crypto v0.15.3 h1:f3GMl+aCydt294BZRjTYwpaXRqwwndvoTY2NLN4wu10= +go.step.sm/crypto v0.15.3/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g= +go.step.sm/linkedca v0.11.0 h1:jkG5XDQz9VSz2PH+cGjDvJTwiIziN0SWExTnicWpb8o= +go.step.sm/linkedca v0.11.0/go.mod h1:5uTRjozEGSPAZal9xJqlaD38cvJcLe3o1VAFVjqcORo= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=