diff --git a/authority/authorize.go b/authority/authorize.go index 91a07353..8555db9b 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -271,10 +271,19 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error { // // TODO(mariano): should we authorize by default? func (a *Authority) authorizeRenew(cert *x509.Certificate) error { + var err error + var isRevoked bool var opts = []interface{}{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())} // Check the passive revocation table. - isRevoked, err := a.db.IsRevoked(cert.SerialNumber.String()) + serial := cert.SerialNumber.String() + if lca, ok := a.adminDB.(interface { + IsRevoked(string) (bool, error) + }); ok { + isRevoked, err = lca.IsRevoked(serial) + } else { + isRevoked, err = a.db.IsRevoked(serial) + } if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) } @@ -294,8 +303,17 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error { // authorizeSSHCertificate returns an error if the given certificate is revoked. func (a *Authority) authorizeSSHCertificate(ctx context.Context, cert *ssh.Certificate) error { + var err error + var isRevoked bool + serial := strconv.FormatUint(cert.Serial, 10) - isRevoked, err := a.db.IsSSHRevoked(serial) + if lca, ok := a.adminDB.(interface { + IsSSHRevoked(string) (bool, error) + }); ok { + isRevoked, err = lca.IsSSHRevoked(serial) + } else { + isRevoked, err = a.db.IsSSHRevoked(serial) + } if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHCertificate", errs.WithKeyVal("serialNumber", serial)) } diff --git a/authority/linkedca.go b/authority/linkedca.go index 117f19ef..79427c5c 100644 --- a/authority/linkedca.go +++ b/authority/linkedca.go @@ -269,6 +269,30 @@ func (c *linkedCaClient) StoreSSHCertificate(crt *ssh.Certificate) error { return errors.Wrap(err, "error posting ssh certificate") } +func (c *linkedCaClient) IsRevoked(serial string) (bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + resp, err := c.client.GetCertificateStatus(ctx, &linkedca.GetCertificateStatusRequest{ + Serial: serial, + }) + if err != nil { + return false, errors.Wrap(err, "error getting certificate status") + } + return resp.Status != linkedca.RevocationStatus_ACTIVE, nil +} + +func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + resp, err := c.client.GetSSHCertificateStatus(ctx, &linkedca.GetSSHCertificateStatusRequest{ + Serial: serial, + }) + if err != nil { + return false, errors.Wrap(err, "error getting certificate status") + } + return resp.Status != linkedca.RevocationStatus_ACTIVE, nil +} + func serializeCertificateChain(fullchain ...*x509.Certificate) string { var chain string for _, crt := range fullchain { diff --git a/go.mod b/go.mod index 6957cc83..98e7dbdb 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,7 @@ require ( go.mozilla.org/pkcs7 v0.0.0-20200128120323-432b2356ecb1 go.step.sm/cli-utils v0.4.1 go.step.sm/crypto v0.9.0 - go.step.sm/linkedca v0.1.0 + go.step.sm/linkedca v0.3.0 golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 golang.org/x/net v0.0.0-20210716203947-853a461950ff golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c // indirect