From c7f226bcec732dda892d8755e1078559f7a0a69d Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 4 Nov 2022 16:42:07 -0700 Subject: [PATCH] Add support for renew when using stepcas It supports renewing X.509 certificates when an RA is configured with stepcas. This will only work when the renewal uses a token, and it won't work with mTLS. The audience cannot be properly verified when an RA is used, to avoid this we will get from the database if an RA was used to issue the initial certificate and we will accept the renew token. Fixes #1021 for stepcas --- api/api.go | 1 + api/api_test.go | 8 ++ api/renew.go | 24 ++++-- authority/authorize.go | 6 +- authority/authorize_test.go | 34 +++++++- authority/provisioners.go | 19 +++++ authority/provisioners_test.go | 51 ++++++++++++ authority/tls.go | 60 +++++++++++--- authority/tls_test.go | 10 +-- cas/apiv1/requests.go | 1 + cas/apiv1/services.go | 20 +++++ cas/apiv1/services_test.go | 48 +++++++++++ cas/stepcas/stepcas.go | 20 ++++- cas/stepcas/stepcas_test.go | 42 ++++++++-- db/db.go | 40 +++++++++- db/db_test.go | 142 +++++++++++++++++++++++++++++++++ 16 files changed, 487 insertions(+), 39 deletions(-) diff --git a/api/api.go b/api/api.go index fda27c42..9c2f1f31 100644 --- a/api/api.go +++ b/api/api.go @@ -40,6 +40,7 @@ type Authority interface { Root(shasum string) (*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Renew(peer *x509.Certificate) ([]*x509.Certificate, error) + RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error) LoadProvisionerByName(string) (provisioner.Interface, error) diff --git a/api/api_test.go b/api/api_test.go index abbbbd5b..e24751b3 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -192,6 +192,7 @@ type mockAuthority struct { sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) renew func(cert *x509.Certificate) ([]*x509.Certificate, error) rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) + renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) loadProvisionerByName func(name string) (provisioner.Interface, error) getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) @@ -264,6 +265,13 @@ func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, erro return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } +func (m *mockAuthority) RenewContext(ctx context.Context, oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { + if m.renewContext != nil { + return m.renewContext(ctx, oldcert, pk) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { if m.rekey != nil { return m.rekey(oldcert, pk) diff --git a/api/renew.go b/api/renew.go index 6e9f680f..1b9ed95f 100644 --- a/api/renew.go +++ b/api/renew.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/errs" ) @@ -17,14 +18,22 @@ const ( // Renew uses the information of certificate in the TLS connection to create a // new one. func Renew(w http.ResponseWriter, r *http.Request) { - cert, err := getPeerCertificate(r) + ctx := r.Context() + + // Get the leaf certificate from the peer or the token. + cert, token, err := getPeerCertificate(r) if err != nil { render.Error(w, err) return } - a := mustAuthority(r.Context()) - certChain, err := a.Renew(cert) + // The token can be used by RAs to renew a certificate. + if token != "" { + ctx = authority.NewTokenContext(ctx, token) + } + + a := mustAuthority(ctx) + certChain, err := a.RenewContext(ctx, cert, nil) if err != nil { render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) return @@ -44,15 +53,16 @@ func Renew(w http.ResponseWriter, r *http.Request) { }, http.StatusCreated) } -func getPeerCertificate(r *http.Request) (*x509.Certificate, error) { +func getPeerCertificate(r *http.Request) (*x509.Certificate, string, error) { if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { - return r.TLS.PeerCertificates[0], nil + return r.TLS.PeerCertificates[0], "", nil } if s := r.Header.Get(authorizationHeader); s != "" { if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 { ctx := r.Context() - return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1]) + peer, err := mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1]) + return peer, parts[1], err } } - return nil, errs.BadRequest("missing client certificate") + return nil, "", errs.BadRequest("missing client certificate") } diff --git a/authority/authorize.go b/authority/authorize.go index 44956cbd..1e50da89 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -286,7 +286,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error { // extra extension cannot be found, authorize the renewal by default. // // TODO(mariano): should we authorize by default? -func (a *Authority) authorizeRenew(cert *x509.Certificate) error { +func (a *Authority) authorizeRenew(ctx context.Context, cert *x509.Certificate) error { serial := cert.SerialNumber.String() var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)} @@ -308,7 +308,7 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error { return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...) } } - if err := p.AuthorizeRenew(context.Background(), cert); err != nil { + if err := p.AuthorizeRenew(ctx, cert); err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) } return nil @@ -434,7 +434,7 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509. } audiences := a.config.GetAudiences().Renew - if !matchesAudience(claims.Audience, audiences) { + if !matchesAudience(claims.Audience, audiences) && !isRAProvisioner(p) { return nil, errs.InternalServerErr(jose.ErrInvalidAudience, errs.WithMessage("error validating renew token: invalid audience claim (aud)")) } diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 7dc22f3a..bec34fd6 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -876,7 +876,7 @@ func TestAuthority_authorizeRenew(t *testing.T) { t.Run(name, func(t *testing.T) { tc := genTestCase(t) - err := tc.auth.authorizeRenew(tc.cert) + err := tc.auth.authorizeRenew(context.Background(), tc.cert) if err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError @@ -1459,6 +1459,37 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) { }) return nil })) + a4 := testAuthority(t) + a4.db = &db.MockAuthDB{ + MUseToken: func(id, tok string) (bool, error) { + return true, nil + }, + MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) { + return &db.CertificateData{ + Provisioner: &db.ProvisionerData{ID: "Max:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk", Name: "Max"}, + RaInfo: &provisioner.RAInfo{ProvisionerName: "ra"}, + }, nil + }, + } + t4, c4 := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://ra.example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "step-ca-client/1.0", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{ Audience: []string{"https://example.com/1.0/renew"}, Subject: "test.example.com", @@ -1627,6 +1658,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) { {"ok", a1, args{ctx, t1}, c1, false}, {"ok expired cert", a1, args{ctx, t2}, c2, false}, {"ok provisioner issuer", a1, args{ctx, t3}, c3, false}, + {"ok ra provisioner", a4, args{ctx, t4}, c4, false}, {"fail token", a1, args{ctx, "not.a.token"}, nil, true}, {"fail token reuse", a1, args{ctx, t1}, nil, true}, {"fail token signature", a1, args{ctx, badSigner}, nil, true}, diff --git a/authority/provisioners.go b/authority/provisioners.go index bfa4eae5..d8a7b4d1 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -48,6 +48,22 @@ func wrapProvisioner(p provisioner.Interface, attData *provisioner.AttestationDa } } +// wrapRAProvisioner wraps the given provisioner with RA information. +func wrapRAProvisioner(p provisioner.Interface, raInfo *provisioner.RAInfo) *wrappedProvisioner { + return &wrappedProvisioner{ + Interface: p, + raInfo: raInfo, + } +} + +// isRAProvisioner returns if the given provisioner is an RA provisioner. +func isRAProvisioner(p provisioner.Interface) bool { + if rap, ok := p.(raProvisioner); ok { + return rap.RAInfo() != nil + } + return false +} + // wrappedProvisioner implements raProvisioner and attProvisioner. type wrappedProvisioner struct { provisioner.Interface @@ -119,6 +135,9 @@ func (a *Authority) unsafeLoadProvisionerFromDatabase(crt *x509.Certificate) (pr } if err == nil && data != nil && data.Provisioner != nil { if p, ok := a.provisioners.Load(data.Provisioner.ID); ok { + if data.RaInfo != nil { + return wrapRAProvisioner(p, data.RaInfo), nil + } return p, nil } } diff --git a/authority/provisioners_test.go b/authority/provisioners_test.go index 6ef62223..7901de6a 100644 --- a/authority/provisioners_test.go +++ b/authority/provisioners_test.go @@ -333,3 +333,54 @@ func TestProvisionerWebhookToLinkedca(t *testing.T) { }) } } + +func Test_wrapRAProvisioner(t *testing.T) { + type args struct { + p provisioner.Interface + raInfo *provisioner.RAInfo + } + tests := []struct { + name string + args args + want *wrappedProvisioner + }{ + {"ok", args{&provisioner.JWK{Name: "jwt"}, &provisioner.RAInfo{ProvisionerName: "ra"}}, &wrappedProvisioner{ + Interface: &provisioner.JWK{Name: "jwt"}, + raInfo: &provisioner.RAInfo{ProvisionerName: "ra"}, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := wrapRAProvisioner(tt.args.p, tt.args.raInfo); !reflect.DeepEqual(got, tt.want) { + t.Errorf("wrapRAProvisioner() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_isRAProvisioner(t *testing.T) { + type args struct { + p provisioner.Interface + } + tests := []struct { + name string + args args + want bool + }{ + {"true", args{&wrappedProvisioner{ + Interface: &provisioner.JWK{Name: "jwt"}, + raInfo: &provisioner.RAInfo{ProvisionerName: "ra"}, + }}, true}, + {"nil ra", args{&wrappedProvisioner{ + Interface: &provisioner.JWK{Name: "jwt"}, + }}, false}, + {"not ra", args{&provisioner.JWK{Name: "jwt"}}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isRAProvisioner(tt.args.p); got != tt.want { + t.Errorf("isRAProvisioner() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/tls.go b/authority/tls.go index b5d85074..11c61b9e 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -34,6 +34,19 @@ import ( "github.com/smallstep/nosql/database" ) +type tokenKey struct{} + +// NewTokenContext adds the given token to the context. +func NewTokenContext(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, tokenKey{}, token) +} + +// TokenFromContext returns the token from the given context. +func TokenFromContext(ctx context.Context) (token string, ok bool) { + token, ok = ctx.Value(tokenKey{}).(string) + return +} + // GetTLSOptions returns the tls options configured. func (a *Authority) GetTLSOptions() *config.TLSOptions { return a.config.TLS @@ -294,28 +307,44 @@ func (a *Authority) AreSANsAllowed(ctx context.Context, sans []string) error { return a.policyEngine.AreSANsAllowed(sans) } -// Renew creates a new Certificate identical to the old certificate, except -// with a validity window that begins 'now'. +// Renew creates a new Certificate identical to the old certificate, except with +// a validity window that begins 'now'. func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) { - return a.Rekey(oldCert, nil) + return a.RenewContext(context.Background(), oldCert, nil) } -// Rekey is used for rekeying and renewing based on the public key. -// If the public key is 'nil' then it's assumed that the cert should be renewed -// using the existing public key. If the public key is not 'nil' then it's -// assumed that the cert should be rekeyed. +// Rekey is used for rekeying and renewing based on the public key. If the +// public key is 'nil' then it's assumed that the cert should be renewed using +// the existing public key. If the public key is not 'nil' then it's assumed +// that the cert should be rekeyed. +// // For both Rekey and Renew all other attributes of the new certificate should // match the old certificate. The exceptions are 'AuthorityKeyId' (which may // have changed), 'SubjectKeyId' (different in case of rekey), and // 'NotBefore/NotAfter' (the validity duration of the new certificate should be // equal to the old one, but starting 'now'). func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { + return a.RenewContext(context.Background(), oldCert, pk) +} + +// RenewContext creates a new certificate identical to the old one, but it can +// optionally replace the public key with the given one. When running on RA +// mode, it can only renew a certificate using a renew token instead. +// +// For both rekey and renew operations, all other attributes of the new +// certificate should match the old certificate. The exceptions are +// 'AuthorityKeyId' (which may have changed), 'SubjectKeyId' (different in case +// of rekey), and 'NotBefore/NotAfter' (the validity duration of the new +// certificate should be equal to the old one, but starting 'now'). +func (a *Authority) RenewContext(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { isRekey := (pk != nil) - opts := []interface{}{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())} + opts := []errs.Option{ + errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String()), + } // Check step provisioner extensions - if err := a.authorizeRenew(oldCert); err != nil { - return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...) + if err := a.authorizeRenew(ctx, oldCert); err != nil { + return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...) } // Durations @@ -388,7 +417,7 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5 if err := a.constraintsEngine.ValidateCertificate(newCert); err != nil { var ee *errs.Error if errors.As(err, &ee) { - return nil, errs.ApplyOptions(ee, opts...) + return nil, errs.StatusCodeError(ee.StatusCode(), err, opts...) } return nil, errs.InternalServerErr(err, errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String()), @@ -396,19 +425,24 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5 ) } + // The token can optionally be in the context. If the CA is running in RA + // mode, this can be used to renew a certificate. + token, _ := TokenFromContext(ctx) + resp, err := a.x509CAService.RenewCertificate(&casapi.RenewCertificateRequest{ Template: newCert, Lifetime: lifetime, Backdate: backdate, + Token: token, }) if err != nil { - return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...) + return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...) } fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...) if err = a.storeRenewedCertificate(oldCert, fullchain); err != nil { if !errors.Is(err, db.ErrNotImplemented) { - return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey; error storing certificate in db", opts...) + return nil, errs.StatusCodeError(http.StatusInternalServerError, err, opts...) } } diff --git a/authority/tls_test.go b/authority/tls_test.go index 918adbdc..5d63b3dd 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -992,14 +992,14 @@ func TestAuthority_Renew(t *testing.T) { return &renewTest{ auth: _a, cert: cert, - err: errors.New("authority.Rekey: error creating certificate"), + err: errors.New("error creating certificate"), code: http.StatusInternalServerError, }, nil }, "fail/unauthorized": func() (*renewTest, error) { return &renewTest{ cert: certNoRenew, - err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"), + err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'dev'"), code: http.StatusUnauthorized, }, nil }, @@ -1012,7 +1012,7 @@ func TestAuthority_Renew(t *testing.T) { return &renewTest{ auth: aa, cert: cert, - err: errors.New("authority.Rekey: authority.authorizeRenew: not authorized"), + err: errors.New("authority.authorizeRenew: not authorized"), code: http.StatusUnauthorized, }, nil }, @@ -1221,14 +1221,14 @@ func TestAuthority_Rekey(t *testing.T) { return &renewTest{ auth: _a, cert: cert, - err: errors.New("authority.Rekey: error creating certificate"), + err: errors.New("error creating certificate"), code: http.StatusInternalServerError, }, nil }, "fail/unauthorized": func() (*renewTest, error) { return &renewTest{ cert: certNoRenew, - err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"), + err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'dev'"), code: http.StatusUnauthorized, }, nil }, diff --git a/cas/apiv1/requests.go b/cas/apiv1/requests.go index eff53a77..fdbb285e 100644 --- a/cas/apiv1/requests.go +++ b/cas/apiv1/requests.go @@ -81,6 +81,7 @@ type RenewCertificateRequest struct { CSR *x509.CertificateRequest Lifetime time.Duration Backdate time.Duration + Token string RequestID string } diff --git a/cas/apiv1/services.go b/cas/apiv1/services.go index f1d02b3c..f10a3e17 100644 --- a/cas/apiv1/services.go +++ b/cas/apiv1/services.go @@ -83,3 +83,23 @@ func (e NotImplementedError) Error() string { func (e NotImplementedError) StatusCode() int { return http.StatusNotImplemented } + +// ValidationError is the type of error returned if request is not properly +// validated. +type ValidationError struct { + Message string +} + +// NotImplementedError implements the error interface. +func (e ValidationError) Error() string { + if e.Message != "" { + return e.Message + } + return "bad request" +} + +// StatusCode implements the StatusCoder interface and returns the HTTP 400 +// error. +func (e ValidationError) StatusCode() int { + return http.StatusBadRequest +} diff --git a/cas/apiv1/services_test.go b/cas/apiv1/services_test.go index f8e16138..9289de76 100644 --- a/cas/apiv1/services_test.go +++ b/cas/apiv1/services_test.go @@ -71,3 +71,51 @@ func TestNotImplementedError_StatusCode(t *testing.T) { }) } } + +func TestValidationError_Error(t *testing.T) { + type fields struct { + Message string + } + tests := []struct { + name string + fields fields + want string + }{ + {"default", fields{""}, "bad request"}, + {"with message", fields{"token is empty"}, "token is empty"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := ValidationError{ + Message: tt.fields.Message, + } + if got := e.Error(); got != tt.want { + t.Errorf("ValidationError.Error() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestValidationError_StatusCode(t *testing.T) { + type fields struct { + Message string + } + tests := []struct { + name string + fields fields + want int + }{ + {"default", fields{""}, 400}, + {"with message", fields{"token is empty"}, 400}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := ValidationError{ + Message: tt.fields.Message, + } + if got := e.StatusCode(); got != tt.want { + t.Errorf("ValidationError.StatusCode() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cas/stepcas/stepcas.go b/cas/stepcas/stepcas.go index 6c2acc84..c64963e6 100644 --- a/cas/stepcas/stepcas.go +++ b/cas/stepcas/stepcas.go @@ -101,7 +101,25 @@ func (s *StepCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1 // RenewCertificate will always return a non-implemented error as mTLS renewals // are not supported yet. func (s *StepCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) { - return nil, apiv1.NotImplementedError{Message: "stepCAS does not support mTLS renewals"} + if req.Token == "" { + return nil, apiv1.ValidationError{Message: "renewCertificateRequest `token` cannot be empty"} + } + + resp, err := s.client.RenewWithToken(req.Token) + if err != nil { + return nil, err + } + + var chain []*x509.Certificate + cert := resp.CertChainPEM[0].Certificate + for _, c := range resp.CertChainPEM[1:] { + chain = append(chain, c.Certificate) + } + + return &apiv1.RenewCertificateResponse{ + Certificate: cert, + CertificateChain: chain, + }, nil } // RevokeCertificate revokes a certificate. diff --git a/cas/stepcas/stepcas_test.go b/cas/stepcas/stepcas_test.go index cc8ea72e..6691a4b4 100644 --- a/cas/stepcas/stepcas_test.go +++ b/cas/stepcas/stepcas_test.go @@ -147,6 +147,16 @@ func testCAHelper(t *testing.T) (*url.URL, *ca.Client) { writeJSON(w, api.SignResponse{ CertChainPEM: []api.Certificate{api.NewCertificate(testCrt), api.NewCertificate(testIssCrt)}, }) + case r.RequestURI == "/renew": + if r.Header.Get("Authorization") == "Bearer fail" { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, `{"error":"fail","message":"fail"}`) + return + } + w.WriteHeader(http.StatusOK) + writeJSON(w, api.SignResponse{ + CertChainPEM: []api.Certificate{api.NewCertificate(testCrt), api.NewCertificate(testIssCrt)}, + }) case r.RequestURI == "/revoke": var msg api.RevokeRequest parseJSON(r, &msg) @@ -723,9 +733,14 @@ func TestStepCAS_CreateCertificate(t *testing.T) { func TestStepCAS_RenewCertificate(t *testing.T) { caURL, client := testCAHelper(t) - x5c := testX5CIssuer(t, caURL, "") jwk := testJWKIssuer(t, caURL, "") + tokenIssuer := testX5CIssuer(t, caURL, "") + token, err := tokenIssuer.SignToken("test", []string{"test.example.com"}, nil) + if err != nil { + t.Fatal(err) + } + type fields struct { iss stepIssuer client *ca.Client @@ -741,13 +756,25 @@ func TestStepCAS_RenewCertificate(t *testing.T) { want *apiv1.RenewCertificateResponse wantErr bool }{ - {"not implemented", fields{x5c, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{ - CSR: testCR, + {"ok", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{ + Template: &x509.Certificate{}, + Backdate: time.Minute, + Lifetime: time.Hour, + Token: token, + }}, &apiv1.RenewCertificateResponse{ + Certificate: testCrt, + CertificateChain: []*x509.Certificate{testIssCrt}, + }, false}, + {"fail no token", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{ + Template: &x509.Certificate{}, + Backdate: time.Minute, Lifetime: time.Hour, }}, nil, true}, - {"not implemented jwk", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{ - CSR: testCR, + {"fail bad token", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{ + Template: &x509.Certificate{}, + Backdate: time.Minute, Lifetime: time.Hour, + Token: "fail", }}, nil, true}, } for _, tt := range tests { @@ -763,7 +790,10 @@ func TestStepCAS_RenewCertificate(t *testing.T) { return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("StepCAS.RenewCertificate() = %v, want %v", got, tt.want) + t.Error(reflect.DeepEqual(got.Certificate, tt.want.Certificate)) + t.Error(reflect.DeepEqual(got.CertificateChain, tt.want.CertificateChain)) + + t.Errorf("StepCAS.RenewCertificate() = %v, want %v", got.Certificate.Subject, tt.want.Certificate.Subject) } }) } diff --git a/db/db.go b/db/db.go index 784c75f4..b3137a50 100644 --- a/db/db.go +++ b/db/db.go @@ -28,8 +28,9 @@ var ( sshHostPrincipalsTable = []byte("ssh_host_principals") ) -var crlKey = []byte("crl") //TODO: at the moment we store a single CRL in the database, in a dedicated table. -// is this acceptable? probably not.... +// TODO: at the moment we store a single CRL in the database, in a dedicated table. +// is this acceptable? probably not.... +var crlKey = []byte("crl") // ErrAlreadyExists can be returned if the DB attempts to set a key that has // been previously set. @@ -323,7 +324,8 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error { // CertificateData is the JSON representation of the data stored in // x509_certs_data table. type CertificateData struct { - Provisioner *ProvisionerData `json:"provisioner,omitempty"` + Provisioner *ProvisionerData `json:"provisioner,omitempty"` + RaInfo *provisioner.RAInfo `json:"ra,omitempty"` } // ProvisionerData is the JSON representation of the provisioner stored in the @@ -334,6 +336,10 @@ type ProvisionerData struct { Type string `json:"type"` } +type raProvisioner interface { + RAInfo() *provisioner.RAInfo +} + // StoreCertificateChain stores the leaf certificate and the provisioner that // authorized the certificate. func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error { @@ -346,6 +352,9 @@ func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Cert Name: p.GetName(), Type: p.GetType().String(), } + if rap, ok := p.(raProvisioner); ok { + data.RaInfo = rap.RAInfo() + } } b, err := json.Marshal(data) if err != nil { @@ -361,6 +370,31 @@ func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Cert return nil } +// StoreRenewedCertificate stores the leaf certificate and the provisioner that +// authorized the old certificate if available. +func (db *DB) StoreRenewedCertificate(oldCert *x509.Certificate, chain ...*x509.Certificate) error { + var certificateData []byte + if data, err := db.GetCertificateData(oldCert.SerialNumber.String()); err == nil { + if b, err := json.Marshal(data); err == nil { + certificateData = b + } + } + + leaf := chain[0] + serialNumber := []byte(leaf.SerialNumber.String()) + + // Add certificate and certificate data in one transaction. + tx := new(database.Tx) + tx.Set(certsTable, serialNumber, leaf.Raw) + if certificateData != nil { + tx.Set(certsDataTable, serialNumber, certificateData) + } + if err := db.Update(tx); err != nil { + return errors.Wrap(err, "database Update error") + } + return nil +} + // UseToken returns true if we were able to successfully store the token for // for the first time, false otherwise. func (db *DB) UseToken(id, tok string) (bool, error) { diff --git a/db/db_test.go b/db/db_test.go index b4515a5b..7668ae58 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -1,6 +1,7 @@ package db import ( + "bytes" "crypto/x509" "errors" "math/big" @@ -164,12 +165,30 @@ func TestUseToken(t *testing.T) { } } +// wrappedProvisioner implements raProvisioner and attProvisioner. +type wrappedProvisioner struct { + provisioner.Interface + raInfo *provisioner.RAInfo +} + +func (p *wrappedProvisioner) RAInfo() *provisioner.RAInfo { + return p.raInfo +} + func TestDB_StoreCertificateChain(t *testing.T) { p := &provisioner.JWK{ ID: "some-id", Name: "admin", Type: "JWK", } + rap := &wrappedProvisioner{ + Interface: p, + raInfo: &provisioner.RAInfo{ + ProvisionerID: "ra-id", + ProvisionerType: "JWK", + ProvisionerName: "ra", + }, + } chain := []*x509.Certificate{ {Raw: []byte("the certificate"), SerialNumber: big.NewInt(1234)}, } @@ -201,6 +220,21 @@ func TestDB_StoreCertificateChain(t *testing.T) { return nil }, }, true}, args{p, chain}, false}, + {"ok ra provisioner", fields{&MockNoSQLDB{ + MUpdate: func(tx *database.Tx) error { + if len(tx.Operations) != 2 { + t.Fatal("unexpected number of operations") + } + assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket) + assert.Equals(t, []byte("1234"), tx.Operations[0].Key) + assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value) + assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket) + assert.Equals(t, []byte("1234"), tx.Operations[1].Key) + assert.Equals(t, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"},"ra":{"provisionerId":"ra-id","provisionerType":"JWK","provisionerName":"ra"}}`), tx.Operations[1].Value) + assert.Equals(t, `{"provisioner":{"id":"some-id","name":"admin","type":"JWK"},"ra":{"provisionerId":"ra-id","provisionerType":"JWK","provisionerName":"ra"}}`, string(tx.Operations[1].Value)) + return nil + }, + }, true}, args{rap, chain}, false}, {"ok no provisioner", fields{&MockNoSQLDB{ MUpdate: func(tx *database.Tx) error { if len(tx.Operations) != 2 { @@ -293,3 +327,111 @@ func TestDB_GetCertificateData(t *testing.T) { }) } } + +func TestDB_StoreRenewedCertificate(t *testing.T) { + oldCert := &x509.Certificate{SerialNumber: big.NewInt(1)} + chain := []*x509.Certificate{ + &x509.Certificate{SerialNumber: big.NewInt(2), Raw: []byte("raw")}, + &x509.Certificate{SerialNumber: big.NewInt(0)}, + } + + testErr := errors.New("test error") + certsData := []byte(`{"provisioner":{"id":"p","name":"name","type":"JWK"},"ra":{"provisionerId":"rap","provisionerType":"JWK","provisionerName":"rapname"}}`) + matchOperation := func(op *database.TxEntry, bucket, key, value []byte) bool { + return bytes.Equal(op.Bucket, bucket) && bytes.Equal(op.Key, key) && bytes.Equal(op.Value, value) + } + + type fields struct { + DB nosql.DB + isUp bool + } + type args struct { + oldCert *x509.Certificate + chain []*x509.Certificate + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{&MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + if bytes.Equal(bucket, certsDataTable) && bytes.Equal(key, []byte("1")) { + return certsData, nil + } + t.Error("ok failed: unexpected get") + return nil, testErr + }, + MUpdate: func(tx *database.Tx) error { + if len(tx.Operations) != 2 { + t.Error("ok failed: unexpected number of operations") + return testErr + } + op0, op1 := tx.Operations[0], tx.Operations[1] + if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) { + t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value) + return testErr + } + if !matchOperation(op1, certsDataTable, []byte("2"), certsData) { + t.Errorf("ok failed: unexpected entry 1, %s[%s]=%s", op1.Bucket, op1.Key, op1.Value) + return testErr + } + return nil + }, + }, true}, args{oldCert, chain}, false}, + {"ok no data", fields{&MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, database.ErrNotFound + }, + MUpdate: func(tx *database.Tx) error { + if len(tx.Operations) != 1 { + t.Error("ok failed: unexpected number of operations") + return testErr + } + op0 := tx.Operations[0] + if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) { + t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value) + return testErr + } + return nil + }, + }, true}, args{oldCert, chain}, false}, + {"ok fail marshal", fields{&MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return []byte(`{"bad":"json"`), nil + }, + MUpdate: func(tx *database.Tx) error { + if len(tx.Operations) != 1 { + t.Error("ok failed: unexpected number of operations") + return testErr + } + op0 := tx.Operations[0] + if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) { + t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value) + return testErr + } + return nil + }, + }, true}, args{oldCert, chain}, false}, + {"fail", fields{&MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return certsData, nil + }, + MUpdate: func(tx *database.Tx) error { + return testErr + }, + }, true}, args{oldCert, chain}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := &DB{ + DB: tt.fields.DB, + isUp: tt.fields.isUp, + } + if err := db.StoreRenewedCertificate(tt.args.oldCert, tt.args.chain...); (err != nil) != tt.wantErr { + t.Errorf("DB.StoreRenewedCertificate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}