forked from TrueCloudLab/certificates
Merge pull request #1156 from smallstep/ra-renew
Add support for renew when using stepcas
This commit is contained in:
commit
e8726d24fa
16 changed files with 487 additions and 39 deletions
|
@ -40,6 +40,7 @@ type Authority interface {
|
|||
Root(shasum string) (*x509.Certificate, error)
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
Renew(peer *x509.Certificate) ([]*x509.Certificate, error)
|
||||
RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error)
|
||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||
|
|
|
@ -192,6 +192,7 @@ type mockAuthority struct {
|
|||
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
|
||||
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
|
||||
loadProvisionerByName func(name string) (provisioner.Interface, error)
|
||||
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
|
||||
|
@ -264,6 +265,13 @@ func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, erro
|
|||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) RenewContext(ctx context.Context, oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
||||
if m.renewContext != nil {
|
||||
return m.renewContext(ctx, oldcert, pk)
|
||||
}
|
||||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
||||
if m.rekey != nil {
|
||||
return m.rekey(oldcert, pk)
|
||||
|
|
24
api/renew.go
24
api/renew.go
|
@ -6,6 +6,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
|
@ -17,14 +18,22 @@ const (
|
|||
// Renew uses the information of certificate in the TLS connection to create a
|
||||
// new one.
|
||||
func Renew(w http.ResponseWriter, r *http.Request) {
|
||||
cert, err := getPeerCertificate(r)
|
||||
ctx := r.Context()
|
||||
|
||||
// Get the leaf certificate from the peer or the token.
|
||||
cert, token, err := getPeerCertificate(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
a := mustAuthority(r.Context())
|
||||
certChain, err := a.Renew(cert)
|
||||
// The token can be used by RAs to renew a certificate.
|
||||
if token != "" {
|
||||
ctx = authority.NewTokenContext(ctx, token)
|
||||
}
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
certChain, err := a.RenewContext(ctx, cert, nil)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
||||
return
|
||||
|
@ -44,15 +53,16 @@ func Renew(w http.ResponseWriter, r *http.Request) {
|
|||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
func getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
|
||||
func getPeerCertificate(r *http.Request) (*x509.Certificate, string, error) {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
return r.TLS.PeerCertificates[0], nil
|
||||
return r.TLS.PeerCertificates[0], "", nil
|
||||
}
|
||||
if s := r.Header.Get(authorizationHeader); s != "" {
|
||||
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
|
||||
ctx := r.Context()
|
||||
return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
|
||||
peer, err := mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
|
||||
return peer, parts[1], err
|
||||
}
|
||||
}
|
||||
return nil, errs.BadRequest("missing client certificate")
|
||||
return nil, "", errs.BadRequest("missing client certificate")
|
||||
}
|
||||
|
|
|
@ -286,7 +286,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
|
|||
// extra extension cannot be found, authorize the renewal by default.
|
||||
//
|
||||
// TODO(mariano): should we authorize by default?
|
||||
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
||||
func (a *Authority) authorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
serial := cert.SerialNumber.String()
|
||||
var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)}
|
||||
|
||||
|
@ -308,7 +308,7 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
|||
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
|
||||
}
|
||||
}
|
||||
if err := p.AuthorizeRenew(context.Background(), cert); err != nil {
|
||||
if err := p.AuthorizeRenew(ctx, cert); err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
|
||||
}
|
||||
return nil
|
||||
|
@ -434,7 +434,7 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.
|
|||
}
|
||||
|
||||
audiences := a.config.GetAudiences().Renew
|
||||
if !matchesAudience(claims.Audience, audiences) {
|
||||
if !matchesAudience(claims.Audience, audiences) && !isRAProvisioner(p) {
|
||||
return nil, errs.InternalServerErr(jose.ErrInvalidAudience, errs.WithMessage("error validating renew token: invalid audience claim (aud)"))
|
||||
}
|
||||
|
||||
|
|
|
@ -876,7 +876,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
tc := genTestCase(t)
|
||||
|
||||
err := tc.auth.authorizeRenew(tc.cert)
|
||||
err := tc.auth.authorizeRenew(context.Background(), tc.cert)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
var sc render.StatusCodedError
|
||||
|
@ -1459,6 +1459,37 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
|||
})
|
||||
return nil
|
||||
}))
|
||||
a4 := testAuthority(t)
|
||||
a4.db = &db.MockAuthDB{
|
||||
MUseToken: func(id, tok string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
|
||||
return &db.CertificateData{
|
||||
Provisioner: &db.ProvisionerData{ID: "Max:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk", Name: "Max"},
|
||||
RaInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
t4, c4 := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://ra.example.com/1.0/renew"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-ca-client/1.0",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
cert.NotBefore = now
|
||||
cert.NotAfter = now.Add(time.Hour)
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||
Value: b,
|
||||
})
|
||||
return nil
|
||||
}))
|
||||
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/renew"},
|
||||
Subject: "test.example.com",
|
||||
|
@ -1627,6 +1658,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
|||
{"ok", a1, args{ctx, t1}, c1, false},
|
||||
{"ok expired cert", a1, args{ctx, t2}, c2, false},
|
||||
{"ok provisioner issuer", a1, args{ctx, t3}, c3, false},
|
||||
{"ok ra provisioner", a4, args{ctx, t4}, c4, false},
|
||||
{"fail token", a1, args{ctx, "not.a.token"}, nil, true},
|
||||
{"fail token reuse", a1, args{ctx, t1}, nil, true},
|
||||
{"fail token signature", a1, args{ctx, badSigner}, nil, true},
|
||||
|
|
|
@ -48,6 +48,22 @@ func wrapProvisioner(p provisioner.Interface, attData *provisioner.AttestationDa
|
|||
}
|
||||
}
|
||||
|
||||
// wrapRAProvisioner wraps the given provisioner with RA information.
|
||||
func wrapRAProvisioner(p provisioner.Interface, raInfo *provisioner.RAInfo) *wrappedProvisioner {
|
||||
return &wrappedProvisioner{
|
||||
Interface: p,
|
||||
raInfo: raInfo,
|
||||
}
|
||||
}
|
||||
|
||||
// isRAProvisioner returns if the given provisioner is an RA provisioner.
|
||||
func isRAProvisioner(p provisioner.Interface) bool {
|
||||
if rap, ok := p.(raProvisioner); ok {
|
||||
return rap.RAInfo() != nil
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// wrappedProvisioner implements raProvisioner and attProvisioner.
|
||||
type wrappedProvisioner struct {
|
||||
provisioner.Interface
|
||||
|
@ -119,6 +135,9 @@ func (a *Authority) unsafeLoadProvisionerFromDatabase(crt *x509.Certificate) (pr
|
|||
}
|
||||
if err == nil && data != nil && data.Provisioner != nil {
|
||||
if p, ok := a.provisioners.Load(data.Provisioner.ID); ok {
|
||||
if data.RaInfo != nil {
|
||||
return wrapRAProvisioner(p, data.RaInfo), nil
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -333,3 +333,54 @@ func TestProvisionerWebhookToLinkedca(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wrapRAProvisioner(t *testing.T) {
|
||||
type args struct {
|
||||
p provisioner.Interface
|
||||
raInfo *provisioner.RAInfo
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *wrappedProvisioner
|
||||
}{
|
||||
{"ok", args{&provisioner.JWK{Name: "jwt"}, &provisioner.RAInfo{ProvisionerName: "ra"}}, &wrappedProvisioner{
|
||||
Interface: &provisioner.JWK{Name: "jwt"},
|
||||
raInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
|
||||
}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := wrapRAProvisioner(tt.args.p, tt.args.raInfo); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("wrapRAProvisioner() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isRAProvisioner(t *testing.T) {
|
||||
type args struct {
|
||||
p provisioner.Interface
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{"true", args{&wrappedProvisioner{
|
||||
Interface: &provisioner.JWK{Name: "jwt"},
|
||||
raInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
|
||||
}}, true},
|
||||
{"nil ra", args{&wrappedProvisioner{
|
||||
Interface: &provisioner.JWK{Name: "jwt"},
|
||||
}}, false},
|
||||
{"not ra", args{&provisioner.JWK{Name: "jwt"}}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isRAProvisioner(tt.args.p); got != tt.want {
|
||||
t.Errorf("isRAProvisioner() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,6 +34,19 @@ import (
|
|||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
type tokenKey struct{}
|
||||
|
||||
// NewTokenContext adds the given token to the context.
|
||||
func NewTokenContext(ctx context.Context, token string) context.Context {
|
||||
return context.WithValue(ctx, tokenKey{}, token)
|
||||
}
|
||||
|
||||
// TokenFromContext returns the token from the given context.
|
||||
func TokenFromContext(ctx context.Context) (token string, ok bool) {
|
||||
token, ok = ctx.Value(tokenKey{}).(string)
|
||||
return
|
||||
}
|
||||
|
||||
// GetTLSOptions returns the tls options configured.
|
||||
func (a *Authority) GetTLSOptions() *config.TLSOptions {
|
||||
return a.config.TLS
|
||||
|
@ -294,28 +307,44 @@ func (a *Authority) AreSANsAllowed(ctx context.Context, sans []string) error {
|
|||
return a.policyEngine.AreSANsAllowed(sans)
|
||||
}
|
||||
|
||||
// Renew creates a new Certificate identical to the old certificate, except
|
||||
// with a validity window that begins 'now'.
|
||||
// Renew creates a new Certificate identical to the old certificate, except with
|
||||
// a validity window that begins 'now'.
|
||||
func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) {
|
||||
return a.Rekey(oldCert, nil)
|
||||
return a.RenewContext(context.Background(), oldCert, nil)
|
||||
}
|
||||
|
||||
// Rekey is used for rekeying and renewing based on the public key.
|
||||
// If the public key is 'nil' then it's assumed that the cert should be renewed
|
||||
// using the existing public key. If the public key is not 'nil' then it's
|
||||
// assumed that the cert should be rekeyed.
|
||||
// Rekey is used for rekeying and renewing based on the public key. If the
|
||||
// public key is 'nil' then it's assumed that the cert should be renewed using
|
||||
// the existing public key. If the public key is not 'nil' then it's assumed
|
||||
// that the cert should be rekeyed.
|
||||
//
|
||||
// For both Rekey and Renew all other attributes of the new certificate should
|
||||
// match the old certificate. The exceptions are 'AuthorityKeyId' (which may
|
||||
// have changed), 'SubjectKeyId' (different in case of rekey), and
|
||||
// 'NotBefore/NotAfter' (the validity duration of the new certificate should be
|
||||
// equal to the old one, but starting 'now').
|
||||
func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
||||
return a.RenewContext(context.Background(), oldCert, pk)
|
||||
}
|
||||
|
||||
// RenewContext creates a new certificate identical to the old one, but it can
|
||||
// optionally replace the public key with the given one. When running on RA
|
||||
// mode, it can only renew a certificate using a renew token instead.
|
||||
//
|
||||
// For both rekey and renew operations, all other attributes of the new
|
||||
// certificate should match the old certificate. The exceptions are
|
||||
// 'AuthorityKeyId' (which may have changed), 'SubjectKeyId' (different in case
|
||||
// of rekey), and 'NotBefore/NotAfter' (the validity duration of the new
|
||||
// certificate should be equal to the old one, but starting 'now').
|
||||
func (a *Authority) RenewContext(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
||||
isRekey := (pk != nil)
|
||||
opts := []interface{}{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())}
|
||||
opts := []errs.Option{
|
||||
errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String()),
|
||||
}
|
||||
|
||||
// Check step provisioner extensions
|
||||
if err := a.authorizeRenew(oldCert); err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...)
|
||||
if err := a.authorizeRenew(ctx, oldCert); err != nil {
|
||||
return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...)
|
||||
}
|
||||
|
||||
// Durations
|
||||
|
@ -388,7 +417,7 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
|
|||
if err := a.constraintsEngine.ValidateCertificate(newCert); err != nil {
|
||||
var ee *errs.Error
|
||||
if errors.As(err, &ee) {
|
||||
return nil, errs.ApplyOptions(ee, opts...)
|
||||
return nil, errs.StatusCodeError(ee.StatusCode(), err, opts...)
|
||||
}
|
||||
return nil, errs.InternalServerErr(err,
|
||||
errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String()),
|
||||
|
@ -396,19 +425,24 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
|
|||
)
|
||||
}
|
||||
|
||||
// The token can optionally be in the context. If the CA is running in RA
|
||||
// mode, this can be used to renew a certificate.
|
||||
token, _ := TokenFromContext(ctx)
|
||||
|
||||
resp, err := a.x509CAService.RenewCertificate(&casapi.RenewCertificateRequest{
|
||||
Template: newCert,
|
||||
Lifetime: lifetime,
|
||||
Backdate: backdate,
|
||||
Token: token,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...)
|
||||
return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...)
|
||||
}
|
||||
|
||||
fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
|
||||
if err = a.storeRenewedCertificate(oldCert, fullchain); err != nil {
|
||||
if !errors.Is(err, db.ErrNotImplemented) {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey; error storing certificate in db", opts...)
|
||||
return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -992,14 +992,14 @@ func TestAuthority_Renew(t *testing.T) {
|
|||
return &renewTest{
|
||||
auth: _a,
|
||||
cert: cert,
|
||||
err: errors.New("authority.Rekey: error creating certificate"),
|
||||
err: errors.New("error creating certificate"),
|
||||
code: http.StatusInternalServerError,
|
||||
}, nil
|
||||
},
|
||||
"fail/unauthorized": func() (*renewTest, error) {
|
||||
return &renewTest{
|
||||
cert: certNoRenew,
|
||||
err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
|
||||
err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
|
||||
code: http.StatusUnauthorized,
|
||||
}, nil
|
||||
},
|
||||
|
@ -1012,7 +1012,7 @@ func TestAuthority_Renew(t *testing.T) {
|
|||
return &renewTest{
|
||||
auth: aa,
|
||||
cert: cert,
|
||||
err: errors.New("authority.Rekey: authority.authorizeRenew: not authorized"),
|
||||
err: errors.New("authority.authorizeRenew: not authorized"),
|
||||
code: http.StatusUnauthorized,
|
||||
}, nil
|
||||
},
|
||||
|
@ -1221,14 +1221,14 @@ func TestAuthority_Rekey(t *testing.T) {
|
|||
return &renewTest{
|
||||
auth: _a,
|
||||
cert: cert,
|
||||
err: errors.New("authority.Rekey: error creating certificate"),
|
||||
err: errors.New("error creating certificate"),
|
||||
code: http.StatusInternalServerError,
|
||||
}, nil
|
||||
},
|
||||
"fail/unauthorized": func() (*renewTest, error) {
|
||||
return &renewTest{
|
||||
cert: certNoRenew,
|
||||
err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
|
||||
err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
|
||||
code: http.StatusUnauthorized,
|
||||
}, nil
|
||||
},
|
||||
|
|
|
@ -81,6 +81,7 @@ type RenewCertificateRequest struct {
|
|||
CSR *x509.CertificateRequest
|
||||
Lifetime time.Duration
|
||||
Backdate time.Duration
|
||||
Token string
|
||||
RequestID string
|
||||
}
|
||||
|
||||
|
|
|
@ -83,3 +83,23 @@ func (e NotImplementedError) Error() string {
|
|||
func (e NotImplementedError) StatusCode() int {
|
||||
return http.StatusNotImplemented
|
||||
}
|
||||
|
||||
// ValidationError is the type of error returned if request is not properly
|
||||
// validated.
|
||||
type ValidationError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
// NotImplementedError implements the error interface.
|
||||
func (e ValidationError) Error() string {
|
||||
if e.Message != "" {
|
||||
return e.Message
|
||||
}
|
||||
return "bad request"
|
||||
}
|
||||
|
||||
// StatusCode implements the StatusCoder interface and returns the HTTP 400
|
||||
// error.
|
||||
func (e ValidationError) StatusCode() int {
|
||||
return http.StatusBadRequest
|
||||
}
|
||||
|
|
|
@ -71,3 +71,51 @@ func TestNotImplementedError_StatusCode(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationError_Error(t *testing.T) {
|
||||
type fields struct {
|
||||
Message string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want string
|
||||
}{
|
||||
{"default", fields{""}, "bad request"},
|
||||
{"with message", fields{"token is empty"}, "token is empty"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := ValidationError{
|
||||
Message: tt.fields.Message,
|
||||
}
|
||||
if got := e.Error(); got != tt.want {
|
||||
t.Errorf("ValidationError.Error() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationError_StatusCode(t *testing.T) {
|
||||
type fields struct {
|
||||
Message string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want int
|
||||
}{
|
||||
{"default", fields{""}, 400},
|
||||
{"with message", fields{"token is empty"}, 400},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := ValidationError{
|
||||
Message: tt.fields.Message,
|
||||
}
|
||||
if got := e.StatusCode(); got != tt.want {
|
||||
t.Errorf("ValidationError.StatusCode() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -101,7 +101,25 @@ func (s *StepCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1
|
|||
// RenewCertificate will always return a non-implemented error as mTLS renewals
|
||||
// are not supported yet.
|
||||
func (s *StepCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) {
|
||||
return nil, apiv1.NotImplementedError{Message: "stepCAS does not support mTLS renewals"}
|
||||
if req.Token == "" {
|
||||
return nil, apiv1.ValidationError{Message: "renewCertificateRequest `token` cannot be empty"}
|
||||
}
|
||||
|
||||
resp, err := s.client.RenewWithToken(req.Token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var chain []*x509.Certificate
|
||||
cert := resp.CertChainPEM[0].Certificate
|
||||
for _, c := range resp.CertChainPEM[1:] {
|
||||
chain = append(chain, c.Certificate)
|
||||
}
|
||||
|
||||
return &apiv1.RenewCertificateResponse{
|
||||
Certificate: cert,
|
||||
CertificateChain: chain,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RevokeCertificate revokes a certificate.
|
||||
|
|
|
@ -147,6 +147,16 @@ func testCAHelper(t *testing.T) (*url.URL, *ca.Client) {
|
|||
writeJSON(w, api.SignResponse{
|
||||
CertChainPEM: []api.Certificate{api.NewCertificate(testCrt), api.NewCertificate(testIssCrt)},
|
||||
})
|
||||
case r.RequestURI == "/renew":
|
||||
if r.Header.Get("Authorization") == "Bearer fail" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(w, `{"error":"fail","message":"fail"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
writeJSON(w, api.SignResponse{
|
||||
CertChainPEM: []api.Certificate{api.NewCertificate(testCrt), api.NewCertificate(testIssCrt)},
|
||||
})
|
||||
case r.RequestURI == "/revoke":
|
||||
var msg api.RevokeRequest
|
||||
parseJSON(r, &msg)
|
||||
|
@ -723,9 +733,14 @@ func TestStepCAS_CreateCertificate(t *testing.T) {
|
|||
|
||||
func TestStepCAS_RenewCertificate(t *testing.T) {
|
||||
caURL, client := testCAHelper(t)
|
||||
x5c := testX5CIssuer(t, caURL, "")
|
||||
jwk := testJWKIssuer(t, caURL, "")
|
||||
|
||||
tokenIssuer := testX5CIssuer(t, caURL, "")
|
||||
token, err := tokenIssuer.SignToken("test", []string{"test.example.com"}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
iss stepIssuer
|
||||
client *ca.Client
|
||||
|
@ -741,13 +756,25 @@ func TestStepCAS_RenewCertificate(t *testing.T) {
|
|||
want *apiv1.RenewCertificateResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{"not implemented", fields{x5c, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
|
||||
CSR: testCR,
|
||||
{"ok", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
|
||||
Template: &x509.Certificate{},
|
||||
Backdate: time.Minute,
|
||||
Lifetime: time.Hour,
|
||||
Token: token,
|
||||
}}, &apiv1.RenewCertificateResponse{
|
||||
Certificate: testCrt,
|
||||
CertificateChain: []*x509.Certificate{testIssCrt},
|
||||
}, false},
|
||||
{"fail no token", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
|
||||
Template: &x509.Certificate{},
|
||||
Backdate: time.Minute,
|
||||
Lifetime: time.Hour,
|
||||
}}, nil, true},
|
||||
{"not implemented jwk", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
|
||||
CSR: testCR,
|
||||
{"fail bad token", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
|
||||
Template: &x509.Certificate{},
|
||||
Backdate: time.Minute,
|
||||
Lifetime: time.Hour,
|
||||
Token: "fail",
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
@ -763,7 +790,10 @@ func TestStepCAS_RenewCertificate(t *testing.T) {
|
|||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("StepCAS.RenewCertificate() = %v, want %v", got, tt.want)
|
||||
t.Error(reflect.DeepEqual(got.Certificate, tt.want.Certificate))
|
||||
t.Error(reflect.DeepEqual(got.CertificateChain, tt.want.CertificateChain))
|
||||
|
||||
t.Errorf("StepCAS.RenewCertificate() = %v, want %v", got.Certificate.Subject, tt.want.Certificate.Subject)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
40
db/db.go
40
db/db.go
|
@ -28,8 +28,9 @@ var (
|
|||
sshHostPrincipalsTable = []byte("ssh_host_principals")
|
||||
)
|
||||
|
||||
var crlKey = []byte("crl") //TODO: at the moment we store a single CRL in the database, in a dedicated table.
|
||||
// is this acceptable? probably not....
|
||||
// TODO: at the moment we store a single CRL in the database, in a dedicated table.
|
||||
// is this acceptable? probably not....
|
||||
var crlKey = []byte("crl")
|
||||
|
||||
// ErrAlreadyExists can be returned if the DB attempts to set a key that has
|
||||
// been previously set.
|
||||
|
@ -323,7 +324,8 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error {
|
|||
// CertificateData is the JSON representation of the data stored in
|
||||
// x509_certs_data table.
|
||||
type CertificateData struct {
|
||||
Provisioner *ProvisionerData `json:"provisioner,omitempty"`
|
||||
Provisioner *ProvisionerData `json:"provisioner,omitempty"`
|
||||
RaInfo *provisioner.RAInfo `json:"ra,omitempty"`
|
||||
}
|
||||
|
||||
// ProvisionerData is the JSON representation of the provisioner stored in the
|
||||
|
@ -334,6 +336,10 @@ type ProvisionerData struct {
|
|||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type raProvisioner interface {
|
||||
RAInfo() *provisioner.RAInfo
|
||||
}
|
||||
|
||||
// StoreCertificateChain stores the leaf certificate and the provisioner that
|
||||
// authorized the certificate.
|
||||
func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error {
|
||||
|
@ -346,6 +352,9 @@ func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Cert
|
|||
Name: p.GetName(),
|
||||
Type: p.GetType().String(),
|
||||
}
|
||||
if rap, ok := p.(raProvisioner); ok {
|
||||
data.RaInfo = rap.RAInfo()
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
|
@ -361,6 +370,31 @@ func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Cert
|
|||
return nil
|
||||
}
|
||||
|
||||
// StoreRenewedCertificate stores the leaf certificate and the provisioner that
|
||||
// authorized the old certificate if available.
|
||||
func (db *DB) StoreRenewedCertificate(oldCert *x509.Certificate, chain ...*x509.Certificate) error {
|
||||
var certificateData []byte
|
||||
if data, err := db.GetCertificateData(oldCert.SerialNumber.String()); err == nil {
|
||||
if b, err := json.Marshal(data); err == nil {
|
||||
certificateData = b
|
||||
}
|
||||
}
|
||||
|
||||
leaf := chain[0]
|
||||
serialNumber := []byte(leaf.SerialNumber.String())
|
||||
|
||||
// Add certificate and certificate data in one transaction.
|
||||
tx := new(database.Tx)
|
||||
tx.Set(certsTable, serialNumber, leaf.Raw)
|
||||
if certificateData != nil {
|
||||
tx.Set(certsDataTable, serialNumber, certificateData)
|
||||
}
|
||||
if err := db.Update(tx); err != nil {
|
||||
return errors.Wrap(err, "database Update error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UseToken returns true if we were able to successfully store the token for
|
||||
// for the first time, false otherwise.
|
||||
func (db *DB) UseToken(id, tok string) (bool, error) {
|
||||
|
|
142
db/db_test.go
142
db/db_test.go
|
@ -1,6 +1,7 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"math/big"
|
||||
|
@ -164,12 +165,30 @@ func TestUseToken(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// wrappedProvisioner implements raProvisioner and attProvisioner.
|
||||
type wrappedProvisioner struct {
|
||||
provisioner.Interface
|
||||
raInfo *provisioner.RAInfo
|
||||
}
|
||||
|
||||
func (p *wrappedProvisioner) RAInfo() *provisioner.RAInfo {
|
||||
return p.raInfo
|
||||
}
|
||||
|
||||
func TestDB_StoreCertificateChain(t *testing.T) {
|
||||
p := &provisioner.JWK{
|
||||
ID: "some-id",
|
||||
Name: "admin",
|
||||
Type: "JWK",
|
||||
}
|
||||
rap := &wrappedProvisioner{
|
||||
Interface: p,
|
||||
raInfo: &provisioner.RAInfo{
|
||||
ProvisionerID: "ra-id",
|
||||
ProvisionerType: "JWK",
|
||||
ProvisionerName: "ra",
|
||||
},
|
||||
}
|
||||
chain := []*x509.Certificate{
|
||||
{Raw: []byte("the certificate"), SerialNumber: big.NewInt(1234)},
|
||||
}
|
||||
|
@ -201,6 +220,21 @@ func TestDB_StoreCertificateChain(t *testing.T) {
|
|||
return nil
|
||||
},
|
||||
}, true}, args{p, chain}, false},
|
||||
{"ok ra provisioner", fields{&MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
if len(tx.Operations) != 2 {
|
||||
t.Fatal("unexpected number of operations")
|
||||
}
|
||||
assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket)
|
||||
assert.Equals(t, []byte("1234"), tx.Operations[0].Key)
|
||||
assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value)
|
||||
assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket)
|
||||
assert.Equals(t, []byte("1234"), tx.Operations[1].Key)
|
||||
assert.Equals(t, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"},"ra":{"provisionerId":"ra-id","provisionerType":"JWK","provisionerName":"ra"}}`), tx.Operations[1].Value)
|
||||
assert.Equals(t, `{"provisioner":{"id":"some-id","name":"admin","type":"JWK"},"ra":{"provisionerId":"ra-id","provisionerType":"JWK","provisionerName":"ra"}}`, string(tx.Operations[1].Value))
|
||||
return nil
|
||||
},
|
||||
}, true}, args{rap, chain}, false},
|
||||
{"ok no provisioner", fields{&MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
if len(tx.Operations) != 2 {
|
||||
|
@ -293,3 +327,111 @@ func TestDB_GetCertificateData(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_StoreRenewedCertificate(t *testing.T) {
|
||||
oldCert := &x509.Certificate{SerialNumber: big.NewInt(1)}
|
||||
chain := []*x509.Certificate{
|
||||
&x509.Certificate{SerialNumber: big.NewInt(2), Raw: []byte("raw")},
|
||||
&x509.Certificate{SerialNumber: big.NewInt(0)},
|
||||
}
|
||||
|
||||
testErr := errors.New("test error")
|
||||
certsData := []byte(`{"provisioner":{"id":"p","name":"name","type":"JWK"},"ra":{"provisionerId":"rap","provisionerType":"JWK","provisionerName":"rapname"}}`)
|
||||
matchOperation := func(op *database.TxEntry, bucket, key, value []byte) bool {
|
||||
return bytes.Equal(op.Bucket, bucket) && bytes.Equal(op.Key, key) && bytes.Equal(op.Value, value)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
DB nosql.DB
|
||||
isUp bool
|
||||
}
|
||||
type args struct {
|
||||
oldCert *x509.Certificate
|
||||
chain []*x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{&MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if bytes.Equal(bucket, certsDataTable) && bytes.Equal(key, []byte("1")) {
|
||||
return certsData, nil
|
||||
}
|
||||
t.Error("ok failed: unexpected get")
|
||||
return nil, testErr
|
||||
},
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
if len(tx.Operations) != 2 {
|
||||
t.Error("ok failed: unexpected number of operations")
|
||||
return testErr
|
||||
}
|
||||
op0, op1 := tx.Operations[0], tx.Operations[1]
|
||||
if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) {
|
||||
t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value)
|
||||
return testErr
|
||||
}
|
||||
if !matchOperation(op1, certsDataTable, []byte("2"), certsData) {
|
||||
t.Errorf("ok failed: unexpected entry 1, %s[%s]=%s", op1.Bucket, op1.Key, op1.Value)
|
||||
return testErr
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}, true}, args{oldCert, chain}, false},
|
||||
{"ok no data", fields{&MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
if len(tx.Operations) != 1 {
|
||||
t.Error("ok failed: unexpected number of operations")
|
||||
return testErr
|
||||
}
|
||||
op0 := tx.Operations[0]
|
||||
if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) {
|
||||
t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value)
|
||||
return testErr
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}, true}, args{oldCert, chain}, false},
|
||||
{"ok fail marshal", fields{&MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return []byte(`{"bad":"json"`), nil
|
||||
},
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
if len(tx.Operations) != 1 {
|
||||
t.Error("ok failed: unexpected number of operations")
|
||||
return testErr
|
||||
}
|
||||
op0 := tx.Operations[0]
|
||||
if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) {
|
||||
t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value)
|
||||
return testErr
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}, true}, args{oldCert, chain}, false},
|
||||
{"fail", fields{&MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return certsData, nil
|
||||
},
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
return testErr
|
||||
},
|
||||
}, true}, args{oldCert, chain}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := &DB{
|
||||
DB: tt.fields.DB,
|
||||
isUp: tt.fields.isUp,
|
||||
}
|
||||
if err := db.StoreRenewedCertificate(tt.args.oldCert, tt.args.chain...); (err != nil) != tt.wantErr {
|
||||
t.Errorf("DB.StoreRenewedCertificate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue