Merge pull request #1156 from smallstep/ra-renew

Add support for renew when using stepcas
This commit is contained in:
Mariano Cano 2022-11-07 15:36:01 -08:00 committed by GitHub
commit e8726d24fa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 487 additions and 39 deletions

View file

@ -40,6 +40,7 @@ type Authority interface {
Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
Renew(peer *x509.Certificate) ([]*x509.Certificate, error)
RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error)
LoadProvisionerByName(string) (provisioner.Interface, error)

View file

@ -192,6 +192,7 @@ type mockAuthority struct {
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
loadProvisionerByName func(name string) (provisioner.Interface, error)
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
@ -264,6 +265,13 @@ func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, erro
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}
func (m *mockAuthority) RenewContext(ctx context.Context, oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
if m.renewContext != nil {
return m.renewContext(ctx, oldcert, pk)
}
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}
func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
if m.rekey != nil {
return m.rekey(oldcert, pk)

View file

@ -6,6 +6,7 @@ import (
"strings"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/errs"
)
@ -17,14 +18,22 @@ const (
// Renew uses the information of certificate in the TLS connection to create a
// new one.
func Renew(w http.ResponseWriter, r *http.Request) {
cert, err := getPeerCertificate(r)
ctx := r.Context()
// Get the leaf certificate from the peer or the token.
cert, token, err := getPeerCertificate(r)
if err != nil {
render.Error(w, err)
return
}
a := mustAuthority(r.Context())
certChain, err := a.Renew(cert)
// The token can be used by RAs to renew a certificate.
if token != "" {
ctx = authority.NewTokenContext(ctx, token)
}
a := mustAuthority(ctx)
certChain, err := a.RenewContext(ctx, cert, nil)
if err != nil {
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
return
@ -44,15 +53,16 @@ func Renew(w http.ResponseWriter, r *http.Request) {
}, http.StatusCreated)
}
func getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
func getPeerCertificate(r *http.Request) (*x509.Certificate, string, error) {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
return r.TLS.PeerCertificates[0], nil
return r.TLS.PeerCertificates[0], "", nil
}
if s := r.Header.Get(authorizationHeader); s != "" {
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
ctx := r.Context()
return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
peer, err := mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
return peer, parts[1], err
}
}
return nil, errs.BadRequest("missing client certificate")
return nil, "", errs.BadRequest("missing client certificate")
}

View file

@ -286,7 +286,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
// extra extension cannot be found, authorize the renewal by default.
//
// TODO(mariano): should we authorize by default?
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
func (a *Authority) authorizeRenew(ctx context.Context, cert *x509.Certificate) error {
serial := cert.SerialNumber.String()
var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)}
@ -308,7 +308,7 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
}
}
if err := p.AuthorizeRenew(context.Background(), cert); err != nil {
if err := p.AuthorizeRenew(ctx, cert); err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
}
return nil
@ -434,7 +434,7 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.
}
audiences := a.config.GetAudiences().Renew
if !matchesAudience(claims.Audience, audiences) {
if !matchesAudience(claims.Audience, audiences) && !isRAProvisioner(p) {
return nil, errs.InternalServerErr(jose.ErrInvalidAudience, errs.WithMessage("error validating renew token: invalid audience claim (aud)"))
}

View file

@ -876,7 +876,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
err := tc.auth.authorizeRenew(tc.cert)
err := tc.auth.authorizeRenew(context.Background(), tc.cert)
if err != nil {
if assert.NotNil(t, tc.err) {
var sc render.StatusCodedError
@ -1459,6 +1459,37 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
})
return nil
}))
a4 := testAuthority(t)
a4.db = &db.MockAuthDB{
MUseToken: func(id, tok string) (bool, error) {
return true, nil
},
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
return &db.CertificateData{
Provisioner: &db.ProvisionerData{ID: "Max:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk", Name: "Max"},
RaInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
}, nil
},
}
t4, c4 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://ra.example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
@ -1627,6 +1658,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
{"ok", a1, args{ctx, t1}, c1, false},
{"ok expired cert", a1, args{ctx, t2}, c2, false},
{"ok provisioner issuer", a1, args{ctx, t3}, c3, false},
{"ok ra provisioner", a4, args{ctx, t4}, c4, false},
{"fail token", a1, args{ctx, "not.a.token"}, nil, true},
{"fail token reuse", a1, args{ctx, t1}, nil, true},
{"fail token signature", a1, args{ctx, badSigner}, nil, true},

