forked from TrueCloudLab/certificates
Merge pull request #1156 from smallstep/ra-renew
Add support for renew when using stepcas
This commit is contained in:
commit
e8726d24fa
16 changed files with 487 additions and 39 deletions
|
@ -40,6 +40,7 @@ type Authority interface {
|
||||||
Root(shasum string) (*x509.Certificate, error)
|
Root(shasum string) (*x509.Certificate, error)
|
||||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||||
Renew(peer *x509.Certificate) ([]*x509.Certificate, error)
|
Renew(peer *x509.Certificate) ([]*x509.Certificate, error)
|
||||||
|
RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||||
Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||||
LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error)
|
LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error)
|
||||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||||
|
|
|
@ -192,6 +192,7 @@ type mockAuthority struct {
|
||||||
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||||
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
|
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
|
||||||
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||||
|
renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||||
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
|
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
|
||||||
loadProvisionerByName func(name string) (provisioner.Interface, error)
|
loadProvisionerByName func(name string) (provisioner.Interface, error)
|
||||||
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
|
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
|
||||||
|
@ -264,6 +265,13 @@ func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, erro
|
||||||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockAuthority) RenewContext(ctx context.Context, oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
||||||
|
if m.renewContext != nil {
|
||||||
|
return m.renewContext(ctx, oldcert, pk)
|
||||||
|
}
|
||||||
|
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
||||||
if m.rekey != nil {
|
if m.rekey != nil {
|
||||||
return m.rekey(oldcert, pk)
|
return m.rekey(oldcert, pk)
|
||||||
|
|
24
api/renew.go
24
api/renew.go
|
@ -6,6 +6,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/api/render"
|
"github.com/smallstep/certificates/api/render"
|
||||||
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,14 +18,22 @@ const (
|
||||||
// Renew uses the information of certificate in the TLS connection to create a
|
// Renew uses the information of certificate in the TLS connection to create a
|
||||||
// new one.
|
// new one.
|
||||||
func Renew(w http.ResponseWriter, r *http.Request) {
|
func Renew(w http.ResponseWriter, r *http.Request) {
|
||||||
cert, err := getPeerCertificate(r)
|
ctx := r.Context()
|
||||||
|
|
||||||
|
// Get the leaf certificate from the peer or the token.
|
||||||
|
cert, token, err := getPeerCertificate(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
a := mustAuthority(r.Context())
|
// The token can be used by RAs to renew a certificate.
|
||||||
certChain, err := a.Renew(cert)
|
if token != "" {
|
||||||
|
ctx = authority.NewTokenContext(ctx, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
a := mustAuthority(ctx)
|
||||||
|
certChain, err := a.RenewContext(ctx, cert, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
||||||
return
|
return
|
||||||
|
@ -44,15 +53,16 @@ func Renew(w http.ResponseWriter, r *http.Request) {
|
||||||
}, http.StatusCreated)
|
}, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
|
func getPeerCertificate(r *http.Request) (*x509.Certificate, string, error) {
|
||||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||||
return r.TLS.PeerCertificates[0], nil
|
return r.TLS.PeerCertificates[0], "", nil
|
||||||
}
|
}
|
||||||
if s := r.Header.Get(authorizationHeader); s != "" {
|
if s := r.Header.Get(authorizationHeader); s != "" {
|
||||||
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
|
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
|
peer, err := mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
|
||||||
|
return peer, parts[1], err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, errs.BadRequest("missing client certificate")
|
return nil, "", errs.BadRequest("missing client certificate")
|
||||||
}
|
}
|
||||||
|
|
|
@ -286,7 +286,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
|
||||||
// extra extension cannot be found, authorize the renewal by default.
|
// extra extension cannot be found, authorize the renewal by default.
|
||||||
//
|
//
|
||||||
// TODO(mariano): should we authorize by default?
|
// TODO(mariano): should we authorize by default?
|
||||||
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
func (a *Authority) authorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||||
serial := cert.SerialNumber.String()
|
serial := cert.SerialNumber.String()
|
||||||
var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)}
|
var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)}
|
||||||
|
|
||||||
|
@ -308,7 +308,7 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
||||||
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
|
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := p.AuthorizeRenew(context.Background(), cert); err != nil {
|
if err := p.AuthorizeRenew(ctx, cert); err != nil {
|
||||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
|
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -434,7 +434,7 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.
|
||||||
}
|
}
|
||||||
|
|
||||||
audiences := a.config.GetAudiences().Renew
|
audiences := a.config.GetAudiences().Renew
|
||||||
if !matchesAudience(claims.Audience, audiences) {
|
if !matchesAudience(claims.Audience, audiences) && !isRAProvisioner(p) {
|
||||||
return nil, errs.InternalServerErr(jose.ErrInvalidAudience, errs.WithMessage("error validating renew token: invalid audience claim (aud)"))
|
return nil, errs.InternalServerErr(jose.ErrInvalidAudience, errs.WithMessage("error validating renew token: invalid audience claim (aud)"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -876,7 +876,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
tc := genTestCase(t)
|
tc := genTestCase(t)
|
||||||
|
|
||||||
err := tc.auth.authorizeRenew(tc.cert)
|
err := tc.auth.authorizeRenew(context.Background(), tc.cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
var sc render.StatusCodedError
|
var sc render.StatusCodedError
|
||||||
|
@ -1459,6 +1459,37 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}))
|
}))
|
||||||
|
a4 := testAuthority(t)
|
||||||
|
a4.db = &db.MockAuthDB{
|
||||||
|
MUseToken: func(id, tok string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
},
|
||||||
|
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
|
||||||
|
return &db.CertificateData{
|
||||||
|
Provisioner: &db.ProvisionerData{ID: "Max:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk", Name: "Max"},
|
||||||
|
RaInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
t4, c4 := generateX5cToken(a1, signer, jose.Claims{
|
||||||
|
Audience: []string{"https://ra.example.com/1.0/renew"},
|
||||||
|
Subject: "test.example.com",
|
||||||
|
Issuer: "step-ca-client/1.0",
|
||||||
|
NotBefore: jose.NewNumericDate(now),
|
||||||
|
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||||
|
cert.NotBefore = now
|
||||||
|
cert.NotAfter = now.Add(time.Hour)
|
||||||
|
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||||
|
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||||
|
Value: b,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
|
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
|
||||||
Audience: []string{"https://example.com/1.0/renew"},
|
Audience: []string{"https://example.com/1.0/renew"},
|
||||||
Subject: "test.example.com",
|
Subject: "test.example.com",
|
||||||
|
@ -1627,6 +1658,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
||||||
{"ok", a1, args{ctx, t1}, c1, false},
|
{"ok", a1, args{ctx, t1}, c1, false},
|
||||||
{"ok expired cert", a1, args{ctx, t2}, c2, false},
|
{"ok expired cert", a1, args{ctx, t2}, c2, false},
|
||||||
{"ok provisioner issuer", a1, args{ctx, t3}, c3, false},
|
{"ok provisioner issuer", a1, args{ctx, t3}, c3, false},
|
||||||
|
{"ok ra provisioner", a4, args{ctx, t4}, c4, false},
|
||||||
{"fail token", a1, args{ctx, "not.a.token"}, nil, true},
|
{"fail token", a1, args{ctx, "not.a.token"}, nil, true},
|
||||||
{"fail token reuse", a1, args{ctx, t1}, nil, true},
|
{"fail token reuse", a1, args{ctx, t1}, nil, true},
|
||||||
{"fail token signature", a1, args{ctx, badSigner}, nil, true},
|
{"fail token signature", a1, args{ctx, badSigner}, nil, true},
|
||||||
|
|
|
@ -48,6 +48,22 @@ func wrapProvisioner(p provisioner.Interface, attData *provisioner.AttestationDa
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wrapRAProvisioner wraps the given provisioner with RA information.
|
||||||
|
func wrapRAProvisioner(p provisioner.Interface, raInfo *provisioner.RAInfo) *wrappedProvisioner {
|
||||||
|
return &wrappedProvisioner{
|
||||||
|
Interface: p,
|
||||||
|
raInfo: raInfo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isRAProvisioner returns if the given provisioner is an RA provisioner.
|
||||||
|
func isRAProvisioner(p provisioner.Interface) bool {
|
||||||
|
if rap, ok := p.(raProvisioner); ok {
|
||||||
|
return rap.RAInfo() != nil
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// wrappedProvisioner implements raProvisioner and attProvisioner.
|
// wrappedProvisioner implements raProvisioner and attProvisioner.
|
||||||
type wrappedProvisioner struct {
|
type wrappedProvisioner struct {
|
||||||
provisioner.Interface
|
provisioner.Interface
|
||||||
|
@ -119,6 +135,9 @@ func (a *Authority) unsafeLoadProvisionerFromDatabase(crt *x509.Certificate) (pr
|
||||||
}
|
}
|
||||||
if err == nil && data != nil && data.Provisioner != nil {
|
if err == nil && data != nil && data.Provisioner != nil {
|
||||||
if p, ok := a.provisioners.Load(data.Provisioner.ID); ok {
|
if p, ok := a.provisioners.Load(data.Provisioner.ID); ok {
|
||||||
|
if data.RaInfo != nil {
|
||||||
|
return wrapRAProvisioner(p, data.RaInfo), nil
|
||||||
|
}
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -333,3 +333,54 @@ func TestProvisionerWebhookToLinkedca(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_wrapRAProvisioner(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
p provisioner.Interface
|
||||||
|
raInfo *provisioner.RAInfo
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *wrappedProvisioner
|
||||||
|
}{
|
||||||
|
{"ok", args{&provisioner.JWK{Name: "jwt"}, &provisioner.RAInfo{ProvisionerName: "ra"}}, &wrappedProvisioner{
|
||||||
|
Interface: &provisioner.JWK{Name: "jwt"},
|
||||||
|
raInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := wrapRAProvisioner(tt.args.p, tt.args.raInfo); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("wrapRAProvisioner() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_isRAProvisioner(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
p provisioner.Interface
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"true", args{&wrappedProvisioner{
|
||||||
|
Interface: &provisioner.JWK{Name: "jwt"},
|
||||||
|
raInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
|
||||||
|
}}, true},
|
||||||
|
{"nil ra", args{&wrappedProvisioner{
|
||||||
|
Interface: &provisioner.JWK{Name: "jwt"},
|
||||||
|
}}, false},
|
||||||
|
{"not ra", args{&provisioner.JWK{Name: "jwt"}}, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := isRAProvisioner(tt.args.p); got != tt.want {
|
||||||
|
t.Errorf("isRAProvisioner() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -34,6 +34,19 @@ import (
|
||||||
"github.com/smallstep/nosql/database"
|
"github.com/smallstep/nosql/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type tokenKey struct{}
|
||||||
|
|
||||||
|
// NewTokenContext adds the given token to the context.
|
||||||
|
func NewTokenContext(ctx context.Context, token string) context.Context {
|
||||||
|
return context.WithValue(ctx, tokenKey{}, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenFromContext returns the token from the given context.
|
||||||
|
func TokenFromContext(ctx context.Context) (token string, ok bool) {
|
||||||
|
token, ok = ctx.Value(tokenKey{}).(string)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// GetTLSOptions returns the tls options configured.
|
// GetTLSOptions returns the tls options configured.
|
||||||
func (a *Authority) GetTLSOptions() *config.TLSOptions {
|
func (a *Authority) GetTLSOptions() *config.TLSOptions {
|
||||||
return a.config.TLS
|
return a.config.TLS
|
||||||
|
@ -294,28 +307,44 @@ func (a *Authority) AreSANsAllowed(ctx context.Context, sans []string) error {
|
||||||
return a.policyEngine.AreSANsAllowed(sans)
|
return a.policyEngine.AreSANsAllowed(sans)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Renew creates a new Certificate identical to the old certificate, except
|
// Renew creates a new Certificate identical to the old certificate, except with
|
||||||
// with a validity window that begins 'now'.
|
// a validity window that begins 'now'.
|
||||||
func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) {
|
func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) {
|
||||||
return a.Rekey(oldCert, nil)
|
return a.RenewContext(context.Background(), oldCert, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rekey is used for rekeying and renewing based on the public key.
|
// Rekey is used for rekeying and renewing based on the public key. If the
|
||||||
// If the public key is 'nil' then it's assumed that the cert should be renewed
|
// public key is 'nil' then it's assumed that the cert should be renewed using
|
||||||
// using the existing public key. If the public key is not 'nil' then it's
|
// the existing public key. If the public key is not 'nil' then it's assumed
|
||||||
// assumed that the cert should be rekeyed.
|
// that the cert should be rekeyed.
|
||||||
|
//
|
||||||
// For both Rekey and Renew all other attributes of the new certificate should
|
// For both Rekey and Renew all other attributes of the new certificate should
|
||||||
// match the old certificate. The exceptions are 'AuthorityKeyId' (which may
|
// match the old certificate. The exceptions are 'AuthorityKeyId' (which may
|
||||||
// have changed), 'SubjectKeyId' (different in case of rekey), and
|
// have changed), 'SubjectKeyId' (different in case of rekey), and
|
||||||
// 'NotBefore/NotAfter' (the validity duration of the new certificate should be
|
// 'NotBefore/NotAfter' (the validity duration of the new certificate should be
|
||||||
// equal to the old one, but starting 'now').
|
// equal to the old one, but starting 'now').
|
||||||
func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
||||||
|
return a.RenewContext(context.Background(), oldCert, pk)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RenewContext creates a new certificate identical to the old one, but it can
|
||||||
|
// optionally replace the public key with the given one. When running on RA
|
||||||
|
// mode, it can only renew a certificate using a renew token instead.
|
||||||
|
//
|
||||||
|
// For both rekey and renew operations, all other attributes of the new
|
||||||
|
// certificate should match the old certificate. The exceptions are
|
||||||
|
// 'AuthorityKeyId' (which may have changed), 'SubjectKeyId' (different in case
|
||||||
|
// of rekey), and 'NotBefore/NotAfter' (the validity duration of the new
|
||||||
|
// certificate should be equal to the old one, but starting 'now').
|
||||||
|
func (a *Authority) RenewContext(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
||||||
isRekey := (pk != nil)
|
isRekey := (pk != nil)
|
||||||
opts := []interface{}{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())}
|
opts := []errs.Option{
|
||||||
|
errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String()),
|
||||||
|
}
|
||||||
|
|
||||||
// Check step provisioner extensions
|
// Check step provisioner extensions
|
||||||
if err := a.authorizeRenew(oldCert); err != nil {
|
if err := a.authorizeRenew(ctx, oldCert); err != nil {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...)
|
return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Durations
|
// Durations
|
||||||
|
@ -388,7 +417,7 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
|
||||||
if err := a.constraintsEngine.ValidateCertificate(newCert); err != nil {
|
if err := a.constraintsEngine.ValidateCertificate(newCert); err != nil {
|
||||||
var ee *errs.Error
|
var ee *errs.Error
|
||||||
if errors.As(err, &ee) {
|
if errors.As(err, &ee) {
|
||||||
return nil, errs.ApplyOptions(ee, opts...)
|
return nil, errs.StatusCodeError(ee.StatusCode(), err, opts...)
|
||||||
}
|
}
|
||||||
return nil, errs.InternalServerErr(err,
|
return nil, errs.InternalServerErr(err,
|
||||||
errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String()),
|
errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String()),
|
||||||
|
@ -396,19 +425,24 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The token can optionally be in the context. If the CA is running in RA
|
||||||
|
// mode, this can be used to renew a certificate.
|
||||||
|
token, _ := TokenFromContext(ctx)
|
||||||
|
|
||||||
resp, err := a.x509CAService.RenewCertificate(&casapi.RenewCertificateRequest{
|
resp, err := a.x509CAService.RenewCertificate(&casapi.RenewCertificateRequest{
|
||||||
Template: newCert,
|
Template: newCert,
|
||||||
Lifetime: lifetime,
|
Lifetime: lifetime,
|
||||||
Backdate: backdate,
|
Backdate: backdate,
|
||||||
|
Token: token,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...)
|
return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
|
fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
|
||||||
if err = a.storeRenewedCertificate(oldCert, fullchain); err != nil {
|
if err = a.storeRenewedCertificate(oldCert, fullchain); err != nil {
|
||||||
if !errors.Is(err, db.ErrNotImplemented) {
|
if !errors.Is(err, db.ErrNotImplemented) {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey; error storing certificate in db", opts...)
|
return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -992,14 +992,14 @@ func TestAuthority_Renew(t *testing.T) {
|
||||||
return &renewTest{
|
return &renewTest{
|
||||||
auth: _a,
|
auth: _a,
|
||||||
cert: cert,
|
cert: cert,
|
||||||
err: errors.New("authority.Rekey: error creating certificate"),
|
err: errors.New("error creating certificate"),
|
||||||
code: http.StatusInternalServerError,
|
code: http.StatusInternalServerError,
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
"fail/unauthorized": func() (*renewTest, error) {
|
"fail/unauthorized": func() (*renewTest, error) {
|
||||||
return &renewTest{
|
return &renewTest{
|
||||||
cert: certNoRenew,
|
cert: certNoRenew,
|
||||||
err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
|
err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
|
||||||
code: http.StatusUnauthorized,
|
code: http.StatusUnauthorized,
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
|
@ -1012,7 +1012,7 @@ func TestAuthority_Renew(t *testing.T) {
|
||||||
return &renewTest{
|
return &renewTest{
|
||||||
auth: aa,
|
auth: aa,
|
||||||
cert: cert,
|
cert: cert,
|
||||||
err: errors.New("authority.Rekey: authority.authorizeRenew: not authorized"),
|
err: errors.New("authority.authorizeRenew: not authorized"),
|
||||||
code: http.StatusUnauthorized,
|
code: http.StatusUnauthorized,
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
|
@ -1221,14 +1221,14 @@ func TestAuthority_Rekey(t *testing.T) {
|
||||||
return &renewTest{
|
return &renewTest{
|
||||||
auth: _a,
|
auth: _a,
|
||||||
cert: cert,
|
cert: cert,
|
||||||
err: errors.New("authority.Rekey: error creating certificate"),
|
err: errors.New("error creating certificate"),
|
||||||
code: http.StatusInternalServerError,
|
code: http.StatusInternalServerError,
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
"fail/unauthorized": func() (*renewTest, error) {
|
"fail/unauthorized": func() (*renewTest, error) {
|
||||||
return &renewTest{
|
return &renewTest{
|
||||||
cert: certNoRenew,
|
cert: certNoRenew,
|
||||||
err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
|
err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
|
||||||
code: http.StatusUnauthorized,
|
code: http.StatusUnauthorized,
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
|
|
|
@ -81,6 +81,7 @@ type RenewCertificateRequest struct {
|
||||||
CSR *x509.CertificateRequest
|
CSR *x509.CertificateRequest
|
||||||
Lifetime time.Duration
|
Lifetime time.Duration
|
||||||
Backdate time.Duration
|
Backdate time.Duration
|
||||||
|
Token string
|
||||||
RequestID string
|
RequestID string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -83,3 +83,23 @@ func (e NotImplementedError) Error() string {
|
||||||
func (e NotImplementedError) StatusCode() int {
|
func (e NotImplementedError) StatusCode() int {
|
||||||
return http.StatusNotImplemented
|
return http.StatusNotImplemented
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidationError is the type of error returned if request is not properly
|
||||||
|
// validated.
|
||||||
|
type ValidationError struct {
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotImplementedError implements the error interface.
|
||||||
|
func (e ValidationError) Error() string {
|
||||||
|
if e.Message != "" {
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
|
return "bad request"
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusCode implements the StatusCoder interface and returns the HTTP 400
|
||||||
|
// error.
|
||||||
|
func (e ValidationError) StatusCode() int {
|
||||||
|
return http.StatusBadRequest
|
||||||
|
}
|
||||||
|
|
|
@ -71,3 +71,51 @@ func TestNotImplementedError_StatusCode(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidationError_Error(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"default", fields{""}, "bad request"},
|
||||||
|
{"with message", fields{"token is empty"}, "token is empty"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
e := ValidationError{
|
||||||
|
Message: tt.fields.Message,
|
||||||
|
}
|
||||||
|
if got := e.Error(); got != tt.want {
|
||||||
|
t.Errorf("ValidationError.Error() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidationError_StatusCode(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"default", fields{""}, 400},
|
||||||
|
{"with message", fields{"token is empty"}, 400},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
e := ValidationError{
|
||||||
|
Message: tt.fields.Message,
|
||||||
|
}
|
||||||
|
if got := e.StatusCode(); got != tt.want {
|
||||||
|
t.Errorf("ValidationError.StatusCode() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -101,7 +101,25 @@ func (s *StepCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1
|
||||||
// RenewCertificate will always return a non-implemented error as mTLS renewals
|
// RenewCertificate will always return a non-implemented error as mTLS renewals
|
||||||
// are not supported yet.
|
// are not supported yet.
|
||||||
func (s *StepCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) {
|
func (s *StepCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) {
|
||||||
return nil, apiv1.NotImplementedError{Message: "stepCAS does not support mTLS renewals"}
|
if req.Token == "" {
|
||||||
|
return nil, apiv1.ValidationError{Message: "renewCertificateRequest `token` cannot be empty"}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := s.client.RenewWithToken(req.Token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var chain []*x509.Certificate
|
||||||
|
cert := resp.CertChainPEM[0].Certificate
|
||||||
|
for _, c := range resp.CertChainPEM[1:] {
|
||||||
|
chain = append(chain, c.Certificate)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &apiv1.RenewCertificateResponse{
|
||||||
|
Certificate: cert,
|
||||||
|
CertificateChain: chain,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RevokeCertificate revokes a certificate.
|
// RevokeCertificate revokes a certificate.
|
||||||
|
|
|
@ -147,6 +147,16 @@ func testCAHelper(t *testing.T) (*url.URL, *ca.Client) {
|
||||||
writeJSON(w, api.SignResponse{
|
writeJSON(w, api.SignResponse{
|
||||||
CertChainPEM: []api.Certificate{api.NewCertificate(testCrt), api.NewCertificate(testIssCrt)},
|
CertChainPEM: []api.Certificate{api.NewCertificate(testCrt), api.NewCertificate(testIssCrt)},
|
||||||
})
|
})
|
||||||
|
case r.RequestURI == "/renew":
|
||||||
|
if r.Header.Get("Authorization") == "Bearer fail" {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprintf(w, `{"error":"fail","message":"fail"}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
writeJSON(w, api.SignResponse{
|
||||||
|
CertChainPEM: []api.Certificate{api.NewCertificate(testCrt), api.NewCertificate(testIssCrt)},
|
||||||
|
})
|
||||||
case r.RequestURI == "/revoke":
|
case r.RequestURI == "/revoke":
|
||||||
var msg api.RevokeRequest
|
var msg api.RevokeRequest
|
||||||
parseJSON(r, &msg)
|
parseJSON(r, &msg)
|
||||||
|
@ -723,9 +733,14 @@ func TestStepCAS_CreateCertificate(t *testing.T) {
|
||||||
|
|
||||||
func TestStepCAS_RenewCertificate(t *testing.T) {
|
func TestStepCAS_RenewCertificate(t *testing.T) {
|
||||||
caURL, client := testCAHelper(t)
|
caURL, client := testCAHelper(t)
|
||||||
x5c := testX5CIssuer(t, caURL, "")
|
|
||||||
jwk := testJWKIssuer(t, caURL, "")
|
jwk := testJWKIssuer(t, caURL, "")
|
||||||
|
|
||||||
|
tokenIssuer := testX5CIssuer(t, caURL, "")
|
||||||
|
token, err := tokenIssuer.SignToken("test", []string{"test.example.com"}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
type fields struct {
|
type fields struct {
|
||||||
iss stepIssuer
|
iss stepIssuer
|
||||||
client *ca.Client
|
client *ca.Client
|
||||||
|
@ -741,13 +756,25 @@ func TestStepCAS_RenewCertificate(t *testing.T) {
|
||||||
want *apiv1.RenewCertificateResponse
|
want *apiv1.RenewCertificateResponse
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"not implemented", fields{x5c, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
|
{"ok", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
|
||||||
CSR: testCR,
|
Template: &x509.Certificate{},
|
||||||
|
Backdate: time.Minute,
|
||||||
|
Lifetime: time.Hour,
|
||||||
|
Token: token,
|
||||||
|
}}, &apiv1.RenewCertificateResponse{
|
||||||
|
Certificate: testCrt,
|
||||||
|
CertificateChain: []*x509.Certificate{testIssCrt},
|
||||||
|
}, false},
|
||||||
|
{"fail no token", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
|
||||||
|
Template: &x509.Certificate{},
|
||||||
|
Backdate: time.Minute,
|
||||||
Lifetime: time.Hour,
|
Lifetime: time.Hour,
|
||||||
}}, nil, true},
|
}}, nil, true},
|
||||||
{"not implemented jwk", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
|
{"fail bad token", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
|
||||||
CSR: testCR,
|
Template: &x509.Certificate{},
|
||||||
|
Backdate: time.Minute,
|
||||||
Lifetime: time.Hour,
|
Lifetime: time.Hour,
|
||||||
|
Token: "fail",
|
||||||
}}, nil, true},
|
}}, nil, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -763,7 +790,10 @@ func TestStepCAS_RenewCertificate(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
t.Errorf("StepCAS.RenewCertificate() = %v, want %v", got, tt.want)
|
t.Error(reflect.DeepEqual(got.Certificate, tt.want.Certificate))
|
||||||
|
t.Error(reflect.DeepEqual(got.CertificateChain, tt.want.CertificateChain))
|
||||||
|
|
||||||
|
t.Errorf("StepCAS.RenewCertificate() = %v, want %v", got.Certificate.Subject, tt.want.Certificate.Subject)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
40
db/db.go
40
db/db.go
|
@ -28,8 +28,9 @@ var (
|
||||||
sshHostPrincipalsTable = []byte("ssh_host_principals")
|
sshHostPrincipalsTable = []byte("ssh_host_principals")
|
||||||
)
|
)
|
||||||
|
|
||||||
var crlKey = []byte("crl") //TODO: at the moment we store a single CRL in the database, in a dedicated table.
|
// TODO: at the moment we store a single CRL in the database, in a dedicated table.
|
||||||
// is this acceptable? probably not....
|
// is this acceptable? probably not....
|
||||||
|
var crlKey = []byte("crl")
|
||||||
|
|
||||||
// ErrAlreadyExists can be returned if the DB attempts to set a key that has
|
// ErrAlreadyExists can be returned if the DB attempts to set a key that has
|
||||||
// been previously set.
|
// been previously set.
|
||||||
|
@ -323,7 +324,8 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error {
|
||||||
// CertificateData is the JSON representation of the data stored in
|
// CertificateData is the JSON representation of the data stored in
|
||||||
// x509_certs_data table.
|
// x509_certs_data table.
|
||||||
type CertificateData struct {
|
type CertificateData struct {
|
||||||
Provisioner *ProvisionerData `json:"provisioner,omitempty"`
|
Provisioner *ProvisionerData `json:"provisioner,omitempty"`
|
||||||
|
RaInfo *provisioner.RAInfo `json:"ra,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProvisionerData is the JSON representation of the provisioner stored in the
|
// ProvisionerData is the JSON representation of the provisioner stored in the
|
||||||
|
@ -334,6 +336,10 @@ type ProvisionerData struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type raProvisioner interface {
|
||||||
|
RAInfo() *provisioner.RAInfo
|
||||||
|
}
|
||||||
|
|
||||||
// StoreCertificateChain stores the leaf certificate and the provisioner that
|
// StoreCertificateChain stores the leaf certificate and the provisioner that
|
||||||
// authorized the certificate.
|
// authorized the certificate.
|
||||||
func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error {
|
func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error {
|
||||||
|
@ -346,6 +352,9 @@ func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Cert
|
||||||
Name: p.GetName(),
|
Name: p.GetName(),
|
||||||
Type: p.GetType().String(),
|
Type: p.GetType().String(),
|
||||||
}
|
}
|
||||||
|
if rap, ok := p.(raProvisioner); ok {
|
||||||
|
data.RaInfo = rap.RAInfo()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(data)
|
b, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -361,6 +370,31 @@ func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Cert
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StoreRenewedCertificate stores the leaf certificate and the provisioner that
|
||||||
|
// authorized the old certificate if available.
|
||||||
|
func (db *DB) StoreRenewedCertificate(oldCert *x509.Certificate, chain ...*x509.Certificate) error {
|
||||||
|
var certificateData []byte
|
||||||
|
if data, err := db.GetCertificateData(oldCert.SerialNumber.String()); err == nil {
|
||||||
|
if b, err := json.Marshal(data); err == nil {
|
||||||
|
certificateData = b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
leaf := chain[0]
|
||||||
|
serialNumber := []byte(leaf.SerialNumber.String())
|
||||||
|
|
||||||
|
// Add certificate and certificate data in one transaction.
|
||||||
|
tx := new(database.Tx)
|
||||||
|
tx.Set(certsTable, serialNumber, leaf.Raw)
|
||||||
|
if certificateData != nil {
|
||||||
|
tx.Set(certsDataTable, serialNumber, certificateData)
|
||||||
|
}
|
||||||
|
if err := db.Update(tx); err != nil {
|
||||||
|
return errors.Wrap(err, "database Update error")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// UseToken returns true if we were able to successfully store the token for
|
// UseToken returns true if we were able to successfully store the token for
|
||||||
// for the first time, false otherwise.
|
// for the first time, false otherwise.
|
||||||
func (db *DB) UseToken(id, tok string) (bool, error) {
|
func (db *DB) UseToken(id, tok string) (bool, error) {
|
||||||
|
|
142
db/db_test.go
142
db/db_test.go
|
@ -1,6 +1,7 @@
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
@ -164,12 +165,30 @@ func TestUseToken(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wrappedProvisioner implements raProvisioner and attProvisioner.
|
||||||
|
type wrappedProvisioner struct {
|
||||||
|
provisioner.Interface
|
||||||
|
raInfo *provisioner.RAInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *wrappedProvisioner) RAInfo() *provisioner.RAInfo {
|
||||||
|
return p.raInfo
|
||||||
|
}
|
||||||
|
|
||||||
func TestDB_StoreCertificateChain(t *testing.T) {
|
func TestDB_StoreCertificateChain(t *testing.T) {
|
||||||
p := &provisioner.JWK{
|
p := &provisioner.JWK{
|
||||||
ID: "some-id",
|
ID: "some-id",
|
||||||
Name: "admin",
|
Name: "admin",
|
||||||
Type: "JWK",
|
Type: "JWK",
|
||||||
}
|
}
|
||||||
|
rap := &wrappedProvisioner{
|
||||||
|
Interface: p,
|
||||||
|
raInfo: &provisioner.RAInfo{
|
||||||
|
ProvisionerID: "ra-id",
|
||||||
|
ProvisionerType: "JWK",
|
||||||
|
ProvisionerName: "ra",
|
||||||
|
},
|
||||||
|
}
|
||||||
chain := []*x509.Certificate{
|
chain := []*x509.Certificate{
|
||||||
{Raw: []byte("the certificate"), SerialNumber: big.NewInt(1234)},
|
{Raw: []byte("the certificate"), SerialNumber: big.NewInt(1234)},
|
||||||
}
|
}
|
||||||
|
@ -201,6 +220,21 @@ func TestDB_StoreCertificateChain(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}, true}, args{p, chain}, false},
|
}, true}, args{p, chain}, false},
|
||||||
|
{"ok ra provisioner", fields{&MockNoSQLDB{
|
||||||
|
MUpdate: func(tx *database.Tx) error {
|
||||||
|
if len(tx.Operations) != 2 {
|
||||||
|
t.Fatal("unexpected number of operations")
|
||||||
|
}
|
||||||
|
assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket)
|
||||||
|
assert.Equals(t, []byte("1234"), tx.Operations[0].Key)
|
||||||
|
assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value)
|
||||||
|
assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket)
|
||||||
|
assert.Equals(t, []byte("1234"), tx.Operations[1].Key)
|
||||||
|
assert.Equals(t, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"},"ra":{"provisionerId":"ra-id","provisionerType":"JWK","provisionerName":"ra"}}`), tx.Operations[1].Value)
|
||||||
|
assert.Equals(t, `{"provisioner":{"id":"some-id","name":"admin","type":"JWK"},"ra":{"provisionerId":"ra-id","provisionerType":"JWK","provisionerName":"ra"}}`, string(tx.Operations[1].Value))
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}, true}, args{rap, chain}, false},
|
||||||
{"ok no provisioner", fields{&MockNoSQLDB{
|
{"ok no provisioner", fields{&MockNoSQLDB{
|
||||||
MUpdate: func(tx *database.Tx) error {
|
MUpdate: func(tx *database.Tx) error {
|
||||||
if len(tx.Operations) != 2 {
|
if len(tx.Operations) != 2 {
|
||||||
|
@ -293,3 +327,111 @@ func TestDB_GetCertificateData(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDB_StoreRenewedCertificate(t *testing.T) {
|
||||||
|
oldCert := &x509.Certificate{SerialNumber: big.NewInt(1)}
|
||||||
|
chain := []*x509.Certificate{
|
||||||
|
&x509.Certificate{SerialNumber: big.NewInt(2), Raw: []byte("raw")},
|
||||||
|
&x509.Certificate{SerialNumber: big.NewInt(0)},
|
||||||
|
}
|
||||||
|
|
||||||
|
testErr := errors.New("test error")
|
||||||
|
certsData := []byte(`{"provisioner":{"id":"p","name":"name","type":"JWK"},"ra":{"provisionerId":"rap","provisionerType":"JWK","provisionerName":"rapname"}}`)
|
||||||
|
matchOperation := func(op *database.TxEntry, bucket, key, value []byte) bool {
|
||||||
|
return bytes.Equal(op.Bucket, bucket) && bytes.Equal(op.Key, key) && bytes.Equal(op.Value, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
DB nosql.DB
|
||||||
|
isUp bool
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
oldCert *x509.Certificate
|
||||||
|
chain []*x509.Certificate
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", fields{&MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
if bytes.Equal(bucket, certsDataTable) && bytes.Equal(key, []byte("1")) {
|
||||||
|
return certsData, nil
|
||||||
|
}
|
||||||
|
t.Error("ok failed: unexpected get")
|
||||||
|
return nil, testErr
|
||||||
|
},
|
||||||
|
MUpdate: func(tx *database.Tx) error {
|
||||||
|
if len(tx.Operations) != 2 {
|
||||||
|
t.Error("ok failed: unexpected number of operations")
|
||||||
|
return testErr
|
||||||
|
}
|
||||||
|
op0, op1 := tx.Operations[0], tx.Operations[1]
|
||||||
|
if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) {
|
||||||
|
t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value)
|
||||||
|
return testErr
|
||||||
|
}
|
||||||
|
if !matchOperation(op1, certsDataTable, []byte("2"), certsData) {
|
||||||
|
t.Errorf("ok failed: unexpected entry 1, %s[%s]=%s", op1.Bucket, op1.Key, op1.Value)
|
||||||
|
return testErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}, true}, args{oldCert, chain}, false},
|
||||||
|
{"ok no data", fields{&MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
return nil, database.ErrNotFound
|
||||||
|
},
|
||||||
|
MUpdate: func(tx *database.Tx) error {
|
||||||
|
if len(tx.Operations) != 1 {
|
||||||
|
t.Error("ok failed: unexpected number of operations")
|
||||||
|
return testErr
|
||||||
|
}
|
||||||
|
op0 := tx.Operations[0]
|
||||||
|
if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) {
|
||||||
|
t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value)
|
||||||
|
return testErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}, true}, args{oldCert, chain}, false},
|
||||||
|
{"ok fail marshal", fields{&MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
return []byte(`{"bad":"json"`), nil
|
||||||
|
},
|
||||||
|
MUpdate: func(tx *database.Tx) error {
|
||||||
|
if len(tx.Operations) != 1 {
|
||||||
|
t.Error("ok failed: unexpected number of operations")
|
||||||
|
return testErr
|
||||||
|
}
|
||||||
|
op0 := tx.Operations[0]
|
||||||
|
if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) {
|
||||||
|
t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value)
|
||||||
|
return testErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}, true}, args{oldCert, chain}, false},
|
||||||
|
{"fail", fields{&MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
return certsData, nil
|
||||||
|
},
|
||||||
|
MUpdate: func(tx *database.Tx) error {
|
||||||
|
return testErr
|
||||||
|
},
|
||||||
|
}, true}, args{oldCert, chain}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
db := &DB{
|
||||||
|
DB: tt.fields.DB,
|
||||||
|
isUp: tt.fields.isUp,
|
||||||
|
}
|
||||||
|
if err := db.StoreRenewedCertificate(tt.args.oldCert, tt.args.chain...); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("DB.StoreRenewedCertificate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue