certificates/acme/api/revoke_test.go
2021-11-26 17:27:42 +01:00

670 lines
19 KiB
Go

package api
import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net/http/httptest"
"net/url"
"reflect"
"testing"
"github.com/go-chi/chi"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ocsp"
)
const (
certPEM = `-----BEGIN CERTIFICATE-----
MIIDujCCAqKgAwIBAgIIE31FZVaPXTUwDQYJKoZIhvcNAQEFBQAwSTELMAkGA1UE
BhMCVVMxEzARBgNVBAoTCkdvb2dsZSBJbmMxJTAjBgNVBAMTHEdvb2dsZSBJbnRl
cm5ldCBBdXRob3JpdHkgRzIwHhcNMTQwMTI5MTMyNzQzWhcNMTQwNTI5MDAwMDAw
WjBpMQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwN
TW91bnRhaW4gVmlldzETMBEGA1UECgwKR29vZ2xlIEluYzEYMBYGA1UEAwwPbWFp
bC5nb29nbGUuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEfRrObuSW5T7q
5CnSEqefEmtH4CCv6+5EckuriNr1CjfVvqzwfAhopXkLrq45EQm8vkmf7W96XJhC
7ZM0dYi1/qOCAU8wggFLMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAa
BgNVHREEEzARgg9tYWlsLmdvb2dsZS5jb20wCwYDVR0PBAQDAgeAMGgGCCsGAQUF
BwEBBFwwWjArBggrBgEFBQcwAoYfaHR0cDovL3BraS5nb29nbGUuY29tL0dJQUcy
LmNydDArBggrBgEFBQcwAYYfaHR0cDovL2NsaWVudHMxLmdvb2dsZS5jb20vb2Nz
cDAdBgNVHQ4EFgQUiJxtimAuTfwb+aUtBn5UYKreKvMwDAYDVR0TAQH/BAIwADAf
BgNVHSMEGDAWgBRK3QYWG7z2aLV29YG2u2IaulqBLzAXBgNVHSAEEDAOMAwGCisG
AQQB1nkCBQEwMAYDVR0fBCkwJzAloCOgIYYfaHR0cDovL3BraS5nb29nbGUuY29t
L0dJQUcyLmNybDANBgkqhkiG9w0BAQUFAAOCAQEAH6RYHxHdcGpMpFE3oxDoFnP+
gtuBCHan2yE2GRbJ2Cw8Lw0MmuKqHlf9RSeYfd3BXeKkj1qO6TVKwCh+0HdZk283
TZZyzmEOyclm3UGFYe82P/iDFt+CeQ3NpmBg+GoaVCuWAARJN/KfglbLyyYygcQq
0SgeDh8dRKUiaW3HQSoYvTvdTuqzwK4CXsr3b5/dAOY8uMuG/IAR3FgwTbZ1dtoW
RvOTa8hYiU6A475WuZKyEHcwnGYe57u2I2KbMgcKjPniocj4QzgYsVAVKW3IwaOh
yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
-----END CERTIFICATE-----`
)
func v(v int) *int {
return &v
}
func parseCertificate(data string) *x509.Certificate {
block, _ := pem.Decode([]byte(data))
if block == nil {
panic("failed to parse certificate PEM")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
panic("failed to parse certificate: " + err.Error())
}
return cert
}
type mockCA struct {
MockIsRevoked func(sn string) (bool, error)
MockRevoke func(ctx context.Context, opts *authority.RevokeOptions) error
}
func (m *mockCA) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
return nil, nil
}
func (m *mockCA) IsRevoked(sn string) (bool, error) {
if m.MockIsRevoked != nil {
return m.MockIsRevoked(sn)
}
return false, nil
}
func (m *mockCA) Revoke(ctx context.Context, opts *authority.RevokeOptions) error {
if m.MockRevoke != nil {
return m.MockRevoke(ctx, opts)
}
return nil
}
func (m *mockCA) LoadProvisionerByName(string) (provisioner.Interface, error) {
return nil, nil
}
func Test_validateReasonCode(t *testing.T) {
tests := []struct {
name string
reasonCode *int
want *acme.Error
}{
{
name: "ok",
reasonCode: v(ocsp.Unspecified),
want: nil,
},
{
name: "fail/too-low",
reasonCode: v(-1),
want: acme.NewError(acme.ErrorBadRevocationReasonType, "reasonCode out of bounds"),
},
{
name: "fail/too-high",
reasonCode: v(11),
want: acme.NewError(acme.ErrorBadRevocationReasonType, "reasonCode out of bounds"),
},
{
name: "fail/missing-7",
reasonCode: v(7),
want: acme.NewError(acme.ErrorBadRevocationReasonType, "reasonCode out of bounds"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateReasonCode(tt.reasonCode)
if (err != nil) != (tt.want != nil) {
t.Errorf("validateReasonCode() = %v, want %v", err, tt.want)
}
if err != nil {
assert.Equals(t, err.Type, tt.want.Type)
assert.Equals(t, err.Detail, tt.want.Detail)
assert.Equals(t, err.Status, tt.want.Status)
assert.Equals(t, err.Err.Error(), tt.want.Err.Error())
assert.Equals(t, err.Detail, tt.want.Detail)
}
})
}
}
func Test_reason(t *testing.T) {
// case ocsp.RemoveFromCRL:
// return "remove from crl"
// case ocsp.PrivilegeWithdrawn:
// return "privilege withdrawn"
// case ocsp.AACompromise:
// return "aa compromised"
// default:
// return "unspecified reason"
// }
// }
tests := []struct {
name string
reasonCode int
want string
}{
{
name: "unspecified reason",
reasonCode: ocsp.Unspecified,
want: "unspecified reason",
},
{
name: "key compromised",
reasonCode: ocsp.KeyCompromise,
want: "key compromised",
},
{
name: "ca compromised",
reasonCode: ocsp.CACompromise,
want: "ca compromised",
},
{
name: "affiliation changed",
reasonCode: ocsp.AffiliationChanged,
want: "affiliation changed",
},
{
name: "superseded",
reasonCode: ocsp.Superseded,
want: "superseded",
},
{
name: "cessation of operation",
reasonCode: ocsp.CessationOfOperation,
want: "cessation of operation",
},
{
name: "certificate hold",
reasonCode: ocsp.CertificateHold,
want: "certificate hold",
},
{
name: "remove from crl",
reasonCode: ocsp.RemoveFromCRL,
want: "remove from crl",
},
{
name: "privilege withdrawn",
reasonCode: ocsp.PrivilegeWithdrawn,
want: "privilege withdrawn",
},
{
name: "aa compromised",
reasonCode: ocsp.AACompromise,
want: "aa compromised",
},
{
name: "default",
reasonCode: -1,
want: "unspecified reason",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := reason(tt.reasonCode); got != tt.want {
t.Errorf("reason() = %v, want %v", got, tt.want)
}
})
}
}
func Test_revokeOptions(t *testing.T) {
var cert *x509.Certificate
type args struct {
serial string
certToBeRevoked *x509.Certificate
reasonCode *int
}
tests := []struct {
name string
args args
want *authority.RevokeOptions
}{
{
name: "ok/no-reasoncode",
args: args{
serial: "1234",
certToBeRevoked: cert,
},
want: &authority.RevokeOptions{
Serial: "1234",
Crt: nil,
ACME: true,
},
},
{
name: "ok/including-reasoncode",
args: args{
serial: "1234",
certToBeRevoked: cert,
reasonCode: v(ocsp.KeyCompromise),
},
want: &authority.RevokeOptions{
Serial: "1234",
Crt: nil,
ACME: true,
ReasonCode: 1,
Reason: "key compromised",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := revokeOptions(tt.args.serial, tt.args.certToBeRevoked, tt.args.reasonCode); !reflect.DeepEqual(got, tt.want) {
t.Errorf("revokeOptions() = %v, want %v", got, tt.want)
}
})
}
}
func TestHandler_RevokeCert(t *testing.T) {
prov := &provisioner.ACME{
Type: "ACME",
Name: "testprov",
}
escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
chiCtx := chi.NewRouteContext()
revokeURL := fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), escProvName)
cert := parseCertificate(certPEM)
rp := &revokePayload{
Certificate: base64.RawURLEncoding.EncodeToString(cert.Raw),
}
payloadBytes, err := json.Marshal(rp)
assert.FatalError(t, err)
type test struct {
db acme.DB
ca acme.CertificateAuthority
ctx context.Context
statusCode int
err *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/wrong-certificate-encoding": func(t *testing.T) test {
rp := &revokePayload{
Certificate: base64.StdEncoding.EncodeToString(cert.Raw),
}
wronglyEncodedPayloadBytes, err := json.Marshal(rp)
assert.FatalError(t, err)
jws := &jose.JSONWebSignature{
Signatures: []jose.Signature{
{
Protected: jose.Header{
Algorithm: jose.ES256,
KeyID: "bar",
ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": revokeURL,
},
},
},
},
}
acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{}
ca := &mockCA{}
return test{
db: db,
ca: ca,
ctx: ctx,
statusCode: 400,
err: &acme.Error{
Type: "urn:ietf:params:acme:error:malformed",
Status: 400,
Detail: "The request message was malformed",
},
}
},
"fail/no-certificate-encoded": func(t *testing.T) test {
rp := &revokePayload{
Certificate: base64.RawURLEncoding.EncodeToString([]byte{}),
}
wrongPayloadBytes, err := json.Marshal(rp)
assert.FatalError(t, err)
jws := &jose.JSONWebSignature{
Signatures: []jose.Signature{
{
Protected: jose.Header{
Algorithm: jose.ES256,
KeyID: "bar",
ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": revokeURL,
},
},
},
},
}
acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wrongPayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{}
ca := &mockCA{}
return test{
db: db,
ca: ca,
ctx: ctx,
statusCode: 400,
err: &acme.Error{
Type: "urn:ietf:params:acme:error:malformed",
Status: 400,
Detail: "The request message was malformed",
},
}
},
"fail/account-not-valid": func(t *testing.T) test {
jws := &jose.JSONWebSignature{
Signatures: []jose.Signature{
{
Protected: jose.Header{
Algorithm: jose.ES256,
KeyID: "bar",
ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": revokeURL,
},
},
},
},
}
acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
assert.Equals(t, cert.SerialNumber.String(), serial)
return &acme.Certificate{
AccountID: "accountID",
}, nil
},
}
ca := &mockCA{}
return test{
db: db,
ca: ca,
ctx: ctx,
statusCode: 403,
err: &acme.Error{
Type: "urn:ietf:params:acme:error:unauthorized",
Detail: fmt.Sprintf("No authorization provided for name %s", cert.Subject.String()),
Status: 403,
},
}
},
"fail/account-not-authorized": func(t *testing.T) test {
jws := &jose.JSONWebSignature{
Signatures: []jose.Signature{
{
Protected: jose.Header{
Algorithm: jose.ES256,
KeyID: "bar",
ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": revokeURL,
},
},
},
},
}
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
assert.Equals(t, cert.SerialNumber.String(), serial)
return &acme.Certificate{
AccountID: "differentAccountID",
}, nil
},
}
ca := &mockCA{}
return test{
db: db,
ca: ca,
ctx: ctx,
statusCode: 403,
err: &acme.Error{
Type: "urn:ietf:params:acme:error:unauthorized",
Detail: fmt.Sprintf("No authorization provided for name %s", cert.Subject.String()),
Status: 403,
},
}
},
"fail/certificate-revoked-check-fails": func(t *testing.T) test {
jws := &jose.JSONWebSignature{
Signatures: []jose.Signature{
{
Protected: jose.Header{
Algorithm: jose.ES256,
KeyID: "bar",
ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": revokeURL,
},
},
},
},
}
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
assert.Equals(t, cert.SerialNumber.String(), serial)
return &acme.Certificate{
AccountID: "accountID",
}, nil
},
}
ca := &mockCA{
MockIsRevoked: func(sn string) (bool, error) {
return false, errors.New("force")
},
}
return test{
db: db,
ca: ca,
ctx: ctx,
statusCode: 500,
err: &acme.Error{
Type: "urn:ietf:params:acme:error:serverInternal",
Detail: "The server experienced an internal error",
Status: 500,
},
}
},
"fail/certificate-already-revoked": func(t *testing.T) test {
jws := &jose.JSONWebSignature{
Signatures: []jose.Signature{
{
Protected: jose.Header{
Algorithm: jose.ES256,
KeyID: "bar",
ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": revokeURL,
},
},
},
},
}
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
assert.Equals(t, cert.SerialNumber.String(), serial)
return &acme.Certificate{
AccountID: "accountID",
}, nil
},
}
ca := &mockCA{
MockIsRevoked: func(sn string) (bool, error) {
return true, nil
},
}
return test{
db: db,
ca: ca,
ctx: ctx,
statusCode: 400,
err: &acme.Error{
Type: "urn:ietf:params:acme:error:alreadyRevoked",
Detail: "Certificate already revoked",
Status: 400,
},
}
},
"fail/certificate-revoke-fails": func(t *testing.T) test {
jws := &jose.JSONWebSignature{
Signatures: []jose.Signature{
{
Protected: jose.Header{
Algorithm: jose.ES256,
KeyID: "bar",
ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": revokeURL,
},
},
},
},
}
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
assert.Equals(t, cert.SerialNumber.String(), serial)
return &acme.Certificate{
AccountID: "accountID",
}, nil
},
}
ca := &mockCA{
MockRevoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
return errors.New("force")
},
}
return test{
db: db,
ca: ca,
ctx: ctx,
statusCode: 500,
err: &acme.Error{
Type: "urn:ietf:params:acme:error:serverInternal",
Detail: "The server experienced an internal error",
Status: 500,
},
}
},
"ok/using-account-key": func(t *testing.T) test {
jws := &jose.JSONWebSignature{
Signatures: []jose.Signature{
{
Protected: jose.Header{
Algorithm: jose.ES256,
KeyID: "bar",
ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": revokeURL,
},
},
},
},
}
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
assert.Equals(t, cert.SerialNumber.String(), serial)
return &acme.Certificate{
AccountID: "accountID",
}, nil
},
}
ca := &mockCA{}
return test{
db: db,
ca: ca,
ctx: ctx,
statusCode: 200,
}
},
}
for name, setup := range tests {
tc := setup(t)
t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca}
req := httptest.NewRequest("POST", revokeURL, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.RevokeCert(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := io.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
assert.True(t, bytes.Equal(bytes.TrimSpace(body), []byte{}))
assert.Equals(t, int64(0), req.ContentLength)
assert.Equals(t, []string{fmt.Sprintf("<%s/acme/%s/directory>;rel=\"index\"", baseURL.String(), escProvName)}, res.Header["Link"])
}
})
}
}