View file

@ -48,6 +48,22 @@ func wrapProvisioner(p provisioner.Interface, attData *provisioner.AttestationDa
}
}
// wrapRAProvisioner wraps the given provisioner with RA information.
func wrapRAProvisioner(p provisioner.Interface, raInfo *provisioner.RAInfo) *wrappedProvisioner {
return &wrappedProvisioner{
Interface: p,
raInfo: raInfo,
}
}
// isRAProvisioner returns if the given provisioner is an RA provisioner.
func isRAProvisioner(p provisioner.Interface) bool {
if rap, ok := p.(raProvisioner); ok {
return rap.RAInfo() != nil
}
return false
}
// wrappedProvisioner implements raProvisioner and attProvisioner.
type wrappedProvisioner struct {
provisioner.Interface
@ -119,6 +135,9 @@ func (a *Authority) unsafeLoadProvisionerFromDatabase(crt *x509.Certificate) (pr
}
if err == nil && data != nil && data.Provisioner != nil {
if p, ok := a.provisioners.Load(data.Provisioner.ID); ok {
if data.RaInfo != nil {
return wrapRAProvisioner(p, data.RaInfo), nil
}
return p, nil
}
}

View file

@ -333,3 +333,54 @@ func TestProvisionerWebhookToLinkedca(t *testing.T) {
})
}
}
func Test_wrapRAProvisioner(t *testing.T) {
type args struct {
p provisioner.Interface
raInfo *provisioner.RAInfo
}
tests := []struct {
name string
args args
want *wrappedProvisioner
}{
{"ok", args{&provisioner.JWK{Name: "jwt"}, &provisioner.RAInfo{ProvisionerName: "ra"}}, &wrappedProvisioner{
Interface: &provisioner.JWK{Name: "jwt"},
raInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := wrapRAProvisioner(tt.args.p, tt.args.raInfo); !reflect.DeepEqual(got, tt.want) {
t.Errorf("wrapRAProvisioner() = %v, want %v", got, tt.want)
}
})
}
}
func Test_isRAProvisioner(t *testing.T) {
type args struct {
p provisioner.Interface
}
tests := []struct {
name string
args args
want bool
}{
{"true", args{&wrappedProvisioner{
Interface: &provisioner.JWK{Name: "jwt"},
raInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
}}, true},
{"nil ra", args{&wrappedProvisioner{
Interface: &provisioner.JWK{Name: "jwt"},
}}, false},
{"not ra", args{&provisioner.JWK{Name: "jwt"}}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isRAProvisioner(tt.args.p); got != tt.want {
t.Errorf("isRAProvisioner() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -34,6 +34,19 @@ import (
"github.com/smallstep/nosql/database"
)
type tokenKey struct{}
// NewTokenContext adds the given token to the context.
func NewTokenContext(ctx context.Context, token string) context.Context {
return context.WithValue(ctx, tokenKey{}, token)
}
// TokenFromContext returns the token from the given context.
func TokenFromContext(ctx context.Context) (token string, ok bool) {
token, ok = ctx.Value(tokenKey{}).(string)
return
}
// GetTLSOptions returns the tls options configured.
func (a *Authority) GetTLSOptions() *config.TLSOptions {
return a.config.TLS
@ -294,28 +307,44 @@ func (a *Authority) AreSANsAllowed(ctx context.Context, sans []string) error {
return a.policyEngine.AreSANsAllowed(sans)
}
// Renew creates a new Certificate identical to the old certificate, except
// with a validity window that begins 'now'.
// Renew creates a new Certificate identical to the old certificate, except with
// a validity window that begins 'now'.
func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) {
return a.Rekey(oldCert, nil)
return a.RenewContext(context.Background(), oldCert, nil)
}
// Rekey is used for rekeying and renewing based on the public key.
// If the public key is 'nil' then it's assumed that the cert should be renewed
// using the existing public key. If the public key is not 'nil' then it's
// assumed that the cert should be rekeyed.
// Rekey is used for rekeying and renewing based on the public key. If the
// public key is 'nil' then it's assumed that the cert should be renewed using
// the existing public key. If the public key is not 'nil' then it's assumed
// that the cert should be rekeyed.
//
// For both Rekey and Renew all other attributes of the new certificate should
// match the old certificate. The exceptions are 'AuthorityKeyId' (which may
// have changed), 'SubjectKeyId' (different in case of rekey), and
// 'NotBefore/NotAfter' (the validity duration of the new certificate should be
// equal to the old one, but starting 'now').
func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
return a.RenewContext(context.Background(), oldCert, pk)
}
// RenewContext creates a new certificate identical to the old one, but it can
// optionally replace the public key with the given one. When running on RA
// mode, it can only renew a certificate using a renew token instead.
//
// For both rekey and renew operations, all other attributes of the new
// certificate should match the old certificate. The exceptions are
// 'AuthorityKeyId' (which may have changed), 'SubjectKeyId' (different in case
// of rekey), and 'NotBefore/NotAfter' (the validity duration of the new
// certificate should be equal to the old one, but starting 'now').
func (a *Authority) RenewContext(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
isRekey := (pk != nil)
opts := []interface{}{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())}
opts := []errs.Option{
errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String()),
}
// Check step provisioner extensions
if err := a.authorizeRenew(oldCert); err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...)
if err := a.authorizeRenew(ctx, oldCert); err != nil {
return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...)
}
// Durations
@ -388,7 +417,7 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
if err := a.constraintsEngine.ValidateCertificate(newCert); err != nil {
var ee *errs.Error
if errors.As(err, &ee) {
return nil, errs.ApplyOptions(ee, opts...)
return nil, errs.StatusCodeError(ee.StatusCode(), err, opts...)
}
return nil, errs.InternalServerErr(err,
errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String()),
@ -396,19 +425,24 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
)
}
// The token can optionally be in the context. If the CA is running in RA
// mode, this can be used to renew a certificate.
token, _ := TokenFromContext(ctx)
resp, err := a.x509CAService.RenewCertificate(&casapi.RenewCertificateRequest{
Template: newCert,
Lifetime: lifetime,
Backdate: backdate,
Token: token,
})
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...)
return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...)
}
fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
if err = a.storeRenewedCertificate(oldCert, fullchain); err != nil {
if !errors.Is(err, db.ErrNotImplemented) {
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey; error storing certificate in db", opts...)
return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...)
}
}

View file

@ -992,14 +992,14 @@ func TestAuthority_Renew(t *testing.T) {
return &renewTest{
auth: _a,
cert: cert,
err: errors.New("authority.Rekey: error creating certificate"),
err: errors.New("error creating certificate"),
code: http.StatusInternalServerError,
}, nil
},
"fail/unauthorized": func() (*renewTest, error) {
return &renewTest{
cert: certNoRenew,
err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
code: http.StatusUnauthorized,
}, nil
},
@ -1012,7 +1012,7 @@ func TestAuthority_Renew(t *testing.T) {
return &renewTest{
auth: aa,
cert: cert,
err: errors.New("authority.Rekey: authority.authorizeRenew: not authorized"),
err: errors.New("authority.authorizeRenew: not authorized"),
code: http.StatusUnauthorized,
}, nil
},
@ -1221,14 +1221,14 @@ func TestAuthority_Rekey(t *testing.T) {
return &renewTest{
auth: _a,
cert: cert,
err: errors.New("authority.Rekey: error creating certificate"),
err: errors.New("error creating certificate"),
code: http.StatusInternalServerError,
}, nil
},
"fail/unauthorized": func() (*renewTest, error) {
return &renewTest{
cert: certNoRenew,
err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
code: http.StatusUnauthorized,
}, nil
},

View file

@ -81,6 +81,7 @@ type RenewCertificateRequest struct {
CSR *x509.CertificateRequest
Lifetime time.Duration
Backdate time.Duration
Token string
RequestID string
}

View file

@ -83,3 +83,23 @@ func (e NotImplementedError) Error() string {
func (e NotImplementedError) StatusCode() int {
return http.StatusNotImplemented
}
// ValidationError is the type of error returned if request is not properly
// validated.
type ValidationError struct {
Message string
}
// NotImplementedError implements the error interface.
func (e ValidationError) Error() string {
if e.Message != "" {
return e.Message
}
return "bad request"
}
// StatusCode implements the StatusCoder interface and returns the HTTP 400
// error.
func (e ValidationError) StatusCode() int {
return http.StatusBadRequest
}

View file

@ -71,3 +71,51 @@ func TestNotImplementedError_StatusCode(t *testing.T) {
})
}
}
func TestValidationError_Error(t *testing.T) {
type fields struct {
Message string
}
tests := []struct {
name string
fields fields
want string
}{
{"default", fields{""}, "bad request"},
{"with message", fields{"token is empty"}, "token is empty"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := ValidationError{
Message: tt.fields.Message,
}
if got := e.Error(); got != tt.want {
t.Errorf("ValidationError.Error() = %v, want %v", got, tt.want)
}
})
}
}
func TestValidationError_StatusCode(t *testing.T) {
type fields struct {
Message string
}
tests := []struct {
name string
fields fields
want int
}{
{"default", fields{""}, 400},
{"with message", fields{"token is empty"}, 400},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := ValidationError{
Message: tt.fields.Message,
}
if got := e.StatusCode(); got != tt.want {
t.Errorf("ValidationError.StatusCode() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -101,7 +101,25 @@ func (s *StepCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1
// RenewCertificate will always return a non-implemented error as mTLS renewals
// are not supported yet.
func (s *StepCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) {
return nil, apiv1.NotImplementedError{Message: "stepCAS does not support mTLS renewals"}
if req.Token == "" {
return nil, apiv1.ValidationError{Message: "renewCertificateRequest `token` cannot be empty"}
}
resp, err := s.client.RenewWithToken(req.Token)
if err != nil {
return nil, err
}
var chain []*x509.Certificate
cert := resp.CertChainPEM[0].Certificate
for _, c := range resp.CertChainPEM[1:] {
chain = append(chain, c.Certificate)
}
return &apiv1.RenewCertificateResponse{
Certificate: cert,
CertificateChain: chain,
}, nil
}
// RevokeCertificate revokes a certificate.

View file

@ -147,6 +147,16 @@ func testCAHelper(t *testing.T) (*url.URL, *ca.Client) {
writeJSON(w, api.SignResponse{
CertChainPEM: []api.Certificate{api.NewCertificate(testCrt), api.NewCertificate(testIssCrt)},
})
case r.RequestURI == "/renew":
if r.Header.Get("Authorization") == "Bearer fail" {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, `{"error":"fail","message":"fail"}`)
return
}
w.WriteHeader(http.StatusOK)
writeJSON(w, api.SignResponse{
CertChainPEM: []api.Certificate{api.NewCertificate(testCrt), api.NewCertificate(testIssCrt)},
})
case r.RequestURI == "/revoke":
var msg api.RevokeRequest
parseJSON(r, &msg)
@ -723,9 +733,14 @@ func TestStepCAS_CreateCertificate(t *testing.T) {
func TestStepCAS_RenewCertificate(t *testing.T) {
caURL, client := testCAHelper(t)
x5c := testX5CIssuer(t, caURL, "")
jwk := testJWKIssuer(t, caURL, "")
tokenIssuer := testX5CIssuer(t, caURL, "")
token, err := tokenIssuer.SignToken("test", []string{"test.example.com"}, nil)
if err != nil {
t.Fatal(err)
}
type fields struct {
iss stepIssuer
client *ca.Client
@ -741,13 +756,25 @@ func TestStepCAS_RenewCertificate(t *testing.T) {
want *apiv1.RenewCertificateResponse
wantErr bool
}{
{"not implemented", fields{x5c, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
CSR: testCR,
{"ok", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
Template: &x509.Certificate{},
Backdate: time.Minute,
Lifetime: time.Hour,
Token: token,
}}, &apiv1.RenewCertificateResponse{
Certificate: testCrt,
CertificateChain: []*x509.Certificate{testIssCrt},
}, false},
{"fail no token", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
Template: &x509.Certificate{},
Backdate: time.Minute,
Lifetime: time.Hour,
}}, nil, true},
{"not implemented jwk", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
CSR: testCR,
{"fail bad token", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{
Template: &x509.Certificate{},
Backdate: time.Minute,
Lifetime: time.Hour,
Token: "fail",
}}, nil, true},
}
for _, tt := range tests {
@ -763,7 +790,10 @@ func TestStepCAS_RenewCertificate(t *testing.T) {
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("StepCAS.RenewCertificate() = %v, want %v", got, tt.want)
t.Error(reflect.DeepEqual(got.Certificate, tt.want.Certificate))
t.Error(reflect.DeepEqual(got.CertificateChain, tt.want.CertificateChain))
t.Errorf("StepCAS.RenewCertificate() = %v, want %v", got.Certificate.Subject, tt.want.Certificate.Subject)
}
})
}

View file

@ -28,8 +28,9 @@ var (
sshHostPrincipalsTable = []byte("ssh_host_principals")
)
var crlKey = []byte("crl") //TODO: at the moment we store a single CRL in the database, in a dedicated table.
// TODO: at the moment we store a single CRL in the database, in a dedicated table.
// is this acceptable? probably not....
var crlKey = []byte("crl")
// ErrAlreadyExists can be returned if the DB attempts to set a key that has
// been previously set.
@ -324,6 +325,7 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error {
// x509_certs_data table.
type CertificateData struct {
Provisioner *ProvisionerData `json:"provisioner,omitempty"`
RaInfo *provisioner.RAInfo `json:"ra,omitempty"`
}
// ProvisionerData is the JSON representation of the provisioner stored in the
@ -334,6 +336,10 @@ type ProvisionerData struct {
Type string `json:"type"`
}
type raProvisioner interface {
RAInfo() *provisioner.RAInfo
}
// StoreCertificateChain stores the leaf certificate and the provisioner that
// authorized the certificate.
func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error {
@ -346,6 +352,9 @@ func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Cert
Name: p.GetName(),
Type: p.GetType().String(),
}
if rap, ok := p.(raProvisioner); ok {
data.RaInfo = rap.RAInfo()
}
}
b, err := json.Marshal(data)
if err != nil {
@ -361,6 +370,31 @@ func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Cert
return nil
}
// StoreRenewedCertificate stores the leaf certificate and the provisioner that
// authorized the old certificate if available.
func (db *DB) StoreRenewedCertificate(oldCert *x509.Certificate, chain ...*x509.Certificate) error {
var certificateData []byte
if data, err := db.GetCertificateData(oldCert.SerialNumber.String()); err == nil {
if b, err := json.Marshal(data); err == nil {
certificateData = b
}
}
leaf := chain[0]
serialNumber := []byte(leaf.SerialNumber.String())
// Add certificate and certificate data in one transaction.
tx := new(database.Tx)
tx.Set(certsTable, serialNumber, leaf.Raw)
if certificateData != nil {
tx.Set(certsDataTable, serialNumber, certificateData)
}
if err := db.Update(tx); err != nil {
return errors.Wrap(err, "database Update error")
}
return nil
}
// UseToken returns true if we were able to successfully store the token for
// for the first time, false otherwise.
func (db *DB) UseToken(id, tok string) (bool, error) {

View file

@ -1,6 +1,7 @@
package db
import (
"bytes"
"crypto/x509"
"errors"
"math/big"
@ -164,12 +165,30 @@ func TestUseToken(t *testing.T) {
}
}
// wrappedProvisioner implements raProvisioner and attProvisioner.
type wrappedProvisioner struct {
provisioner.Interface
raInfo *provisioner.RAInfo
}
func (p *wrappedProvisioner) RAInfo() *provisioner.RAInfo {
return p.raInfo
}
func TestDB_StoreCertificateChain(t *testing.T) {
p := &provisioner.JWK{
ID: "some-id",
Name: "admin",
Type: "JWK",
}
rap := &wrappedProvisioner{
Interface: p,
raInfo: &provisioner.RAInfo{
ProvisionerID: "ra-id",
ProvisionerType: "JWK",
ProvisionerName: "ra",
},
}
chain := []*x509.Certificate{
{Raw: []byte("the certificate"), SerialNumber: big.NewInt(1234)},
}
@ -201,6 +220,21 @@ func TestDB_StoreCertificateChain(t *testing.T) {
return nil
},
}, true}, args{p, chain}, false},
{"ok ra provisioner", fields{&MockNoSQLDB{
MUpdate: func(tx *database.Tx) error {
if len(tx.Operations) != 2 {
t.Fatal("unexpected number of operations")
}
assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket)
assert.Equals(t, []byte("1234"), tx.Operations[0].Key)
assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value)
assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket)
assert.Equals(t, []byte("1234"), tx.Operations[1].Key)
assert.Equals(t, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"},"ra":{"provisionerId":"ra-id","provisionerType":"JWK","provisionerName":"ra"}}`), tx.Operations[1].Value)
assert.Equals(t, `{"provisioner":{"id":"some-id","name":"admin","type":"JWK"},"ra":{"provisionerId":"ra-id","provisionerType":"JWK","provisionerName":"ra"}}`, string(tx.Operations[1].Value))
return nil
},
}, true}, args{rap, chain}, false},
{"ok no provisioner", fields{&MockNoSQLDB{
MUpdate: func(tx *database.Tx) error {
if len(tx.Operations) != 2 {
@ -293,3 +327,111 @@ func TestDB_GetCertificateData(t *testing.T) {
})
}
}
func TestDB_StoreRenewedCertificate(t *testing.T) {
oldCert := &x509.Certificate{SerialNumber: big.NewInt(1)}
chain := []*x509.Certificate{
&x509.Certificate{SerialNumber: big.NewInt(2), Raw: []byte("raw")},
&x509.Certificate{SerialNumber: big.NewInt(0)},
}
testErr := errors.New("test error")
certsData := []byte(`{"provisioner":{"id":"p","name":"name","type":"JWK"},"ra":{"provisionerId":"rap","provisionerType":"JWK","provisionerName":"rapname"}}`)
matchOperation := func(op *database.TxEntry, bucket, key, value []byte) bool {
return bytes.Equal(op.Bucket, bucket) && bytes.Equal(op.Key, key) && bytes.Equal(op.Value, value)
}
type fields struct {
DB nosql.DB
isUp bool
}
type args struct {
oldCert *x509.Certificate
chain []*x509.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if bytes.Equal(bucket, certsDataTable) && bytes.Equal(key, []byte("1")) {
return certsData, nil
}
t.Error("ok failed: unexpected get")
return nil, testErr
},
MUpdate: func(tx *database.Tx) error {
if len(tx.Operations) != 2 {
t.Error("ok failed: unexpected number of operations")
return testErr
}
op0, op1 := tx.Operations[0], tx.Operations[1]
if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) {
t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value)
return testErr
}
if !matchOperation(op1, certsDataTable, []byte("2"), certsData) {
t.Errorf("ok failed: unexpected entry 1, %s[%s]=%s", op1.Bucket, op1.Key, op1.Value)
return testErr
}
return nil
},
}, true}, args{oldCert, chain}, false},
{"ok no data", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, database.ErrNotFound
},
MUpdate: func(tx *database.Tx) error {
if len(tx.Operations) != 1 {
t.Error("ok failed: unexpected number of operations")
return testErr
}
op0 := tx.Operations[0]
if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) {
t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value)
return testErr
}
return nil
},
}, true}, args{oldCert, chain}, false},
{"ok fail marshal", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return []byte(`{"bad":"json"`), nil
},
MUpdate: func(tx *database.Tx) error {
if len(tx.Operations) != 1 {
t.Error("ok failed: unexpected number of operations")
return testErr
}
op0 := tx.Operations[0]
if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) {
t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value)
return testErr
}
return nil
},
}, true}, args{oldCert, chain}, false},
{"fail", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return certsData, nil
},
MUpdate: func(tx *database.Tx) error {
return testErr
},
}, true}, args{oldCert, chain}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &DB{
DB: tt.fields.DB,
isUp: tt.fields.isUp,
}
if err := db.StoreRenewedCertificate(tt.args.oldCert, tt.args.chain...); (err != nil) != tt.wantErr {
t.Errorf("DB.StoreRenewedCertificate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}