forked from TrueCloudLab/certificates
Add support for renew when using stepcas
It supports renewing X.509 certificates when an RA is configured with stepcas. This will only work when the renewal uses a token, and it won't work with mTLS. The audience cannot be properly verified when an RA is used, to avoid this we will get from the database if an RA was used to issue the initial certificate and we will accept the renew token. Fixes #1021 for stepcas
This commit is contained in:
parent
068a2dae8e
commit
c7f226bcec
16 changed files with 487 additions and 39 deletions
|
@ -40,6 +40,7 @@ type Authority interface {
|
|||
Root(shasum string) (*x509.Certificate, error)
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
Renew(peer *x509.Certificate) ([]*x509.Certificate, error)
|
||||
RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error)
|
||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||
|
|
|
@ -192,6 +192,7 @@ type mockAuthority struct {
|
|||
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
|
||||
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
|
||||
loadProvisionerByName func(name string) (provisioner.Interface, error)
|
||||
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
|
||||
|
@ -264,6 +265,13 @@ func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, erro
|
|||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) RenewContext(ctx context.Context, oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
||||
if m.renewContext != nil {
|
||||
return m.renewContext(ctx, oldcert, pk)
|
||||
}
|
||||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
||||
if m.rekey != nil {
|
||||
return m.rekey(oldcert, pk)
|
||||
|
|
24
api/renew.go
24
api/renew.go
|
@ -6,6 +6,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
|
@ -17,14 +18,22 @@ const (
|
|||
// Renew uses the information of certificate in the TLS connection to create a
|
||||
// new one.
|
||||
func Renew(w http.ResponseWriter, r *http.Request) {
|
||||
cert, err := getPeerCertificate(r)
|
||||
ctx := r.Context()
|
||||
|
||||
// Get the leaf certificate from the peer or the token.
|
||||
cert, token, err := getPeerCertificate(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
a := mustAuthority(r.Context())
|
||||
certChain, err := a.Renew(cert)
|
||||
// The token can be used by RAs to renew a certificate.
|
||||
if token != "" {
|
||||
ctx = authority.NewTokenContext(ctx, token)
|
||||
}
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
certChain, err := a.RenewContext(ctx, cert, nil)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
||||
return
|
||||
|
@ -44,15 +53,16 @@ func Renew(w http.ResponseWriter, r *http.Request) {
|
|||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
func getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
|
||||
func getPeerCertificate(r *http.Request) (*x509.Certificate, string, error) {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
return r.TLS.PeerCertificates[0], nil
|
||||
return r.TLS.PeerCertificates[0], "", nil
|
||||
}
|
||||
if s := r.Header.Get(authorizationHeader); s != "" {
|
||||
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
|
||||
ctx := r.Context()
|
||||
return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
|
||||
peer, err := mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
|
||||
return peer, parts[1], err
|
||||
}
|
||||
}
|
||||
return nil, errs.BadRequest("missing client certificate")
|
||||
return nil, "", errs.BadRequest("missing client certificate")
|
||||
}
|
||||
|
|
|
@ -286,7 +286,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
|
|||
// extra extension cannot be found, authorize the renewal by default.
|
||||
//
|
||||
// TODO(mariano): should we authorize by default?
|
||||
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
||||
func (a *Authority) authorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
serial := cert.SerialNumber.String()
|
||||
var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)}
|
||||
|
||||
|
@ -308,7 +308,7 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
|||
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
|
||||
}
|
||||
}
|
||||
if err := p.AuthorizeRenew(context.Background(), cert); err != nil {
|
||||
if err := p.AuthorizeRenew(ctx, cert); err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
|
||||
}
|
||||
return nil
|
||||
|
@ -434,7 +434,7 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.
|
|||
}
|
||||
|
||||
audiences := a.config.GetAudiences().Renew
|
||||
if !matchesAudience(claims.Audience, audiences) {
|
||||
if !matchesAudience(claims.Audience, audiences) && !isRAProvisioner(p) {
|
||||
return nil, errs.InternalServerErr(jose.ErrInvalidAudience, errs.WithMessage("error validating renew token: invalid audience claim (aud)"))
|
||||
}
|
||||
|
||||
|
|
|
@ -876,7 +876,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
tc := genTestCase(t)
|
||||
|
||||
err := tc.auth.authorizeRenew(tc.cert)
|
||||
err := tc.auth.authorizeRenew(context.Background(), tc.cert)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
var sc render.StatusCodedError
|
||||
|
@ -1459,6 +1459,37 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
|||
})
|
||||
return nil
|
||||
}))
|
||||
a4 := testAuthority(t)
|
||||
a4.db = &db.MockAuthDB{
|
||||
MUseToken: func(id, tok string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
|
||||
return &db.CertificateData{
|
||||
Provisioner: &db.ProvisionerData{ID: "Max:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk", Name: "Max"},
|
||||
RaInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
t4, c4 := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://ra.example.com/1.0/renew"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-ca-client/1.0",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
cert.NotBefore = now
|
||||
cert.NotAfter = now.Add(time.Hour)
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||
Value: b,
|
||||
})
|
||||
return nil
|
||||
}))
|
||||
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/renew"},
|
||||
Subject: "test.example.com",
|
||||
|
@ -1627,6 +1658,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
|||
{"ok", a1, args{ctx, t1}, c1, false},
|
||||
{"ok expired cert", a1, args{ctx, t2}, c2, false},
|
||||
{"ok provisioner issuer", a1, args{ctx, t3}, c3, false},
|
||||
{"ok ra provisioner", a4, args{ctx, t4}, c4, false},
|
||||
{"fail token", a1, args{ctx, "not.a.token"}, nil, true},
|
||||
{"fail token reuse", a1, args{ctx, t1}, nil, true},
|
||||
{"fail token signature", a1, args{ctx, badSigner}, nil, true},
|
||||
|
|
|
@ -48,6 +48,22 @@ func wrapProvisioner(p provisioner.Interface, attData *provisioner.AttestationDa
|
|||
}
|
||||
}
|
||||
|
||||
// wrapRAProvisioner wraps the given provisioner with RA information.
|
||||
func wrapRAProvisioner(p provisioner.Interface, raInfo *provisioner.RAInfo) *wrappedProvisioner {
|
||||
return &wrappedProvisioner{
|
||||
Interface: p,
|
||||
raInfo: raInfo,
|
||||
}
|
||||
}
|
||||
|
||||
// isRAProvisioner returns if the given provisioner is an RA provisioner.
|
||||
func isRAProvisioner(p provisioner.Interface) bool {
|
||||
if rap, ok := p.(raProvisioner); ok {
|
||||
return rap.RAInfo() != nil
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// wrappedProvisioner implements raProvisioner and attProvisioner.
|
||||
type wrappedProvisioner struct {
|
||||
provisioner.Interface
|
||||
|
@ -119,6 +135,9 @@ func (a *Authority) unsafeLoadProvisionerFromDatabase(crt *x509.Certificate) (pr
|
|||
}
|
||||
if err == nil && data != nil && data.Provisioner != nil {
|
||||
if p, ok := a.provisioners.Load(data.Provisioner.ID); ok {
|
||||
if data.RaInfo != nil {
|
||||
return wrapRAProvisioner(p, data.RaInfo), nil
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -333,3 +333,54 @@ func TestProvisionerWebhookToLinkedca(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wrapRAProvisioner(t *testing.T) {
|
||||
type args struct {
|
||||
p provisioner.Interface
|
||||
raInfo *provisioner.RAInfo
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *wrappedProvisioner
|
||||
}{
|
||||
{"ok", args{&provisioner.JWK{Name: "jwt"}, &provisioner.RAInfo{ProvisionerName: "ra"}}, &wrappedProvisioner{
|
||||
Interface: &provisioner.JWK{Name: "jwt"},
|
||||
raInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
|
||||
}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := wrapRAProvisioner(tt.args.p, tt.args.raInfo); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("wrapRAProvisioner() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isRAProvisioner(t *testing.T) {
|
||||
type args struct {
|
||||
p provisioner.Interface
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{"true", args{&wrappedProvisioner{
|
||||
Interface: &provisioner.JWK{Name: "jwt"},
|
||||
raInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
|
||||
}}, true},
|
||||
{"nil ra", args{&wrappedProvisioner{
|
||||
Interface: &provisioner.JWK{Name: "jwt"},
|
||||
}}, false},
|
||||
{"not ra", args{&provisioner.JWK{Name: "jwt"}}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isRAProvisioner(tt.args.p); got != tt.want {
|
||||
t.Errorf("isRAProvisioner() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,6 +34,19 @@ import (
|
|||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
type tokenKey struct{}
|
||||
|
||||
// NewTokenContext adds the given token to the context.
|
||||
func NewTokenContext(ctx context.Context, token string) context.Context {
|
||||
return context.WithValue(ctx, tokenKey{}, token)
|
||||
}
|
||||
|
||||
// TokenFromContext returns the token from the given context.
|
||||
func TokenFromContext(ctx context.Context) (token string, ok bool) {
|
||||
token, ok = ctx.Value(tokenKey{}).(string)
|
||||
return
|
||||
}
|
||||
|
||||
// GetTLSOptions returns the tls options configured.
|
||||
func (a *Authority) GetTLSOptions() *config.TLSOptions {
|
||||
return a.config.TLS
|
||||
|
@ -294,28 +307,44 @@ func (a *Authority) AreSANsAllowed(ctx context.Context, sans []string) error {
|
|||
return a.policyEngine.AreSANsAllowed(sans)
|
||||
}
|
||||
|
||||
// Renew creates a new Certificate identical to the old certificate, except
|
||||
// with a validity window that begins 'now'.
|
||||
// Renew creates a new Certificate identical to the old certificate, except with
|
||||
// a validity window that begins 'now'.
|
||||
func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) {
|
||||
return a.Rekey(oldCert, nil)
|
||||
return a.RenewContext(context.Background(), oldCert, nil)
|
||||
}
|
||||
|
||||
// Rekey is used for rekeying and renewing based on the public key.
|
||||
// If the public key is 'nil' then it's assumed that the cert should be renewed
|
||||
// using the existing public key. If the public key is not 'nil' then it's
|
||||
// assumed that the cert should be rekeyed.
|
||||
// Rekey is used for rekeying and renewing based on the public key. If the
|
||||
// public key is 'nil' then it's assumed that the cert should be renewed using
|
||||
// the existing public key. If the public key is not 'nil' then it's assumed
|
||||
// that the cert should be rekeyed.
|
||||
//
|
||||
// For both Rekey and Renew all other attributes of the new certificate should
|
||||
// match the old certificate. The exceptions are 'AuthorityKeyId' (which may
|
||||
// have changed), 'SubjectKeyId' (different in case of rekey), and
|
||||
// 'NotBefore/NotAfter' (the validity duration of the new certificate should be
|
||||
// equal to the old one, but starting 'now').
|
||||
func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
||||
return a.RenewContext(context.Background(), oldCert, pk)
|
||||
}
|
||||
|
||||
// RenewContext creates a new certificate identical to the old one, but it can
|
||||
// optionally replace the public key with the given one. When running on RA
|
||||
// mode, it can only renew a certificate using a renew token instead.
|
||||
//
|
||||
// For both rekey and renew operations, all other attributes of the new
|
||||
// certificate should match the old certificate. The exceptions are
|
||||
// 'AuthorityKeyId' (which may have changed), 'SubjectKeyId' (different in case
|
||||
// of rekey), and 'NotBefore/NotAfter' (the validity duration of the new
|
||||
// certificate should be equal to the old one, but starting 'now').
|
||||
func (a *Authority) RenewContext(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
||||
isRekey := (pk != nil)
|
||||
opts := []interface{}{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())}
|
||||
opts := []errs.Option{
|
||||
errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String()),
|
||||
}
|
||||
|
||||
// Check step provisioner extensions
|
||||
if err := a.authorizeRenew(oldCert); err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...)
|
||||
if err := a.authorizeRenew(ctx, oldCert); err != nil {
|
||||
return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...)
|
||||
}
|
||||
|
||||
// Durations
|
||||
|
@ -388,7 +417,7 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
|
|||
if err := a.constraintsEngine.ValidateCertificate(newCert); err != nil {
|
||||
var ee *errs.Error
|
||||
if errors.As(err, &ee) {
|
||||
return nil, errs.ApplyOptions(ee, opts...)
|
||||
return nil, errs.StatusCodeError(ee.StatusCode(), err, opts...)
|
||||
}
|
||||
return nil, errs.InternalServerErr(err,
|
||||
errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String()),
|
||||
|
@ -396,19 +425,24 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
|
|||
)
|
||||
}
|
||||
|
||||
// The token can optionally be in the context. If the CA is running in RA
|
||||
// mode, this can be used to renew a certificate.
|
||||
token, _ := TokenFromContext(ctx)
|
||||
|
||||
resp, err := a.x509CAService.RenewCertificate(&casapi.RenewCertificateRequest{
|
||||
Template: newCert,
|
||||
Lifetime: lifetime,
|
||||
Backdate: backdate,
|
||||
Token: token,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...)
|
||||
return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...)
|
||||
}
|
||||
|
||||
fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
|
||||
if err = a.storeRenewedCertificate(oldCert, fullchain); err != nil {
|
||||
if !errors.Is(err, db.ErrNotImplemented) {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey; error storing certificate in db", opts...)
|
||||
return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -992,14 +992,14 @@ func TestAuthority_Renew(t *testing.T) {
|
|||
return &renewTest{
|
||||
auth: _a,
|
||||
cert: cert,
|
||||
err: errors.New("authority.Rekey: error creating certificate"),
|
||||
err: errors.New("error creating certificate"),
|
||||
code: http.StatusInternalServerError,
|
||||
}, nil
|
||||
},
|
||||
"fail/unauthorized": func() (*renewTest, error) {
|
||||
return &renewTest{
|
||||
cert: certNoRenew,
|
||||
err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
|
||||
err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
|
||||
code: http.StatusUnauthorized,
|
||||
}, nil
|
||||
},
|
||||
|
@ -1012,7 +1012,7 @@ func TestAuthority_Renew(t *testing.T) {
|
|||
return &renewTest{
|
||||
auth: aa,
|
||||
cert: cert,
|
||||
err: errors.New("authority.Rekey: authority.authorizeRenew: not authorized"),
|
||||
err: errors.New("authority.authorizeRenew: not authorized"),
|
||||
code: http.StatusUnauthorized,
|
||||
}, nil
|
||||
},
|
||||
|
@ -1221,14 +1221,14 @@ func TestAuthority_Rekey(t *testing.T) {
|
|||
return &renewTest{
|
||||
auth: _a,
|
||||
cert: cert,
|
||||
err: errors.New("authority.Rekey: error creating certificate"),
|
||||
err: errors.New("error creating certificate"),
|
||||
code: http.StatusInternalServerError,
|
||||
}, nil
|
||||
},
|
||||
"fail/unauthorized": func() (*renewTest, error) {
|
||||
return &renewTest{
|
||||
cert: certNoRenew,
|
||||
err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
|
||||
err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
|
||||
code: http.StatusUnauthorized,
|
||||
}, nil
|
||||
},
|
||||
|
|
|
@ -81,6 +81,7 @@ type RenewCertificateRequest struct {
|
|||
CSR *x509.CertificateRequest
|
||||
Lifetime time.Duration
|
||||
Backdate time.Duration
|
||||
Token string
|
||||
RequestID string
|
||||
}
|
||||
|
||||
|
|
|
@ -83,3 +83,23 @@ func (e NotImplementedError) Error() string {
|
|||
func (e NotImplementedError) StatusCode() int {
|
||||
return http.StatusNotImplemented
|
||||
}
|
||||
|
||||
// ValidationError is the type of error returned if request is not properly
|
||||
// validated.
|
||||
type ValidationError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
// NotImplementedError implements the error interface.
|
||||
func (e ValidationError) Error() string {
|
||||
if e.Message != "" {
|
||||
return e.Message
|
||||
}
|
||||
return "bad request"
|
||||
}
|
||||
|
||||
// StatusCode implements the StatusCoder interface and returns the HTTP 400
|
||||
// error.
|
||||
func (e ValidationError) StatusCode() int {
|
||||
return http.StatusBadRequest
|
||||
}
|
||||
|
|
|
@ -71,3 +71,51 @@ func TestNotImplementedError_StatusCode(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationError_Error(t *testing.T) {
|
||||
type fields struct {
|
||||
Message string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want string
|
||||
}{
|
||||
{"default", fields{""}, "bad request"},
|
||||
{"with message", fields{"token is empty"}, "token is empty"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := ValidationError{
|
||||
Message: tt.fields.Message,
|
||||
}
|
||||
if got := e.Error(); got != tt.want {
|
||||
t.Errorf("ValidationError.Error() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationError_StatusCode(t *testing.T) {
|
||||
type fields struct {
|
||||
Message string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want int
|
||||
}{
|
||||
{"default", fields{""}, 400},
|
||||
{"with message", fields{"token is empty"}, 400},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := ValidationError{
|
||||
Message: tt.fields.Message,
|
||||
}
|
||||
if got := e.StatusCode(); got != tt.want {
|
||||
t.Errorf("ValidationError.StatusCode() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -101,7 +101,25 @@ func (s *StepCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1
|
|||
// RenewCertificate will always return a non-implemented error as mTLS renewals
|
||||
// are not supported yet.
|
||||
func (s *StepCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) {
|
||||
return nil, apiv1.NotImplementedError{Message: "stepCAS does not support mTLS renewals"}
|
||||
if req.Token == "" {
|
||||
return nil, apiv1.ValidationError{Message: "renewCertificateRequest `token` cannot be empty"}
|
||||
}
|
||||
|
||||
resp, err := s.client.RenewWithToken(req.Token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var chain []*x509.Certificate
|
||||
cert := resp.CertChainPEM[0].Certificate
|
||||
for _, c := range resp.CertChainPEM[1:] {
|
||||
chain = append(chain, c.Certificate)
|
||||
}
|
||||
|
||||
return &apiv1.RenewCertificateResponse{
|
||||
Certificate: cert,
|
||||
CertificateChain: chain,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RevokeCertificate revokes a certificate.
|
||||
|
|
|
@ -147,6 +147,16 @@ func testCAHelper(t *testing.T) (*url.URL, *ca.Client) {
|
|||
writeJSON(w, api.SignResponse{
|
||||
CertChainPEM: []api.Certificate{api.NewCertificate(testCrt), api.NewCertificate(testIssCrt)},
|
||||
})
|
||||
case r.RequestURI == "/renew":
|
||||
if r.Header.Get("Authorization") == "Bearer fail" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(w, `{"error":"fail","message":"fail"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
writeJSON(w, api.SignResponse{
|
||||
CertChainPEM: []api.Certificate{api.NewCertificate(testCrt), api.NewCertificate(testIssCrt)},
|
||||
})
|
||||
case r.RequestURI == "/revoke":
|
||||
var msg api.RevokeRequest
|
||||
parseJSON(r, &msg)
|
||||
|
@ -723,9 +733,14 @@ func TestStepCAS_CreateCertificate(t *testing.T) {
|
|||
|
||||
func TestStepCAS_RenewCertificate(t *testing.T) {
|
||||
caURL, client := testCAHelper(t)
|
||||
x5c := testX5CIssuer(t, caURL, "")
|
||||
jwk := testJWKIssuer(t, caURL, "")
|
||||
|
||||
tokenIssuer := testX5CIssuer(t, caURL, "")
|
||||
token, err := tokenIssuer.SignToken("test", []string{"test.example.com"}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
iss stepIssuer
|
||||
client *ca.Client
|
||||
|
@ -741,13 +756,25 @@ func TestStepCAS_RenewCertificate(t *testing.T) {
|
|||
want *apiv1.RenewCertificateResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{"not implemented", fields{x5c, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
|
||||
CSR: testCR,
|
||||
{"ok", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
|
||||
Template: &x509.Certificate{},
|
||||
Backdate: time.Minute,
|
||||
Lifetime: time.Hour,
|
||||
Token: token,
|
||||
}}, &apiv1.RenewCertificateResponse{
|
||||
Certificate: testCrt,
|
||||
CertificateChain: []*x509.Certificate{testIssCrt},
|
||||
}, false},
|
||||
{"fail no token", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
|
||||
Template: &x509.Certificate{},
|
||||
Backdate: time.Minute,
|
||||
Lifetime: time.Hour,
|
||||
}}, nil, true},
|
||||
{"not implemented jwk", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
|
||||
CSR: testCR,
|
||||
{"fail bad token", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
|
||||
Template: &x509.Certificate{},
|
||||
Backdate: time.Minute,
|
||||
Lifetime: time.Hour,
|
||||
Token: "fail",
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
@ -763,7 +790,10 @@ func TestStepCAS_RenewCertificate(t *testing.T) {
|
|||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("StepCAS.RenewCertificate() = %v, want %v", got, tt.want)
|
||||
t.Error(reflect.DeepEqual(got.Certificate, tt.want.Certificate))
|
||||
t.Error(reflect.DeepEqual(got.CertificateChain, tt.want.CertificateChain))
|
||||
|
||||
t.Errorf("StepCAS.RenewCertificate() = %v, want %v", got.Certificate.Subject, tt.want.Certificate.Subject)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
40
db/db.go
40
db/db.go
|
@ -28,8 +28,9 @@ var (
|
|||
sshHostPrincipalsTable = []byte("ssh_host_principals")
|
||||
)
|
||||
|
||||
var crlKey = []byte("crl") //TODO: at the moment we store a single CRL in the database, in a dedicated table.
|
||||
// is this acceptable? probably not....
|
||||
// TODO: at the moment we store a single CRL in the database, in a dedicated table.
|
||||
// is this acceptable? probably not....
|
||||
var crlKey = []byte("crl")
|
||||
|
||||
// ErrAlreadyExists can be returned if the DB attempts to set a key that has
|
||||
// been previously set.
|
||||
|
@ -323,7 +324,8 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error {
|
|||
// CertificateData is the JSON representation of the data stored in
|
||||
// x509_certs_data table.
|
||||
type CertificateData struct {
|
||||
Provisioner *ProvisionerData `json:"provisioner,omitempty"`
|
||||
Provisioner *ProvisionerData `json:"provisioner,omitempty"`
|
||||
RaInfo *provisioner.RAInfo `json:"ra,omitempty"`
|
||||
}
|
||||
|
||||
// ProvisionerData is the JSON representation of the provisioner stored in the
|
||||
|
@ -334,6 +336,10 @@ type ProvisionerData struct {
|
|||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type raProvisioner interface {
|
||||
RAInfo() *provisioner.RAInfo
|
||||
}
|
||||
|
||||
// StoreCertificateChain stores the leaf certificate and the provisioner that
|
||||
// authorized the certificate.
|
||||
func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error {
|
||||
|
@ -346,6 +352,9 @@ func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Cert
|
|||
Name: p.GetName(),
|
||||
Type: p.GetType().String(),
|
||||
}
|
||||
if rap, ok := p.(raProvisioner); ok {
|
||||
data.RaInfo = rap.RAInfo()
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
|
@ -361,6 +370,31 @@ func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Cert
|
|||
return nil
|
||||
}
|
||||
|
||||
// StoreRenewedCertificate stores the leaf certificate and the provisioner that
|
||||
// authorized the old certificate if available.
|
||||
func (db *DB) StoreRenewedCertificate(oldCert *x509.Certificate, chain ...*x509.Certificate) error {
|
||||
var certificateData []byte
|
||||
if data, err := db.GetCertificateData(oldCert.SerialNumber.String()); err == nil {
|
||||
if b, err := json.Marshal(data); err == nil {
|
||||
certificateData = b
|
||||
}
|
||||
}
|
||||
|
||||
leaf := chain[0]
|
||||
serialNumber := []byte(leaf.SerialNumber.String())
|
||||
|
||||
// Add certificate and certificate data in one transaction.
|
||||
tx := new(database.Tx)
|
||||
tx.Set(certsTable, serialNumber, leaf.Raw)
|
||||
if certificateData != nil {
|
||||
tx.Set(certsDataTable, serialNumber, certificateData)
|
||||
}
|
||||
if err := db.Update(tx); err != nil {
|
||||
return errors.Wrap(err, "database Update error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UseToken returns true if we were able to successfully store the token for
|
||||
// for the first time, false otherwise.
|
||||
func (db *DB) UseToken(id, tok string) (bool, error) {
|
||||
|
|
142
db/db_test.go
142
db/db_test.go
|
@ -1,6 +1,7 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"math/big"
|
||||
|
@ -164,12 +165,30 @@ func TestUseToken(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// wrappedProvisioner implements raProvisioner and attProvisioner.
|
||||
type wrappedProvisioner struct {
|
||||
provisioner.Interface
|
||||
raInfo *provisioner.RAInfo
|
||||
}
|
||||
|
||||
func (p *wrappedProvisioner) RAInfo() *provisioner.RAInfo {
|
||||
return p.raInfo
|
||||
}
|
||||
|
||||
func TestDB_StoreCertificateChain(t *testing.T) {
|
||||
p := &provisioner.JWK{
|
||||
ID: "some-id",
|
||||
Name: "admin",
|
||||
Type: "JWK",
|
||||
}
|
||||
rap := &wrappedProvisioner{
|
||||
Interface: p,
|
||||
raInfo: &provisioner.RAInfo{
|
||||
ProvisionerID: "ra-id",
|
||||
ProvisionerType: "JWK",
|
||||
ProvisionerName: "ra",
|
||||
},
|
||||
}
|
||||
chain := []*x509.Certificate{
|
||||
{Raw: []byte("the certificate"), SerialNumber: big.NewInt(1234)},
|
||||
}
|
||||
|
@ -201,6 +220,21 @@ func TestDB_StoreCertificateChain(t *testing.T) {
|
|||
return nil
|
||||
},
|
||||
}, true}, args{p, chain}, false},
|
||||
{"ok ra provisioner", fields{&MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
if len(tx.Operations) != 2 {
|
||||
t.Fatal("unexpected number of operations")
|
||||
}
|
||||
assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket)
|
||||
assert.Equals(t, []byte("1234"), tx.Operations[0].Key)
|
||||
assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value)
|
||||
assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket)
|
||||
assert.Equals(t, []byte("1234"), tx.Operations[1].Key)
|
||||
assert.Equals(t, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"},"ra":{"provisionerId":"ra-id","provisionerType":"JWK","provisionerName":"ra"}}`), tx.Operations[1].Value)
|
||||
assert.Equals(t, `{"provisioner":{"id":"some-id","name":"admin","type":"JWK"},"ra":{"provisionerId":"ra-id","provisionerType":"JWK","provisionerName":"ra"}}`, string(tx.Operations[1].Value))
|
||||
return nil
|
||||
},
|
||||
}, true}, args{rap, chain}, false},
|
||||
{"ok no provisioner", fields{&MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
if len(tx.Operations) != 2 {
|
||||
|
@ -293,3 +327,111 @@ func TestDB_GetCertificateData(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_StoreRenewedCertificate(t *testing.T) {
|
||||
oldCert := &x509.Certificate{SerialNumber: big.NewInt(1)}
|
||||
chain := []*x509.Certificate{
|
||||
&x509.Certificate{SerialNumber: big.NewInt(2), Raw: []byte("raw")},
|
||||
&x509.Certificate{SerialNumber: big.NewInt(0)},
|
||||
}
|
||||
|
||||
testErr := errors.New("test error")
|
||||
certsData := []byte(`{"provisioner":{"id":"p","name":"name","type":"JWK"},"ra":{"provisionerId":"rap","provisionerType":"JWK","provisionerName":"rapname"}}`)
|
||||
matchOperation := func(op *database.TxEntry, bucket, key, value []byte) bool {
|
||||
return bytes.Equal(op.Bucket, bucket) && bytes.Equal(op.Key, key) && bytes.Equal(op.Value, value)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
DB nosql.DB
|
||||
isUp bool
|
||||
}
|
||||
type args struct {
|
||||
oldCert *x509.Certificate
|
||||
chain []*x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{&MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if bytes.Equal(bucket, certsDataTable) && bytes.Equal(key, []byte("1")) {
|
||||
return certsData, nil
|
||||
}
|
||||
t.Error("ok failed: unexpected get")
|
||||
return nil, testErr
|
||||
},
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
if len(tx.Operations) != 2 {
|
||||
t.Error("ok failed: unexpected number of operations")
|
||||
return testErr
|
||||
}
|
||||
op0, op1 := tx.Operations[0], tx.Operations[1]
|
||||
if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) {
|
||||
t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value)
|
||||
return testErr
|
||||
}
|
||||
if !matchOperation(op1, certsDataTable, []byte("2"), certsData) {
|
||||
t.Errorf("ok failed: unexpected entry 1, %s[%s]=%s", op1.Bucket, op1.Key, op1.Value)
|
||||
return testErr
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}, true}, args{oldCert, chain}, false},
|
||||
{"ok no data", fields{&MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
if len(tx.Operations) != 1 {
|
||||
t.Error("ok failed: unexpected number of operations")
|
||||
return testErr
|
||||
}
|
||||
op0 := tx.Operations[0]
|
||||
if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) {
|
||||
t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value)
|
||||
return testErr
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}, true}, args{oldCert, chain}, false},
|
||||
{"ok fail marshal", fields{&MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return []byte(`{"bad":"json"`), nil
|
||||
},
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
if len(tx.Operations) != 1 {
|
||||
t.Error("ok failed: unexpected number of operations")
|
||||
return testErr
|
||||
}
|
||||
op0 := tx.Operations[0]
|
||||
if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) {
|
||||
t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value)
|
||||
return testErr
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}, true}, args{oldCert, chain}, false},
|
||||
{"fail", fields{&MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return certsData, nil
|
||||
},
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
return testErr
|
||||
},
|
||||
}, true}, args{oldCert, chain}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := &DB{
|
||||
DB: tt.fields.DB,
|
||||
isUp: tt.fields.isUp,
|
||||
}
|
||||
if err := db.StoreRenewedCertificate(tt.args.oldCert, tt.args.chain...); (err != nil) != tt.wantErr {
|
||||
t.Errorf("DB.StoreRenewedCertificate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue