forked from TrueCloudLab/certificates
Merge branch 'master' into herman/ip-sans-improvements
This commit is contained in:
commit
13a31fd862
42 changed files with 2482 additions and 162 deletions
|
@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||||
|
|
||||||
## [Unreleased - 0.18.1] - DATE
|
## [Unreleased - 0.18.1] - DATE
|
||||||
### Added
|
### Added
|
||||||
|
- Support for ACME revocation.
|
||||||
### Changed
|
### Changed
|
||||||
### Deprecated
|
### Deprecated
|
||||||
### Removed
|
### Removed
|
||||||
|
|
|
@ -42,6 +42,8 @@ To get up and running quickly, or as an alternative to running your own `step-ca
|
||||||
[![GitHub stars](https://img.shields.io/github/stars/smallstep/certificates.svg?style=social)](https://github.com/smallstep/certificates/stargazers)
|
[![GitHub stars](https://img.shields.io/github/stars/smallstep/certificates.svg?style=social)](https://github.com/smallstep/certificates/stargazers)
|
||||||
[![Twitter followers](https://img.shields.io/twitter/follow/smallsteplabs.svg?label=Follow&style=social)](https://twitter.com/intent/follow?screen_name=smallsteplabs)
|
[![Twitter followers](https://img.shields.io/twitter/follow/smallsteplabs.svg?label=Follow&style=social)](https://twitter.com/intent/follow?screen_name=smallsteplabs)
|
||||||
|
|
||||||
|
![star us](https://github.com/smallstep/certificates/raw/master/docs/images/star.gif)
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
### 🦾 A fast, stable, flexible private CA
|
### 🦾 A fast, stable, flexible private CA
|
||||||
|
|
|
@ -100,11 +100,17 @@ func (h *Handler) Route(r api.Router) {
|
||||||
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
|
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
|
||||||
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
|
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
|
||||||
|
|
||||||
|
validatingMiddleware := func(next nextHTTP) nextHTTP {
|
||||||
|
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next)))))))
|
||||||
|
}
|
||||||
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
||||||
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next)))))))))
|
return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next)))
|
||||||
}
|
}
|
||||||
extractPayloadByKid := func(next nextHTTP) nextHTTP {
|
extractPayloadByKid := func(next nextHTTP) nextHTTP {
|
||||||
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))))))))
|
return validatingMiddleware(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))
|
||||||
|
}
|
||||||
|
extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP {
|
||||||
|
return validatingMiddleware(h.extractOrLookupJWK(h.verifyAndExtractJWSPayload(next)))
|
||||||
}
|
}
|
||||||
|
|
||||||
r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount))
|
r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount))
|
||||||
|
@ -117,6 +123,7 @@ func (h *Handler) Route(r api.Router) {
|
||||||
r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization)))
|
r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization)))
|
||||||
r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge))
|
r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge))
|
||||||
r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
|
r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
|
||||||
|
r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(h.RevokeCert))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNonce just sets the right header since a Nonce is added to each response
|
// GetNonce just sets the right header since a Nonce is added to each response
|
||||||
|
|
|
@ -262,11 +262,11 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||||
// Store the JWK in the context.
|
// Store the JWK in the context.
|
||||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||||
|
|
||||||
// Get Account or continue to generate a new one.
|
// Get Account OR continue to generate a new one OR continue Revoke with certificate private key
|
||||||
acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID)
|
acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, acme.ErrNotFound):
|
case errors.Is(err, acme.ErrNotFound):
|
||||||
// For NewAccount requests ...
|
// For NewAccount and Revoke requests ...
|
||||||
break
|
break
|
||||||
case err != nil:
|
case err != nil:
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
|
@ -352,6 +352,42 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractOrLookupJWK forwards handling to either extractJWK or
|
||||||
|
// lookupJWK based on the presence of a JWK or a KID, respectively.
|
||||||
|
func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
jws, err := jwsFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// at this point the JWS has already been verified (if correctly configured in middleware),
|
||||||
|
// and it can be used to check if a JWK exists. This flow is used when the ACME client
|
||||||
|
// signed the payload with a certificate private key.
|
||||||
|
if canExtractJWKFrom(jws) {
|
||||||
|
h.extractJWK(next)(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// default to looking up the JWK based on KeyID. This flow is used when the ACME client
|
||||||
|
// signed the payload with an account private key.
|
||||||
|
h.lookupJWK(next)(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// canExtractJWKFrom checks if the JWS has a JWK that can be extracted
|
||||||
|
func canExtractJWKFrom(jws *jose.JSONWebSignature) bool {
|
||||||
|
if jws == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(jws.Signatures) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return jws.Signatures[0].Protected.JSONWebKey != nil
|
||||||
|
}
|
||||||
|
|
||||||
// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context.
|
// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context.
|
||||||
// Make sure to parse and validate the JWS before running this middleware.
|
// Make sure to parse and validate the JWS before running this middleware.
|
||||||
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||||
|
|
|
@ -1472,3 +1472,187 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_canExtractJWKFrom(t *testing.T) {
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
type args struct {
|
||||||
|
jws *jose.JSONWebSignature
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no-jws",
|
||||||
|
args: args{
|
||||||
|
jws: nil,
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no-signatures",
|
||||||
|
args: args{
|
||||||
|
jws: &jose.JSONWebSignature{
|
||||||
|
Signatures: []jose.Signature{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no-jwk",
|
||||||
|
args: args{
|
||||||
|
jws: &jose.JSONWebSignature{
|
||||||
|
Signatures: []jose.Signature{
|
||||||
|
{
|
||||||
|
Protected: jose.Header{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok",
|
||||||
|
args: args{
|
||||||
|
jws: &jose.JSONWebSignature{
|
||||||
|
Signatures: []jose.Signature{
|
||||||
|
{
|
||||||
|
Protected: jose.Header{
|
||||||
|
JSONWebKey: jwk,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := canExtractJWKFrom(tt.args.jws); got != tt.want {
|
||||||
|
t.Errorf("canExtractJWKFrom() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_extractOrLookupJWK(t *testing.T) {
|
||||||
|
u := "https://ca.smallstep.com/acme/account"
|
||||||
|
type test struct {
|
||||||
|
db acme.DB
|
||||||
|
linker Linker
|
||||||
|
statusCode int
|
||||||
|
ctx context.Context
|
||||||
|
err *acme.Error
|
||||||
|
next func(w http.ResponseWriter, r *http.Request)
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"ok/extract": func(t *testing.T) test {
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
kid, err := jwk.Thumbprint(crypto.SHA256)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
pub := jwk.Public()
|
||||||
|
pub.KeyID = base64.RawURLEncoding.EncodeToString(kid)
|
||||||
|
so := new(jose.SignerOptions)
|
||||||
|
so.WithHeader("jwk", pub) // JWK for certificate private key flow
|
||||||
|
signer, err := jose.NewSigner(jose.SigningKey{
|
||||||
|
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
|
||||||
|
Key: jwk.Key,
|
||||||
|
}, so)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
signed, err := signer.Sign([]byte("foo"))
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
raw, err := signed.CompactSerialize()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
parsedJWS, err := jose.ParseJWS(raw)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
linker: NewLinker("dns", "acme"),
|
||||||
|
db: &acme.MockDB{
|
||||||
|
MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
|
||||||
|
assert.Equals(t, kid, pub.KeyID)
|
||||||
|
return nil, acme.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: context.WithValue(context.Background(), jwsContextKey, parsedJWS),
|
||||||
|
statusCode: 200,
|
||||||
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write(testBody)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/lookup": func(t *testing.T) test {
|
||||||
|
prov := newProv()
|
||||||
|
provName := url.PathEscape(prov.GetName())
|
||||||
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
accID := "accID"
|
||||||
|
prefix := fmt.Sprintf("%s/acme/%s/account/", baseURL, provName)
|
||||||
|
so := new(jose.SignerOptions)
|
||||||
|
so.WithHeader("kid", fmt.Sprintf("%s%s", prefix, accID)) // KID for account private key flow
|
||||||
|
signer, err := jose.NewSigner(jose.SigningKey{
|
||||||
|
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
|
||||||
|
Key: jwk.Key,
|
||||||
|
}, so)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
jws, err := signer.Sign([]byte("baz"))
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
raw, err := jws.CompactSerialize()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
parsedJWS, err := jose.ParseJWS(raw)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
|
return test{
|
||||||
|
linker: NewLinker("test.ca.smallstep.com", "acme"),
|
||||||
|
db: &acme.MockDB{
|
||||||
|
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
|
||||||
|
assert.Equals(t, accID, acc.ID)
|
||||||
|
return acc, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 200,
|
||||||
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write(testBody)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, prep := range tests {
|
||||||
|
tc := prep(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := &Handler{db: tc.db, linker: tc.linker}
|
||||||
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
|
req = req.WithContext(tc.ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.extractOrLookupJWK(tc.next)(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
|
||||||
|
var ae acme.Error
|
||||||
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
|
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||||
|
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||||
|
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
|
} else {
|
||||||
|
assert.Equals(t, bytes.TrimSpace(body), testBody)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
287
acme/api/revoke.go
Normal file
287
acme/api/revoke.go
Normal file
|
@ -0,0 +1,287 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/certificates/api"
|
||||||
|
"github.com/smallstep/certificates/authority"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/certificates/logging"
|
||||||
|
"go.step.sm/crypto/jose"
|
||||||
|
"golang.org/x/crypto/ocsp"
|
||||||
|
)
|
||||||
|
|
||||||
|
type revokePayload struct {
|
||||||
|
Certificate string `json:"certificate"`
|
||||||
|
ReasonCode *int `json:"reason,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevokeCert attempts to revoke a certificate.
|
||||||
|
func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
jws, err := jwsFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prov, err := provisionerFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, err := payloadFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var p revokePayload
|
||||||
|
err = json.Unmarshal(payload.value, &p)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, acme.WrapErrorISE(err, "error unmarshaling payload"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate)
|
||||||
|
if err != nil {
|
||||||
|
// in this case the most likely cause is a client that didn't properly encode the certificate
|
||||||
|
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
certToBeRevoked, err := x509.ParseCertificate(certBytes)
|
||||||
|
if err != nil {
|
||||||
|
// in this case a client may have encoded something different than a certificate
|
||||||
|
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serial := certToBeRevoked.SerialNumber.String()
|
||||||
|
dbCert, err := h.db.GetCertificateBySerial(ctx, serial)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) {
|
||||||
|
// this should never happen
|
||||||
|
api.WriteError(w, acme.NewErrorISE("certificate raw bytes are not equal"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldCheckAccountFrom(jws) {
|
||||||
|
account, err := accountFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
|
||||||
|
if acmeErr != nil {
|
||||||
|
api.WriteError(w, acmeErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// if account doesn't need to be checked, the JWS should be verified to be signed by the
|
||||||
|
// private key that belongs to the public key in the certificate to be revoked.
|
||||||
|
_, err := jws.Verify(certToBeRevoked.PublicKey)
|
||||||
|
if err != nil {
|
||||||
|
// TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized?
|
||||||
|
api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hasBeenRevokedBefore, err := h.ca.IsRevoked(serial)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasBeenRevokedBefore {
|
||||||
|
api.WriteError(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reasonCode := p.ReasonCode
|
||||||
|
acmeErr := validateReasonCode(reasonCode)
|
||||||
|
if acmeErr != nil {
|
||||||
|
api.WriteError(w, acmeErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorize revocation by ACME provisioner
|
||||||
|
ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod)
|
||||||
|
err = prov.AuthorizeRevoke(ctx, "")
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
options := revokeOptions(serial, certToBeRevoked, reasonCode)
|
||||||
|
err = h.ca.Revoke(ctx, options)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, wrapRevokeErr(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logRevoke(w, options)
|
||||||
|
w.Header().Add("Link", link(h.linker.GetLink(ctx, DirectoryLinkType), "index"))
|
||||||
|
w.Write(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAccountAuthorized checks if an ACME account that was retrieved earlier is authorized
|
||||||
|
// to revoke the certificate. An Account must always be valid in order to revoke a certificate.
|
||||||
|
// In case the certificate retrieved from the database belongs to the Account, the Account is
|
||||||
|
// authorized. If the certificate retrieved from the database doesn't belong to the Account,
|
||||||
|
// the identifiers in the certificate are extracted and compared against the (valid) Authorizations
|
||||||
|
// that are stored for the ACME Account. If these sets match, the Account is considered authorized
|
||||||
|
// to revoke the certificate. If this check fails, the client will receive an unauthorized error.
|
||||||
|
func (h *Handler) isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
|
||||||
|
if !account.IsValid() {
|
||||||
|
return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)
|
||||||
|
}
|
||||||
|
certificateBelongsToAccount := dbCert.AccountID == account.ID
|
||||||
|
if certificateBelongsToAccount {
|
||||||
|
return nil // return early
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(hs): according to RFC8555: 7.6, a server MUST consider the following accounts authorized
|
||||||
|
// to revoke a certificate:
|
||||||
|
//
|
||||||
|
// o the account that issued the certificate.
|
||||||
|
// o an account that holds authorizations for all of the identifiers in the certificate.
|
||||||
|
//
|
||||||
|
// We currently only support the first case. The second might result in step going OOM when
|
||||||
|
// large numbers of Authorizations are involved when the current nosql interface is in use.
|
||||||
|
// We want to protect users from this failure scenario, so that's why it hasn't been added yet.
|
||||||
|
// This issue is tracked in https://github.com/smallstep/certificates/issues/767
|
||||||
|
|
||||||
|
// not authorized; fail closed.
|
||||||
|
return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' is not authorized", account.ID), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapRevokeErr is a best effort implementation to transform an error during
|
||||||
|
// revocation into an ACME error, so that clients can understand the error.
|
||||||
|
func wrapRevokeErr(err error) *acme.Error {
|
||||||
|
t := err.Error()
|
||||||
|
if strings.Contains(t, "is already revoked") {
|
||||||
|
return acme.NewError(acme.ErrorAlreadyRevokedType, t)
|
||||||
|
}
|
||||||
|
return acme.WrapErrorISE(err, "error when revoking certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
// unauthorizedError returns an ACME error indicating the request was
|
||||||
|
// not authorized to revoke the certificate.
|
||||||
|
func wrapUnauthorizedError(cert *x509.Certificate, unauthorizedIdentifiers []acme.Identifier, msg string, err error) *acme.Error {
|
||||||
|
var acmeErr *acme.Error
|
||||||
|
if err == nil {
|
||||||
|
acmeErr = acme.NewError(acme.ErrorUnauthorizedType, msg)
|
||||||
|
} else {
|
||||||
|
acmeErr = acme.WrapError(acme.ErrorUnauthorizedType, err, msg)
|
||||||
|
}
|
||||||
|
acmeErr.Status = http.StatusForbidden // RFC8555 7.6 shows example with 403
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(unauthorizedIdentifiers) > 0:
|
||||||
|
identifier := unauthorizedIdentifiers[0] // picking the first; compound may be an option too?
|
||||||
|
acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", identifier.Value)
|
||||||
|
case cert.Subject.String() != "":
|
||||||
|
acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", cert.Subject.CommonName)
|
||||||
|
default:
|
||||||
|
acmeErr.Detail = "No authorization provided"
|
||||||
|
}
|
||||||
|
|
||||||
|
return acmeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// logRevoke logs successful revocation of certificate
|
||||||
|
func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {
|
||||||
|
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||||
|
rl.WithFields(map[string]interface{}{
|
||||||
|
"serial": ri.Serial,
|
||||||
|
"reasonCode": ri.ReasonCode,
|
||||||
|
"reason": ri.Reason,
|
||||||
|
"passiveOnly": ri.PassiveOnly,
|
||||||
|
"ACME": ri.ACME,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateReasonCode validates the revocation reason
|
||||||
|
func validateReasonCode(reasonCode *int) *acme.Error {
|
||||||
|
if reasonCode != nil && ((*reasonCode < ocsp.Unspecified || *reasonCode > ocsp.AACompromise) || *reasonCode == 7) {
|
||||||
|
return acme.NewError(acme.ErrorBadRevocationReasonType, "reasonCode out of bounds")
|
||||||
|
}
|
||||||
|
// NOTE: it's possible to add additional requirements to the reason code:
|
||||||
|
// The server MAY disallow a subset of reasonCodes from being
|
||||||
|
// used by the user. If a request contains a disallowed reasonCode,
|
||||||
|
// then the server MUST reject it with the error type
|
||||||
|
// "urn:ietf:params:acme:error:badRevocationReason"
|
||||||
|
// No additional checks have been implemented so far.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// revokeOptions determines the RevokeOptions for the Authority to use in revocation
|
||||||
|
func revokeOptions(serial string, certToBeRevoked *x509.Certificate, reasonCode *int) *authority.RevokeOptions {
|
||||||
|
opts := &authority.RevokeOptions{
|
||||||
|
Serial: serial,
|
||||||
|
ACME: true,
|
||||||
|
Crt: certToBeRevoked,
|
||||||
|
}
|
||||||
|
if reasonCode != nil { // NOTE: when implementing CRL and/or OCSP, and reason code is missing, CRL entry extension should be omitted
|
||||||
|
opts.Reason = reason(*reasonCode)
|
||||||
|
opts.ReasonCode = *reasonCode
|
||||||
|
}
|
||||||
|
return opts
|
||||||
|
}
|
||||||
|
|
||||||
|
// reason transforms an integer reason code to a
|
||||||
|
// textual description of the revocation reason.
|
||||||
|
func reason(reasonCode int) string {
|
||||||
|
switch reasonCode {
|
||||||
|
case ocsp.Unspecified:
|
||||||
|
return "unspecified reason"
|
||||||
|
case ocsp.KeyCompromise:
|
||||||
|
return "key compromised"
|
||||||
|
case ocsp.CACompromise:
|
||||||
|
return "ca compromised"
|
||||||
|
case ocsp.AffiliationChanged:
|
||||||
|
return "affiliation changed"
|
||||||
|
case ocsp.Superseded:
|
||||||
|
return "superseded"
|
||||||
|
case ocsp.CessationOfOperation:
|
||||||
|
return "cessation of operation"
|
||||||
|
case ocsp.CertificateHold:
|
||||||
|
return "certificate hold"
|
||||||
|
case ocsp.RemoveFromCRL:
|
||||||
|
return "remove from crl"
|
||||||
|
case ocsp.PrivilegeWithdrawn:
|
||||||
|
return "privilege withdrawn"
|
||||||
|
case ocsp.AACompromise:
|
||||||
|
return "aa compromised"
|
||||||
|
default:
|
||||||
|
return "unspecified reason"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldCheckAccountFrom indicates whether an account should be
|
||||||
|
// retrieved from the context, so that it can be used for
|
||||||
|
// additional checks. This should only be done when no JWK
|
||||||
|
// can be extracted from the request, as that would indicate
|
||||||
|
// that the revocation request was signed with a certificate
|
||||||
|
// key pair (and not an account key pair). Looking up such
|
||||||
|
// a JWK would result in no Account being found.
|
||||||
|
func shouldCheckAccountFrom(jws *jose.JSONWebSignature) bool {
|
||||||
|
return !canExtractJWKFrom(jws)
|
||||||
|
}
|
1316
acme/api/revoke_test.go
Normal file
1316
acme/api/revoke_test.go
Normal file
File diff suppressed because it is too large
Load diff
|
@ -26,8 +26,11 @@ import (
|
||||||
type ChallengeType string
|
type ChallengeType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
HTTP01 ChallengeType = "http-01"
|
// HTTP01 is the http-01 ACME challenge type
|
||||||
DNS01 ChallengeType = "dns-01"
|
HTTP01 ChallengeType = "http-01"
|
||||||
|
// DNS01 is the dns-01 ACME challenge type
|
||||||
|
DNS01 ChallengeType = "dns-01"
|
||||||
|
// TLSALPN01 is the tls-alpn-01 ACME challenge type
|
||||||
TLSALPN01 ChallengeType = "tls-alpn-01"
|
TLSALPN01 ChallengeType = "tls-alpn-01"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -5,12 +5,15 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CertificateAuthority is the interface implemented by a CA authority.
|
// CertificateAuthority is the interface implemented by a CA authority.
|
||||||
type CertificateAuthority interface {
|
type CertificateAuthority interface {
|
||||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||||
|
IsRevoked(sn string) (bool, error)
|
||||||
|
Revoke(context.Context, *authority.RevokeOptions) error
|
||||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,6 +31,7 @@ var clock Clock
|
||||||
// only those methods required by the ACME api/authority.
|
// only those methods required by the ACME api/authority.
|
||||||
type Provisioner interface {
|
type Provisioner interface {
|
||||||
AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error)
|
AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error)
|
||||||
|
AuthorizeRevoke(ctx context.Context, token string) error
|
||||||
GetID() string
|
GetID() string
|
||||||
GetName() string
|
GetName() string
|
||||||
DefaultTLSCertDuration() time.Duration
|
DefaultTLSCertDuration() time.Duration
|
||||||
|
@ -41,6 +45,7 @@ type MockProvisioner struct {
|
||||||
MgetID func() string
|
MgetID func() string
|
||||||
MgetName func() string
|
MgetName func() string
|
||||||
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||||
|
MauthorizeRevoke func(ctx context.Context, token string) error
|
||||||
MdefaultTLSCertDuration func() time.Duration
|
MdefaultTLSCertDuration func() time.Duration
|
||||||
MgetOptions func() *provisioner.Options
|
MgetOptions func() *provisioner.Options
|
||||||
}
|
}
|
||||||
|
@ -61,6 +66,14 @@ func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]prov
|
||||||
return m.Mret1.([]provisioner.SignOption), m.Merr
|
return m.Mret1.([]provisioner.SignOption), m.Merr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AuthorizeRevoke mock
|
||||||
|
func (m *MockProvisioner) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||||
|
if m.MauthorizeRevoke != nil {
|
||||||
|
return m.MauthorizeRevoke(ctx, token)
|
||||||
|
}
|
||||||
|
return m.Merr
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultTLSCertDuration mock
|
// DefaultTLSCertDuration mock
|
||||||
func (m *MockProvisioner) DefaultTLSCertDuration() time.Duration {
|
func (m *MockProvisioner) DefaultTLSCertDuration() time.Duration {
|
||||||
if m.MdefaultTLSCertDuration != nil {
|
if m.MdefaultTLSCertDuration != nil {
|
||||||
|
|
34
acme/db.go
34
acme/db.go
|
@ -25,9 +25,11 @@ type DB interface {
|
||||||
CreateAuthorization(ctx context.Context, az *Authorization) error
|
CreateAuthorization(ctx context.Context, az *Authorization) error
|
||||||
GetAuthorization(ctx context.Context, id string) (*Authorization, error)
|
GetAuthorization(ctx context.Context, id string) (*Authorization, error)
|
||||||
UpdateAuthorization(ctx context.Context, az *Authorization) error
|
UpdateAuthorization(ctx context.Context, az *Authorization) error
|
||||||
|
GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*Authorization, error)
|
||||||
|
|
||||||
CreateCertificate(ctx context.Context, cert *Certificate) error
|
CreateCertificate(ctx context.Context, cert *Certificate) error
|
||||||
GetCertificate(ctx context.Context, id string) (*Certificate, error)
|
GetCertificate(ctx context.Context, id string) (*Certificate, error)
|
||||||
|
GetCertificateBySerial(ctx context.Context, serial string) (*Certificate, error)
|
||||||
|
|
||||||
CreateChallenge(ctx context.Context, ch *Challenge) error
|
CreateChallenge(ctx context.Context, ch *Challenge) error
|
||||||
GetChallenge(ctx context.Context, id, authzID string) (*Challenge, error)
|
GetChallenge(ctx context.Context, id, authzID string) (*Challenge, error)
|
||||||
|
@ -50,12 +52,14 @@ type MockDB struct {
|
||||||
MockCreateNonce func(ctx context.Context) (Nonce, error)
|
MockCreateNonce func(ctx context.Context) (Nonce, error)
|
||||||
MockDeleteNonce func(ctx context.Context, nonce Nonce) error
|
MockDeleteNonce func(ctx context.Context, nonce Nonce) error
|
||||||
|
|
||||||
MockCreateAuthorization func(ctx context.Context, az *Authorization) error
|
MockCreateAuthorization func(ctx context.Context, az *Authorization) error
|
||||||
MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error)
|
MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error)
|
||||||
MockUpdateAuthorization func(ctx context.Context, az *Authorization) error
|
MockUpdateAuthorization func(ctx context.Context, az *Authorization) error
|
||||||
|
MockGetAuthorizationsByAccountID func(ctx context.Context, accountID string) ([]*Authorization, error)
|
||||||
|
|
||||||
MockCreateCertificate func(ctx context.Context, cert *Certificate) error
|
MockCreateCertificate func(ctx context.Context, cert *Certificate) error
|
||||||
MockGetCertificate func(ctx context.Context, id string) (*Certificate, error)
|
MockGetCertificate func(ctx context.Context, id string) (*Certificate, error)
|
||||||
|
MockGetCertificateBySerial func(ctx context.Context, serial string) (*Certificate, error)
|
||||||
|
|
||||||
MockCreateChallenge func(ctx context.Context, ch *Challenge) error
|
MockCreateChallenge func(ctx context.Context, ch *Challenge) error
|
||||||
MockGetChallenge func(ctx context.Context, id, authzID string) (*Challenge, error)
|
MockGetChallenge func(ctx context.Context, id, authzID string) (*Challenge, error)
|
||||||
|
@ -160,6 +164,16 @@ func (m *MockDB) UpdateAuthorization(ctx context.Context, az *Authorization) err
|
||||||
return m.MockError
|
return m.MockError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAuthorizationsByAccountID mock
|
||||||
|
func (m *MockDB) GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*Authorization, error) {
|
||||||
|
if m.MockGetAuthorizationsByAccountID != nil {
|
||||||
|
return m.MockGetAuthorizationsByAccountID(ctx, accountID)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return nil, m.MockError
|
||||||
|
}
|
||||||
|
return nil, m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
// CreateCertificate mock
|
// CreateCertificate mock
|
||||||
func (m *MockDB) CreateCertificate(ctx context.Context, cert *Certificate) error {
|
func (m *MockDB) CreateCertificate(ctx context.Context, cert *Certificate) error {
|
||||||
if m.MockCreateCertificate != nil {
|
if m.MockCreateCertificate != nil {
|
||||||
|
@ -180,6 +194,16 @@ func (m *MockDB) GetCertificate(ctx context.Context, id string) (*Certificate, e
|
||||||
return m.MockRet1.(*Certificate), m.MockError
|
return m.MockRet1.(*Certificate), m.MockError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCertificateBySerial mock
|
||||||
|
func (m *MockDB) GetCertificateBySerial(ctx context.Context, serial string) (*Certificate, error) {
|
||||||
|
if m.MockGetCertificateBySerial != nil {
|
||||||
|
return m.MockGetCertificateBySerial(ctx, serial)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return nil, m.MockError
|
||||||
|
}
|
||||||
|
return m.MockRet1.(*Certificate), m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
// CreateChallenge mock
|
// CreateChallenge mock
|
||||||
func (m *MockDB) CreateChallenge(ctx context.Context, ch *Challenge) error {
|
func (m *MockDB) CreateChallenge(ctx context.Context, ch *Challenge) error {
|
||||||
if m.MockCreateChallenge != nil {
|
if m.MockCreateChallenge != nil {
|
||||||
|
|
|
@ -116,3 +116,37 @@ func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) e
|
||||||
nu.Error = az.Error
|
nu.Error = az.Error
|
||||||
return db.save(ctx, old.ID, nu, old, "authz", authzTable)
|
return db.save(ctx, old.ID, nu, old, "authz", authzTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAuthorizationsByAccountID retrieves and unmarshals ACME authz types from the database.
|
||||||
|
func (db *DB) GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*acme.Authorization, error) {
|
||||||
|
entries, err := db.db.List(authzTable)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error listing authz")
|
||||||
|
}
|
||||||
|
authzs := []*acme.Authorization{}
|
||||||
|
for _, entry := range entries {
|
||||||
|
dbaz := new(dbAuthz)
|
||||||
|
if err = json.Unmarshal(entry.Value, dbaz); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error unmarshaling dbAuthz key '%s' into dbAuthz struct", string(entry.Key))
|
||||||
|
}
|
||||||
|
// Filter out all dbAuthzs that don't belong to the accountID. This
|
||||||
|
// could be made more efficient with additional data structures mapping the
|
||||||
|
// Account ID to authorizations. Not trivial to do, though.
|
||||||
|
if dbaz.AccountID != accountID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
authzs = append(authzs, &acme.Authorization{
|
||||||
|
ID: dbaz.ID,
|
||||||
|
AccountID: dbaz.AccountID,
|
||||||
|
Identifier: dbaz.Identifier,
|
||||||
|
Status: dbaz.Status,
|
||||||
|
Challenges: nil, // challenges not required for current use case
|
||||||
|
Wildcard: dbaz.Wildcard,
|
||||||
|
ExpiresAt: dbaz.ExpiresAt,
|
||||||
|
Token: dbaz.Token,
|
||||||
|
Error: dbaz.Error,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return authzs, nil
|
||||||
|
}
|
||||||
|
|
|
@ -3,9 +3,11 @@ package nosql
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
|
@ -614,3 +616,154 @@ func TestDB_UpdateAuthorization(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDB_GetAuthorizationsByAccountID(t *testing.T) {
|
||||||
|
azID := "azID"
|
||||||
|
accountID := "accountID"
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
acmeErr *acme.Error
|
||||||
|
authzs []*acme.Authorization
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/db.List-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error listing authz: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/unmarshal": func(t *testing.T) test {
|
||||||
|
b := []byte(`{malformed}`)
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
return []*nosqldb.Entry{
|
||||||
|
{
|
||||||
|
Bucket: bucket,
|
||||||
|
Key: []byte(azID),
|
||||||
|
Value: b,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
authzs: nil,
|
||||||
|
err: fmt.Errorf("error unmarshaling dbAuthz key '%s' into dbAuthz struct", azID),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
now := clock.Now()
|
||||||
|
dbaz := &dbAuthz{
|
||||||
|
ID: azID,
|
||||||
|
AccountID: accountID,
|
||||||
|
Identifier: acme.Identifier{
|
||||||
|
Type: "dns",
|
||||||
|
Value: "test.ca.smallstep.com",
|
||||||
|
},
|
||||||
|
Status: acme.StatusValid,
|
||||||
|
Token: "token",
|
||||||
|
CreatedAt: now,
|
||||||
|
ExpiresAt: now.Add(5 * time.Minute),
|
||||||
|
ChallengeIDs: []string{"foo", "bar"},
|
||||||
|
Wildcard: true,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbaz)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
return []*nosqldb.Entry{
|
||||||
|
{
|
||||||
|
Bucket: bucket,
|
||||||
|
Key: []byte(azID),
|
||||||
|
Value: b,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
authzs: []*acme.Authorization{
|
||||||
|
{
|
||||||
|
ID: dbaz.ID,
|
||||||
|
AccountID: dbaz.AccountID,
|
||||||
|
Token: dbaz.Token,
|
||||||
|
Identifier: dbaz.Identifier,
|
||||||
|
Status: dbaz.Status,
|
||||||
|
Challenges: nil,
|
||||||
|
Wildcard: dbaz.Wildcard,
|
||||||
|
ExpiresAt: dbaz.ExpiresAt,
|
||||||
|
Error: dbaz.Error,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/skip-different-account": func(t *testing.T) test {
|
||||||
|
now := clock.Now()
|
||||||
|
dbaz := &dbAuthz{
|
||||||
|
ID: azID,
|
||||||
|
AccountID: "differentAccountID",
|
||||||
|
Identifier: acme.Identifier{
|
||||||
|
Type: "dns",
|
||||||
|
Value: "test.ca.smallstep.com",
|
||||||
|
},
|
||||||
|
Status: acme.StatusValid,
|
||||||
|
Token: "token",
|
||||||
|
CreatedAt: now,
|
||||||
|
ExpiresAt: now.Add(5 * time.Minute),
|
||||||
|
ChallengeIDs: []string{"foo", "bar"},
|
||||||
|
Wildcard: true,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbaz)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
return []*nosqldb.Entry{
|
||||||
|
{
|
||||||
|
Bucket: bucket,
|
||||||
|
Key: []byte(azID),
|
||||||
|
Value: b,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
authzs: []*acme.Authorization{},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
d := DB{db: tc.db}
|
||||||
|
if azs, err := d.GetAuthorizationsByAccountID(context.Background(), accountID); err != nil {
|
||||||
|
switch k := err.(type) {
|
||||||
|
case *acme.Error:
|
||||||
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if assert.Nil(t, tc.err) {
|
||||||
|
if !cmp.Equal(azs, tc.authzs) {
|
||||||
|
t.Errorf("db.GetAuthorizationsByAccountID() diff =\n%s", cmp.Diff(azs, tc.authzs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -21,6 +21,11 @@ type dbCert struct {
|
||||||
Intermediates []byte `json:"intermediates"`
|
Intermediates []byte `json:"intermediates"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type dbSerial struct {
|
||||||
|
Serial string `json:"serial"`
|
||||||
|
CertificateID string `json:"certificateID"`
|
||||||
|
}
|
||||||
|
|
||||||
// CreateCertificate creates and stores an ACME certificate type.
|
// CreateCertificate creates and stores an ACME certificate type.
|
||||||
func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) error {
|
func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) error {
|
||||||
var err error
|
var err error
|
||||||
|
@ -49,7 +54,17 @@ func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) err
|
||||||
Intermediates: intermediates,
|
Intermediates: intermediates,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
}
|
}
|
||||||
return db.save(ctx, cert.ID, dbch, nil, "certificate", certTable)
|
err = db.save(ctx, cert.ID, dbch, nil, "certificate", certTable)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
serial := cert.Leaf.SerialNumber.String()
|
||||||
|
dbSerial := &dbSerial{
|
||||||
|
Serial: serial,
|
||||||
|
CertificateID: cert.ID,
|
||||||
|
}
|
||||||
|
return db.save(ctx, serial, dbSerial, nil, "serial", certBySerialTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCertificate retrieves and unmarshals an ACME certificate type from the
|
// GetCertificate retrieves and unmarshals an ACME certificate type from the
|
||||||
|
@ -80,6 +95,24 @@ func (db *DB) GetCertificate(ctx context.Context, id string) (*acme.Certificate,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCertificateBySerial retrieves and unmarshals an ACME certificate type from the
|
||||||
|
// datastore based on a certificate serial number.
|
||||||
|
func (db *DB) GetCertificateBySerial(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||||
|
b, err := db.db.Get(certBySerialTable, []byte(serial))
|
||||||
|
if nosql.IsErrNotFound(err) {
|
||||||
|
return nil, acme.NewError(acme.ErrorMalformedType, "certificate with serial %s not found", serial)
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error loading certificate ID for serial %s", serial)
|
||||||
|
}
|
||||||
|
|
||||||
|
dbSerial := new(dbSerial)
|
||||||
|
if err := json.Unmarshal(b, dbSerial); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error unmarshaling certificate with serial %s", serial)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.GetCertificate(ctx, dbSerial.CertificateID)
|
||||||
|
}
|
||||||
|
|
||||||
func parseBundle(b []byte) ([]*x509.Certificate, error) {
|
func parseBundle(b []byte) ([]*x509.Certificate, error) {
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
package nosql
|
package nosql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -14,7 +16,6 @@ import (
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
"github.com/smallstep/nosql"
|
"github.com/smallstep/nosql"
|
||||||
nosqldb "github.com/smallstep/nosql/database"
|
nosqldb "github.com/smallstep/nosql/database"
|
||||||
|
|
||||||
"go.step.sm/crypto/pemutil"
|
"go.step.sm/crypto/pemutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -75,18 +76,36 @@ func TestDB_CreateCertificate(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
db: &db.MockNoSQLDB{
|
db: &db.MockNoSQLDB{
|
||||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
*idPtr = string(key)
|
if !bytes.Equal(bucket, certTable) && !bytes.Equal(bucket, certBySerialTable) {
|
||||||
assert.Equals(t, bucket, certTable)
|
t.Fail()
|
||||||
assert.Equals(t, key, []byte(cert.ID))
|
}
|
||||||
assert.Equals(t, old, nil)
|
if bytes.Equal(bucket, certTable) {
|
||||||
|
*idPtr = string(key)
|
||||||
|
assert.Equals(t, bucket, certTable)
|
||||||
|
assert.Equals(t, key, []byte(cert.ID))
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
dbc := new(dbCert)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbc))
|
||||||
|
assert.Equals(t, dbc.ID, string(key))
|
||||||
|
assert.Equals(t, dbc.ID, cert.ID)
|
||||||
|
assert.Equals(t, dbc.AccountID, cert.AccountID)
|
||||||
|
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
|
||||||
|
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
|
||||||
|
}
|
||||||
|
if bytes.Equal(bucket, certBySerialTable) {
|
||||||
|
assert.Equals(t, bucket, certBySerialTable)
|
||||||
|
assert.Equals(t, key, []byte(cert.Leaf.SerialNumber.String()))
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
dbs := new(dbSerial)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbs))
|
||||||
|
assert.Equals(t, dbs.Serial, string(key))
|
||||||
|
assert.Equals(t, dbs.CertificateID, cert.ID)
|
||||||
|
|
||||||
|
*idPtr = cert.ID
|
||||||
|
}
|
||||||
|
|
||||||
dbc := new(dbCert)
|
|
||||||
assert.FatalError(t, json.Unmarshal(nu, dbc))
|
|
||||||
assert.Equals(t, dbc.ID, string(key))
|
|
||||||
assert.Equals(t, dbc.ID, cert.ID)
|
|
||||||
assert.Equals(t, dbc.AccountID, cert.AccountID)
|
|
||||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
|
|
||||||
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
|
|
||||||
return nil, true, nil
|
return nil, true, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -317,3 +336,135 @@ func Test_parseBundle(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDB_GetCertificateBySerial(t *testing.T) {
|
||||||
|
leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
certID := "certID"
|
||||||
|
serial := ""
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
acmeErr *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/not-found": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
if bytes.Equal(bucket, certBySerialTable) {
|
||||||
|
return nil, nosqldb.ErrNotFound
|
||||||
|
}
|
||||||
|
return nil, errors.New("wrong table")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate with serial %s not found", serial),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
if bytes.Equal(bucket, certBySerialTable) {
|
||||||
|
return nil, errors.New("force")
|
||||||
|
}
|
||||||
|
return nil, errors.New("wrong table")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: fmt.Errorf("error loading certificate ID for serial %s", serial),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/unmarshal-dbSerial": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
if bytes.Equal(bucket, certBySerialTable) {
|
||||||
|
return []byte(`{"serial":malformed!}`), nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("wrong table")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: fmt.Errorf("error unmarshaling certificate with serial %s", serial),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
|
||||||
|
if bytes.Equal(bucket, certBySerialTable) {
|
||||||
|
certSerial := dbSerial{
|
||||||
|
Serial: serial,
|
||||||
|
CertificateID: certID,
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := json.Marshal(certSerial)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Equal(bucket, certTable) {
|
||||||
|
cert := dbCert{
|
||||||
|
ID: certID,
|
||||||
|
AccountID: "accountID",
|
||||||
|
OrderID: "orderID",
|
||||||
|
Leaf: pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: leaf.Raw,
|
||||||
|
}),
|
||||||
|
Intermediates: append(pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: inter.Raw,
|
||||||
|
}), pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: root.Raw,
|
||||||
|
})...),
|
||||||
|
CreatedAt: clock.Now(),
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(cert)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("wrong table")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, prep := range tests {
|
||||||
|
tc := prep(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
d := DB{db: tc.db}
|
||||||
|
cert, err := d.GetCertificateBySerial(context.Background(), serial)
|
||||||
|
if err != nil {
|
||||||
|
switch k := err.(type) {
|
||||||
|
case *acme.Error:
|
||||||
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, cert.ID, certID)
|
||||||
|
assert.Equals(t, cert.AccountID, "accountID")
|
||||||
|
assert.Equals(t, cert.OrderID, "orderID")
|
||||||
|
assert.Equals(t, cert.Leaf, leaf)
|
||||||
|
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ var (
|
||||||
orderTable = []byte("acme_orders")
|
orderTable = []byte("acme_orders")
|
||||||
ordersByAccountIDTable = []byte("acme_account_orders_index")
|
ordersByAccountIDTable = []byte("acme_account_orders_index")
|
||||||
certTable = []byte("acme_certs")
|
certTable = []byte("acme_certs")
|
||||||
|
certBySerialTable = []byte("acme_serial_certs_index")
|
||||||
)
|
)
|
||||||
|
|
||||||
// DB is a struct that implements the AcmeDB interface.
|
// DB is a struct that implements the AcmeDB interface.
|
||||||
|
@ -29,7 +30,7 @@ type DB struct {
|
||||||
// New configures and returns a new ACME DB backend implemented using a nosql DB.
|
// New configures and returns a new ACME DB backend implemented using a nosql DB.
|
||||||
func New(db nosqlDB.DB) (*DB, error) {
|
func New(db nosqlDB.DB) (*DB, error) {
|
||||||
tables := [][]byte{accountTable, accountByKeyIDTable, authzTable,
|
tables := [][]byte{accountTable, accountByKeyIDTable, authzTable,
|
||||||
challengeTable, nonceTable, orderTable, ordersByAccountIDTable, certTable}
|
challengeTable, nonceTable, orderTable, ordersByAccountIDTable, certTable, certBySerialTable}
|
||||||
for _, b := range tables {
|
for _, b := range tables {
|
||||||
if err := db.CreateTable(b); err != nil {
|
if err := db.CreateTable(b); err != nil {
|
||||||
return nil, errors.Wrapf(err, "error creating table %s",
|
return nil, errors.Wrapf(err, "error creating table %s",
|
||||||
|
|
|
@ -147,7 +147,7 @@ var (
|
||||||
},
|
},
|
||||||
ErrorAlreadyRevokedType: {
|
ErrorAlreadyRevokedType: {
|
||||||
typ: officialACMEPrefix + ErrorAlreadyRevokedType.String(),
|
typ: officialACMEPrefix + ErrorAlreadyRevokedType.String(),
|
||||||
details: "Certificate already Revoked",
|
details: "Certificate already revoked",
|
||||||
status: 400,
|
status: 400,
|
||||||
},
|
},
|
||||||
ErrorBadCSRType: {
|
ErrorBadCSRType: {
|
||||||
|
|
|
@ -17,7 +17,9 @@ import (
|
||||||
type IdentifierType string
|
type IdentifierType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
IP IdentifierType = "ip"
|
// IP is the ACME ip identifier type
|
||||||
|
IP IdentifierType = "ip"
|
||||||
|
// DNS is the ACME dns identifier type
|
||||||
DNS IdentifierType = "dns"
|
DNS IdentifierType = "dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"go.step.sm/crypto/x509util"
|
"go.step.sm/crypto/x509util"
|
||||||
)
|
)
|
||||||
|
@ -287,6 +288,14 @@ func (m *mockSignAuth) LoadProvisionerByName(name string) (provisioner.Interface
|
||||||
return m.ret1.(provisioner.Interface), m.err
|
return m.ret1.(provisioner.Interface), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockSignAuth) IsRevoked(sn string) (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSignAuth) Revoke(context.Context, *authority.RevokeOptions) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestOrder_Finalize(t *testing.T) {
|
func TestOrder_Finalize(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
o *Order
|
o *Order
|
||||||
|
|
|
@ -348,7 +348,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
||||||
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
||||||
roots, err := h.Authority.GetRoots()
|
roots, err := h.Authority.GetRoots()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err, "error getting roots"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -366,7 +366,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
||||||
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
||||||
federated, err := h.Authority.GetFederation()
|
federated, err := h.Authority.GetFederation()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err, "error getting federated roots"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -96,7 +96,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err, "error revoking certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -74,7 +74,7 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err, "error signing certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
certChainPEM := certChainToPEM(certChain)
|
certChainPEM := certChainToPEM(certChain)
|
||||||
|
|
|
@ -293,7 +293,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
|
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -301,7 +301,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
|
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
|
||||||
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
|
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
addUserCertificate = &SSHCertificate{addUserCert}
|
addUserCertificate = &SSHCertificate{addUserCert}
|
||||||
|
@ -326,7 +326,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...)
|
certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err, "error signing identity certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
identityCertificate = certChainToPEM(certChain)
|
identityCertificate = certChainToPEM(certChain)
|
||||||
|
|
|
@ -68,7 +68,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...)
|
newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -60,7 +60,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
newCert, err := h.Authority.RenewSSH(ctx, oldCert)
|
newCert, err := h.Authority.RenewSSH(ctx, oldCert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -75,7 +75,7 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||||
opts.OTT = body.OTT
|
opts.OTT = body.OTT
|
||||||
|
|
||||||
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -588,6 +588,19 @@ func (a *Authority) CloseForReload() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsRevoked returns whether or not a certificate has been
|
||||||
|
// revoked before.
|
||||||
|
func (a *Authority) IsRevoked(sn string) (bool, error) {
|
||||||
|
// Check the passive revocation table.
|
||||||
|
if lca, ok := a.adminDB.(interface {
|
||||||
|
IsRevoked(string) (bool, error)
|
||||||
|
}); ok {
|
||||||
|
return lca.IsRevoked(sn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.db.IsRevoked(sn)
|
||||||
|
}
|
||||||
|
|
||||||
// requiresDecrypter returns whether the Authority
|
// requiresDecrypter returns whether the Authority
|
||||||
// requires a KMS that provides a crypto.Decrypter
|
// requires a KMS that provides a crypto.Decrypter
|
||||||
// Currently this is only required when SCEP is
|
// Currently this is only required when SCEP is
|
||||||
|
|
|
@ -274,19 +274,9 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
|
||||||
//
|
//
|
||||||
// TODO(mariano): should we authorize by default?
|
// TODO(mariano): should we authorize by default?
|
||||||
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
||||||
var err error
|
|
||||||
var isRevoked bool
|
|
||||||
var opts = []interface{}{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())}
|
|
||||||
|
|
||||||
// Check the passive revocation table.
|
|
||||||
serial := cert.SerialNumber.String()
|
serial := cert.SerialNumber.String()
|
||||||
if lca, ok := a.adminDB.(interface {
|
var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)}
|
||||||
IsRevoked(string) (bool, error)
|
isRevoked, err := a.IsRevoked(serial)
|
||||||
}); ok {
|
|
||||||
isRevoked, err = lca.IsRevoked(serial)
|
|
||||||
} else {
|
|
||||||
isRevoked, err = a.db.IsRevoked(serial)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
|
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
|
||||||
}
|
}
|
||||||
|
|
|
@ -99,6 +99,15 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AuthorizeRevoke is called just before the certificate is to be revoked by
|
||||||
|
// the CA. It can be used to authorize revocation of a certificate. It
|
||||||
|
// currently is a no-op.
|
||||||
|
// TODO(hs): add configuration option that toggles revocation? Or change function signature to make it more useful?
|
||||||
|
// Or move certain logic out of the Revoke API to here? Would likely involve some more stuff in the ctx.
|
||||||
|
func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AuthorizeRenew returns an error if the renewal is disabled.
|
// AuthorizeRenew returns an error if the renewal is disabled.
|
||||||
// NOTE: This method does not actually validate the certificate or check it's
|
// NOTE: This method does not actually validate the certificate or check it's
|
||||||
// revocation status. Just confirms that the provisioner that created the
|
// revocation status. Just confirms that the provisioner that created the
|
||||||
|
|
|
@ -184,7 +184,6 @@ func TestUnimplementedMethods(t *testing.T) {
|
||||||
{"x5c/sshRenew", &X5C{}, SSHRenewMethod},
|
{"x5c/sshRenew", &X5C{}, SSHRenewMethod},
|
||||||
{"x5c/sshRekey", &X5C{}, SSHRekeyMethod},
|
{"x5c/sshRekey", &X5C{}, SSHRekeyMethod},
|
||||||
{"x5c/sshRevoke", &X5C{}, SSHRekeyMethod},
|
{"x5c/sshRevoke", &X5C{}, SSHRekeyMethod},
|
||||||
{"acme/revoke", &ACME{}, RevokeMethod},
|
|
||||||
{"acme/sshSign", &ACME{}, SSHSignMethod},
|
{"acme/sshSign", &ACME{}, SSHSignMethod},
|
||||||
{"acme/sshRekey", &ACME{}, SSHRekeyMethod},
|
{"acme/sshRekey", &ACME{}, SSHRekeyMethod},
|
||||||
{"acme/sshRenew", &ACME{}, SSHRenewMethod},
|
{"acme/sshRenew", &ACME{}, SSHRenewMethod},
|
||||||
|
|
|
@ -9,12 +9,14 @@ import (
|
||||||
"encoding/asn1"
|
"encoding/asn1"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
|
"go.step.sm/crypto/keyutil"
|
||||||
"go.step.sm/crypto/x509util"
|
"go.step.sm/crypto/x509util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -83,19 +85,19 @@ type emailOnlyIdentity string
|
||||||
func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error {
|
func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error {
|
||||||
switch {
|
switch {
|
||||||
case len(req.DNSNames) > 0:
|
case len(req.DNSNames) > 0:
|
||||||
return errors.New("certificate request cannot contain DNS names")
|
return errs.Forbidden("certificate request cannot contain DNS names")
|
||||||
case len(req.IPAddresses) > 0:
|
case len(req.IPAddresses) > 0:
|
||||||
return errors.New("certificate request cannot contain IP addresses")
|
return errs.Forbidden("certificate request cannot contain IP addresses")
|
||||||
case len(req.URIs) > 0:
|
case len(req.URIs) > 0:
|
||||||
return errors.New("certificate request cannot contain URIs")
|
return errs.Forbidden("certificate request cannot contain URIs")
|
||||||
case len(req.EmailAddresses) == 0:
|
case len(req.EmailAddresses) == 0:
|
||||||
return errors.New("certificate request does not contain any email address")
|
return errs.Forbidden("certificate request does not contain any email address")
|
||||||
case len(req.EmailAddresses) > 1:
|
case len(req.EmailAddresses) > 1:
|
||||||
return errors.New("certificate request contains too many email addresses")
|
return errs.Forbidden("certificate request contains too many email addresses")
|
||||||
case req.EmailAddresses[0] == "":
|
case req.EmailAddresses[0] == "":
|
||||||
return errors.New("certificate request cannot contain an empty email address")
|
return errs.Forbidden("certificate request cannot contain an empty email address")
|
||||||
case req.EmailAddresses[0] != string(e):
|
case req.EmailAddresses[0] != string(e):
|
||||||
return errors.Errorf("certificate request does not contain the valid email address, got %s, want %s", req.EmailAddresses[0], e)
|
return errs.Forbidden("certificate request does not contain the valid email address - got %s, want %s", req.EmailAddresses[0], e)
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -108,12 +110,13 @@ type defaultPublicKeyValidator struct{}
|
||||||
func (v defaultPublicKeyValidator) Valid(req *x509.CertificateRequest) error {
|
func (v defaultPublicKeyValidator) Valid(req *x509.CertificateRequest) error {
|
||||||
switch k := req.PublicKey.(type) {
|
switch k := req.PublicKey.(type) {
|
||||||
case *rsa.PublicKey:
|
case *rsa.PublicKey:
|
||||||
if k.Size() < 256 {
|
if k.Size() < keyutil.MinRSAKeyBytes {
|
||||||
return errors.New("rsa key in CSR must be at least 2048 bits (256 bytes)")
|
return errs.Forbidden("certificate request RSA key must be at least %d bits (%d bytes)",
|
||||||
|
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)
|
||||||
}
|
}
|
||||||
case *ecdsa.PublicKey, ed25519.PublicKey:
|
case *ecdsa.PublicKey, ed25519.PublicKey:
|
||||||
default:
|
default:
|
||||||
return errors.Errorf("unrecognized public key of type '%T' in CSR", k)
|
return errs.BadRequest("certificate request key of type '%T' is not supported", k)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -139,11 +142,12 @@ func (v publicKeyMinimumLengthValidator) Valid(req *x509.CertificateRequest) err
|
||||||
case *rsa.PublicKey:
|
case *rsa.PublicKey:
|
||||||
minimumLengthInBytes := v.length / 8
|
minimumLengthInBytes := v.length / 8
|
||||||
if k.Size() < minimumLengthInBytes {
|
if k.Size() < minimumLengthInBytes {
|
||||||
return errors.Errorf("rsa key in CSR must be at least %d bits (%d bytes)", v.length, minimumLengthInBytes)
|
return errs.Forbidden("certificate request RSA key must be at least %d bits (%d bytes)",
|
||||||
|
v.length, minimumLengthInBytes)
|
||||||
}
|
}
|
||||||
case *ecdsa.PublicKey, ed25519.PublicKey:
|
case *ecdsa.PublicKey, ed25519.PublicKey:
|
||||||
default:
|
default:
|
||||||
return errors.Errorf("unrecognized public key of type '%T' in CSR", k)
|
return errs.BadRequest("certificate request key of type '%T' is not supported", k)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -158,7 +162,7 @@ func (v commonNameValidator) Valid(req *x509.CertificateRequest) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if req.Subject.CommonName != string(v) {
|
if req.Subject.CommonName != string(v) {
|
||||||
return errors.Errorf("certificate request does not contain the valid common name; requested common name = %s, token subject = %s", req.Subject.CommonName, v)
|
return errs.Forbidden("certificate request does not contain the valid common name - got %s, want %s", req.Subject.CommonName, v)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -176,7 +180,7 @@ func (v commonNameSliceValidator) Valid(req *x509.CertificateRequest) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return errors.Errorf("certificate request does not contain the valid common name, got %s, want %s", req.Subject.CommonName, v)
|
return errs.Forbidden("certificate request does not contain the valid common name - got %s, want %s", req.Subject.CommonName, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// dnsNamesValidator validates the DNS names SAN of a certificate request.
|
// dnsNamesValidator validates the DNS names SAN of a certificate request.
|
||||||
|
@ -197,7 +201,7 @@ func (v dnsNamesValidator) Valid(req *x509.CertificateRequest) error {
|
||||||
got[s] = true
|
got[s] = true
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(want, got) {
|
if !reflect.DeepEqual(want, got) {
|
||||||
return errors.Errorf("certificate request does not contain the valid DNS names - got %v, want %v", req.DNSNames, v)
|
return errs.Forbidden("certificate request does not contain the valid DNS names - got %v, want %v", req.DNSNames, v)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -220,7 +224,7 @@ func (v ipAddressesValidator) Valid(req *x509.CertificateRequest) error {
|
||||||
got[ip.String()] = true
|
got[ip.String()] = true
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(want, got) {
|
if !reflect.DeepEqual(want, got) {
|
||||||
return errors.Errorf("IP Addresses claim failed - got %v, want %v", req.IPAddresses, v)
|
return errs.Forbidden("certificate request does not contain the valid IP addresses - got %v, want %v", req.IPAddresses, v)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -243,7 +247,7 @@ func (v emailAddressesValidator) Valid(req *x509.CertificateRequest) error {
|
||||||
got[s] = true
|
got[s] = true
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(want, got) {
|
if !reflect.DeepEqual(want, got) {
|
||||||
return errors.Errorf("certificate request does not contain the valid Email Addresses - got %v, want %v", req.EmailAddresses, v)
|
return errs.Forbidden("certificate request does not contain the valid email addresses - got %v, want %v", req.EmailAddresses, v)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -266,7 +270,7 @@ func (v urisValidator) Valid(req *x509.CertificateRequest) error {
|
||||||
got[u.String()] = true
|
got[u.String()] = true
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(want, got) {
|
if !reflect.DeepEqual(want, got) {
|
||||||
return errors.Errorf("URIs claim failed - got %v, want %v", req.URIs, v)
|
return errs.Forbidden("certificate request does not contain the valid URIs - got %v, want %v", req.URIs, v)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -334,15 +338,15 @@ func (v profileLimitDuration) Modify(cert *x509.Certificate, so SignOptions) err
|
||||||
backdate = -1 * so.Backdate
|
backdate = -1 * so.Backdate
|
||||||
}
|
}
|
||||||
if notBefore.Before(v.notBefore) {
|
if notBefore.Before(v.notBefore) {
|
||||||
return errors.Errorf("requested certificate notBefore (%s) is before "+
|
return errs.Forbidden(
|
||||||
"the active validity window of the provisioning credential (%s)",
|
"requested certificate notBefore (%s) is before the active validity window of the provisioning credential (%s)",
|
||||||
notBefore, v.notBefore)
|
notBefore, v.notBefore)
|
||||||
}
|
}
|
||||||
|
|
||||||
notAfter := so.NotAfter.RelativeTime(notBefore)
|
notAfter := so.NotAfter.RelativeTime(notBefore)
|
||||||
if notAfter.After(v.notAfter) {
|
if notAfter.After(v.notAfter) {
|
||||||
return errors.Errorf("requested certificate notAfter (%s) is after "+
|
return errs.Forbidden(
|
||||||
"the expiration of the provisioning credential (%s)",
|
"requested certificate notAfter (%s) is after the expiration of the provisioning credential (%s)",
|
||||||
notAfter, v.notAfter)
|
notAfter, v.notAfter)
|
||||||
}
|
}
|
||||||
if notAfter.IsZero() {
|
if notAfter.IsZero() {
|
||||||
|
@ -388,14 +392,14 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
|
||||||
return errs.BadRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb)
|
return errs.BadRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb)
|
||||||
}
|
}
|
||||||
if d < v.min {
|
if d < v.min {
|
||||||
return errs.BadRequest("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min)
|
return errs.Forbidden("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min)
|
||||||
}
|
}
|
||||||
// NOTE: this check is not "technically correct". We're allowing the max
|
// NOTE: this check is not "technically correct". We're allowing the max
|
||||||
// duration of a cert to be "max + backdate" and not all certificates will
|
// duration of a cert to be "max + backdate" and not all certificates will
|
||||||
// be backdated (e.g. if a user passes the NotBefore value then we do not
|
// be backdated (e.g. if a user passes the NotBefore value then we do not
|
||||||
// apply a backdate). This is good enough.
|
// apply a backdate). This is good enough.
|
||||||
if d > v.max+o.Backdate {
|
if d > v.max+o.Backdate {
|
||||||
return errs.BadRequest("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate)
|
return errs.Forbidden("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -422,16 +426,15 @@ func newForceCNOption(forceCN bool) *forceCNOption {
|
||||||
|
|
||||||
func (o *forceCNOption) Modify(cert *x509.Certificate, _ SignOptions) error {
|
func (o *forceCNOption) Modify(cert *x509.Certificate, _ SignOptions) error {
|
||||||
if !o.ForceCN {
|
if !o.ForceCN {
|
||||||
// Forcing CN is disabled, do nothing to certificate
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Force the common name to be the first DNS if not provided.
|
||||||
if cert.Subject.CommonName == "" {
|
if cert.Subject.CommonName == "" {
|
||||||
if len(cert.DNSNames) > 0 {
|
if len(cert.DNSNames) == 0 {
|
||||||
cert.Subject.CommonName = cert.DNSNames[0]
|
return errs.BadRequest("cannot force common name, DNS names is empty")
|
||||||
} else {
|
|
||||||
return errors.New("Cannot force CN, DNSNames is empty")
|
|
||||||
}
|
}
|
||||||
|
cert.Subject.CommonName = cert.DNSNames[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -456,7 +459,7 @@ func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValue
|
||||||
func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error {
|
func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error {
|
||||||
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID, o.KeyValuePairs...)
|
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID, o.KeyValuePairs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errs.NewError(http.StatusInternalServerError, err, "error creating certificate")
|
||||||
}
|
}
|
||||||
// Prepend the provisioner extension. In the auth.Sign code we will
|
// Prepend the provisioner extension. In the auth.Sign code we will
|
||||||
// force the resulting certificate to only have one extension, the
|
// force the resulting certificate to only have one extension, the
|
||||||
|
@ -477,7 +480,7 @@ func createProvisionerExtension(typ int, name, credentialID string, keyValuePair
|
||||||
KeyValuePairs: keyValuePairs,
|
KeyValuePairs: keyValuePairs,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return pkix.Extension{}, errors.Wrapf(err, "error marshaling provisioner extension")
|
return pkix.Extension{}, errors.Wrap(err, "error marshaling provisioner extension")
|
||||||
}
|
}
|
||||||
return pkix.Extension{
|
return pkix.Extension{
|
||||||
Id: stepOIDProvisioner,
|
Id: stepOIDProvisioner,
|
||||||
|
|
|
@ -77,12 +77,12 @@ func Test_defaultPublicKeyValidator_Valid(t *testing.T) {
|
||||||
{
|
{
|
||||||
"fail/unrecognized-key-type",
|
"fail/unrecognized-key-type",
|
||||||
&x509.CertificateRequest{PublicKey: "foo"},
|
&x509.CertificateRequest{PublicKey: "foo"},
|
||||||
errors.New("unrecognized public key of type 'string' in CSR"),
|
errors.New("certificate request key of type 'string' is not supported"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"fail/rsa/too-short",
|
"fail/rsa/too-short",
|
||||||
shortRSA,
|
shortRSA,
|
||||||
errors.New("rsa key in CSR must be at least 2048 bits (256 bytes)"),
|
errors.New("certificate request RSA key must be at least 2048 bits (256 bytes)"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ok/rsa",
|
"ok/rsa",
|
||||||
|
@ -303,14 +303,14 @@ func Test_defaultSANsValidator_Valid(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
csr: &x509.CertificateRequest{EmailAddresses: []string{"max@fx.com", "mariano@fx.com"}},
|
csr: &x509.CertificateRequest{EmailAddresses: []string{"max@fx.com", "mariano@fx.com"}},
|
||||||
expectedSANs: []string{"dcow@fx.com"},
|
expectedSANs: []string{"dcow@fx.com"},
|
||||||
err: errors.New("certificate request does not contain the valid Email Addresses"),
|
err: errors.New("certificate request does not contain the valid email addresses"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/ipAddressesValidator": func() test {
|
"fail/ipAddressesValidator": func() test {
|
||||||
return test{
|
return test{
|
||||||
csr: &x509.CertificateRequest{IPAddresses: []net.IP{net.ParseIP("1.1.1.1"), net.ParseIP("127.0.0.1")}},
|
csr: &x509.CertificateRequest{IPAddresses: []net.IP{net.ParseIP("1.1.1.1"), net.ParseIP("127.0.0.1")}},
|
||||||
expectedSANs: []string{"127.0.0.1"},
|
expectedSANs: []string{"127.0.0.1"},
|
||||||
err: errors.New("IP Addresses claim failed"),
|
err: errors.New("certificate request does not contain the valid IP addresses"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/urisValidator": func() test {
|
"fail/urisValidator": func() test {
|
||||||
|
@ -321,7 +321,7 @@ func Test_defaultSANsValidator_Valid(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
csr: &x509.CertificateRequest{URIs: []*url.URL{u1, u2}},
|
csr: &x509.CertificateRequest{URIs: []*url.URL{u1, u2}},
|
||||||
expectedSANs: []string{"urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959"},
|
expectedSANs: []string{"urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959"},
|
||||||
err: errors.New("URIs claim failed"),
|
err: errors.New("certificate request does not contain the valid URIs"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func() test {
|
"ok": func() test {
|
||||||
|
@ -512,7 +512,7 @@ func Test_forceCN_Option(t *testing.T) {
|
||||||
Subject: pkix.Name{},
|
Subject: pkix.Name{},
|
||||||
DNSNames: []string{},
|
DNSNames: []string{},
|
||||||
},
|
},
|
||||||
err: errors.New("Cannot force CN, DNSNames is empty"),
|
err: errors.New("cannot force common name, DNS names is empty"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,7 +56,12 @@ type SignSSHOptions struct {
|
||||||
// Validate validates the given SignSSHOptions.
|
// Validate validates the given SignSSHOptions.
|
||||||
func (o SignSSHOptions) Validate() error {
|
func (o SignSSHOptions) Validate() error {
|
||||||
if o.CertType != "" && o.CertType != SSHUserCert && o.CertType != SSHHostCert {
|
if o.CertType != "" && o.CertType != SSHUserCert && o.CertType != SSHHostCert {
|
||||||
return errs.BadRequest("unknown certificate type '%s'", o.CertType)
|
return errs.BadRequest("certType '%s' is not valid", o.CertType)
|
||||||
|
}
|
||||||
|
for _, p := range o.Principals {
|
||||||
|
if p == "" {
|
||||||
|
return errs.BadRequest("principals cannot contain empty values")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -75,7 +80,7 @@ func (o SignSSHOptions) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
|
||||||
case SSHHostCert:
|
case SSHHostCert:
|
||||||
cert.CertType = ssh.HostCert
|
cert.CertType = ssh.HostCert
|
||||||
default:
|
default:
|
||||||
return errors.Errorf("ssh certificate has an unknown type - %s", o.CertType)
|
return errs.BadRequest("ssh certificate has an unknown type '%s'", o.CertType)
|
||||||
}
|
}
|
||||||
|
|
||||||
cert.KeyId = o.KeyID
|
cert.KeyId = o.KeyID
|
||||||
|
@ -95,7 +100,7 @@ func (o SignSSHOptions) ModifyValidity(cert *ssh.Certificate) error {
|
||||||
cert.ValidBefore = uint64(o.ValidBefore.RelativeTime(t).Unix())
|
cert.ValidBefore = uint64(o.ValidBefore.RelativeTime(t).Unix())
|
||||||
}
|
}
|
||||||
if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore {
|
if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore {
|
||||||
return errors.New("ssh certificate valid after cannot be greater than valid before")
|
return errs.BadRequest("ssh certificate validAfter cannot be greater than validBefore")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -104,16 +109,16 @@ func (o SignSSHOptions) ModifyValidity(cert *ssh.Certificate) error {
|
||||||
// ignores zero values.
|
// ignores zero values.
|
||||||
func (o SignSSHOptions) match(got SignSSHOptions) error {
|
func (o SignSSHOptions) match(got SignSSHOptions) error {
|
||||||
if o.CertType != "" && got.CertType != "" && o.CertType != got.CertType {
|
if o.CertType != "" && got.CertType != "" && o.CertType != got.CertType {
|
||||||
return errors.Errorf("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType)
|
return errs.Forbidden("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType)
|
||||||
}
|
}
|
||||||
if len(o.Principals) > 0 && len(got.Principals) > 0 && !containsAllMembers(o.Principals, got.Principals) {
|
if len(o.Principals) > 0 && len(got.Principals) > 0 && !containsAllMembers(o.Principals, got.Principals) {
|
||||||
return errors.Errorf("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals)
|
return errs.Forbidden("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals)
|
||||||
}
|
}
|
||||||
if !o.ValidAfter.IsZero() && !got.ValidAfter.IsZero() && !o.ValidAfter.Equal(&got.ValidAfter) {
|
if !o.ValidAfter.IsZero() && !got.ValidAfter.IsZero() && !o.ValidAfter.Equal(&got.ValidAfter) {
|
||||||
return errors.Errorf("ssh certificate valid after does not match - got %v, want %v", got.ValidAfter, o.ValidAfter)
|
return errs.Forbidden("ssh certificate validAfter does not match - got %v, want %v", got.ValidAfter, o.ValidAfter)
|
||||||
}
|
}
|
||||||
if !o.ValidBefore.IsZero() && !got.ValidBefore.IsZero() && !o.ValidBefore.Equal(&got.ValidBefore) {
|
if !o.ValidBefore.IsZero() && !got.ValidBefore.IsZero() && !o.ValidBefore.Equal(&got.ValidBefore) {
|
||||||
return errors.Errorf("ssh certificate valid before does not match - got %v, want %v", got.ValidBefore, o.ValidBefore)
|
return errs.Forbidden("ssh certificate validBefore does not match - got %v, want %v", got.ValidBefore, o.ValidBefore)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -206,7 +211,7 @@ func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate, _ SignSSHOpt
|
||||||
cert.Extensions["permit-user-rc"] = ""
|
cert.Extensions["permit-user-rc"] = ""
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
return errors.New("ssh certificate type has not been set or is invalid")
|
return errs.BadRequest("ssh certificate has an unknown type '%d'", cert.CertType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -272,7 +277,7 @@ func (m *sshLimitDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error
|
||||||
|
|
||||||
certValidAfter := time.Unix(int64(cert.ValidAfter), 0)
|
certValidAfter := time.Unix(int64(cert.ValidAfter), 0)
|
||||||
if certValidAfter.After(m.NotAfter) {
|
if certValidAfter.After(m.NotAfter) {
|
||||||
return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validAfter (%s)",
|
return errs.Forbidden("provisioning credential expiration (%s) is before requested certificate validAfter (%s)",
|
||||||
m.NotAfter, certValidAfter)
|
m.NotAfter, certValidAfter)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -285,7 +290,7 @@ func (m *sshLimitDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error
|
||||||
} else {
|
} else {
|
||||||
certValidBefore := time.Unix(int64(cert.ValidBefore), 0)
|
certValidBefore := time.Unix(int64(cert.ValidBefore), 0)
|
||||||
if m.NotAfter.Before(certValidBefore) {
|
if m.NotAfter.Before(certValidBefore) {
|
||||||
return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validBefore (%s)",
|
return errs.Forbidden("provisioning credential expiration (%s) is before requested certificate validBefore (%s)",
|
||||||
m.NotAfter, certValidBefore)
|
m.NotAfter, certValidBefore)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -319,11 +324,11 @@ type sshCertOptionsRequireValidator struct {
|
||||||
func (v *sshCertOptionsRequireValidator) Valid(got SignSSHOptions) error {
|
func (v *sshCertOptionsRequireValidator) Valid(got SignSSHOptions) error {
|
||||||
switch {
|
switch {
|
||||||
case v.CertType && got.CertType == "":
|
case v.CertType && got.CertType == "":
|
||||||
return errors.New("ssh certificate certType cannot be empty")
|
return errs.BadRequest("ssh certificate certType cannot be empty")
|
||||||
case v.KeyID && got.KeyID == "":
|
case v.KeyID && got.KeyID == "":
|
||||||
return errors.New("ssh certificate keyID cannot be empty")
|
return errs.BadRequest("ssh certificate keyID cannot be empty")
|
||||||
case v.Principals && len(got.Principals) == 0:
|
case v.Principals && len(got.Principals) == 0:
|
||||||
return errors.New("ssh certificate principals cannot be empty")
|
return errs.BadRequest("ssh certificate principals cannot be empty")
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -354,7 +359,7 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
|
||||||
case 0:
|
case 0:
|
||||||
return errs.BadRequest("ssh certificate type has not been set")
|
return errs.BadRequest("ssh certificate type has not been set")
|
||||||
default:
|
default:
|
||||||
return errs.BadRequest("unknown ssh certificate type %d", cert.CertType)
|
return errs.BadRequest("ssh certificate has an unknown type '%d'", cert.CertType)
|
||||||
}
|
}
|
||||||
|
|
||||||
// To not take into account the backdate, time.Now() will be used to
|
// To not take into account the backdate, time.Now() will be used to
|
||||||
|
@ -363,9 +368,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case dur < min:
|
case dur < min:
|
||||||
return errs.BadRequest("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min)
|
return errs.Forbidden("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min)
|
||||||
case dur > max+opts.Backdate:
|
case dur > max+opts.Backdate:
|
||||||
return errs.BadRequest("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate)
|
return errs.Forbidden("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate)
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -381,25 +386,25 @@ type sshCertDefaultValidator struct{}
|
||||||
func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, o SignSSHOptions) error {
|
func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, o SignSSHOptions) error {
|
||||||
switch {
|
switch {
|
||||||
case len(cert.Nonce) == 0:
|
case len(cert.Nonce) == 0:
|
||||||
return errors.New("ssh certificate nonce cannot be empty")
|
return errs.Forbidden("ssh certificate nonce cannot be empty")
|
||||||
case cert.Key == nil:
|
case cert.Key == nil:
|
||||||
return errors.New("ssh certificate key cannot be nil")
|
return errs.Forbidden("ssh certificate key cannot be nil")
|
||||||
case cert.Serial == 0:
|
case cert.Serial == 0:
|
||||||
return errors.New("ssh certificate serial cannot be 0")
|
return errs.Forbidden("ssh certificate serial cannot be 0")
|
||||||
case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert:
|
case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert:
|
||||||
return errors.Errorf("ssh certificate has an unknown type: %d", cert.CertType)
|
return errs.Forbidden("ssh certificate has an unknown type '%d'", cert.CertType)
|
||||||
case cert.KeyId == "":
|
case cert.KeyId == "":
|
||||||
return errors.New("ssh certificate key id cannot be empty")
|
return errs.Forbidden("ssh certificate key id cannot be empty")
|
||||||
case cert.ValidAfter == 0:
|
case cert.ValidAfter == 0:
|
||||||
return errors.New("ssh certificate validAfter cannot be 0")
|
return errs.Forbidden("ssh certificate validAfter cannot be 0")
|
||||||
case cert.ValidBefore < uint64(now().Unix()):
|
case cert.ValidBefore < uint64(now().Unix()):
|
||||||
return errors.New("ssh certificate validBefore cannot be in the past")
|
return errs.Forbidden("ssh certificate validBefore cannot be in the past")
|
||||||
case cert.ValidBefore < cert.ValidAfter:
|
case cert.ValidBefore < cert.ValidAfter:
|
||||||
return errors.New("ssh certificate validBefore cannot be before validAfter")
|
return errs.Forbidden("ssh certificate validBefore cannot be before validAfter")
|
||||||
case cert.SignatureKey == nil:
|
case cert.SignatureKey == nil:
|
||||||
return errors.New("ssh certificate signature key cannot be nil")
|
return errs.Forbidden("ssh certificate signature key cannot be nil")
|
||||||
case cert.Signature == nil:
|
case cert.Signature == nil:
|
||||||
return errors.New("ssh certificate signature cannot be nil")
|
return errs.Forbidden("ssh certificate signature cannot be nil")
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -409,27 +414,31 @@ func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, o SignSSHOptions)
|
||||||
type sshDefaultPublicKeyValidator struct{}
|
type sshDefaultPublicKeyValidator struct{}
|
||||||
|
|
||||||
// Valid checks that certificate request common name matches the one configured.
|
// Valid checks that certificate request common name matches the one configured.
|
||||||
|
//
|
||||||
|
// TODO: this is the only validator that checks the key type. We should execute
|
||||||
|
// this before the signing. We should add a new validations interface or extend
|
||||||
|
// SSHCertOptionsValidator with the key.
|
||||||
func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SignSSHOptions) error {
|
func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SignSSHOptions) error {
|
||||||
if cert.Key == nil {
|
if cert.Key == nil {
|
||||||
return errors.New("ssh certificate key cannot be nil")
|
return errs.BadRequest("ssh certificate key cannot be nil")
|
||||||
}
|
}
|
||||||
switch cert.Key.Type() {
|
switch cert.Key.Type() {
|
||||||
case ssh.KeyAlgoRSA:
|
case ssh.KeyAlgoRSA:
|
||||||
_, in, ok := sshParseString(cert.Key.Marshal())
|
_, in, ok := sshParseString(cert.Key.Marshal())
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("ssh certificate key is invalid")
|
return errs.BadRequest("ssh certificate key is invalid")
|
||||||
}
|
}
|
||||||
key, err := sshParseRSAPublicKey(in)
|
key, err := sshParseRSAPublicKey(in)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errs.BadRequestErr(err, "error parsing public key")
|
||||||
}
|
}
|
||||||
if key.Size() < keyutil.MinRSAKeyBytes {
|
if key.Size() < keyutil.MinRSAKeyBytes {
|
||||||
return errors.Errorf("ssh certificate key must be at least %d bits (%d bytes)",
|
return errs.Forbidden("ssh certificate key must be at least %d bits (%d bytes)",
|
||||||
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)
|
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
case ssh.KeyAlgoDSA:
|
case ssh.KeyAlgoDSA:
|
||||||
return errors.New("ssh certificate key algorithm (DSA) is not supported")
|
return errs.BadRequest("ssh certificate key algorithm (DSA) is not supported")
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,14 +49,14 @@ func TestSSHOptions_Modify(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
so: SignSSHOptions{CertType: "foo"},
|
so: SignSSHOptions{CertType: "foo"},
|
||||||
cert: new(ssh.Certificate),
|
cert: new(ssh.Certificate),
|
||||||
err: errors.Errorf("ssh certificate has an unknown type - foo"),
|
err: errors.Errorf("ssh certificate has an unknown type 'foo'"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/validAfter-greater-validBefore": func() test {
|
"fail/validAfter-greater-validBefore": func() test {
|
||||||
return test{
|
return test{
|
||||||
so: SignSSHOptions{CertType: "user"},
|
so: SignSSHOptions{CertType: "user"},
|
||||||
cert: &ssh.Certificate{ValidAfter: uint64(15), ValidBefore: uint64(10)},
|
cert: &ssh.Certificate{ValidAfter: uint64(15), ValidBefore: uint64(10)},
|
||||||
err: errors.Errorf("ssh certificate valid after cannot be greater than valid before"),
|
err: errors.Errorf("ssh certificate validAfter cannot be greater than validBefore"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok/user-cert": func() test {
|
"ok/user-cert": func() test {
|
||||||
|
@ -136,14 +136,14 @@ func TestSSHOptions_Match(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
so: SignSSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute))},
|
so: SignSSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute))},
|
||||||
cmp: SignSSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(5 * time.Minute))},
|
cmp: SignSSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(5 * time.Minute))},
|
||||||
err: errors.Errorf("ssh certificate valid after does not match"),
|
err: errors.Errorf("ssh certificate validAfter does not match"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/validBefore": func() test {
|
"fail/validBefore": func() test {
|
||||||
return test{
|
return test{
|
||||||
so: SignSSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(1 * time.Minute))},
|
so: SignSSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(1 * time.Minute))},
|
||||||
cmp: SignSSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute))},
|
cmp: SignSSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute))},
|
||||||
err: errors.Errorf("ssh certificate valid before does not match"),
|
err: errors.Errorf("ssh certificate validBefore does not match"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok/original-empty": func() test {
|
"ok/original-empty": func() test {
|
||||||
|
@ -394,7 +394,7 @@ func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
modifier: sshDefaultExtensionModifier{},
|
modifier: sshDefaultExtensionModifier{},
|
||||||
cert: cert,
|
cert: cert,
|
||||||
err: errors.New("ssh certificate type has not been set or is invalid"),
|
err: errors.New("ssh certificate has an unknown type '3'"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok/host": func() test {
|
"ok/host": func() test {
|
||||||
|
@ -518,7 +518,7 @@ func Test_sshCertDefaultValidator_Valid(t *testing.T) {
|
||||||
"fail/unexpected-cert-type",
|
"fail/unexpected-cert-type",
|
||||||
// UserCert = 1, HostCert = 2
|
// UserCert = 1, HostCert = 2
|
||||||
&ssh.Certificate{Nonce: []byte("foo"), Key: sshPub, CertType: 3, Serial: 1},
|
&ssh.Certificate{Nonce: []byte("foo"), Key: sshPub, CertType: 3, Serial: 1},
|
||||||
errors.New("ssh certificate has an unknown type: 3"),
|
errors.New("ssh certificate has an unknown type '3'"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"fail/empty-cert-key-id",
|
"fail/empty-cert-key-id",
|
||||||
|
@ -725,7 +725,7 @@ func Test_sshCertValidityValidator(t *testing.T) {
|
||||||
ValidBefore: uint64(now().Add(10 * time.Minute).Unix()),
|
ValidBefore: uint64(now().Add(10 * time.Minute).Unix()),
|
||||||
},
|
},
|
||||||
SignSSHOptions{},
|
SignSSHOptions{},
|
||||||
errors.New("unknown ssh certificate type 3"),
|
errors.New("ssh certificate has an unknown type '3'"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"fail/duration<min",
|
"fail/duration<min",
|
||||||
|
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/authority/config"
|
"github.com/smallstep/certificates/authority/config"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
|
@ -174,7 +173,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
|
||||||
// validate the given SSHOptions
|
// validate the given SSHOptions
|
||||||
case provisioner.SSHCertOptionsValidator:
|
case provisioner.SSHCertOptionsValidator:
|
||||||
if err := o.Valid(opts); err != nil {
|
if err := o.Valid(opts); err != nil {
|
||||||
return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH")
|
return nil, errs.BadRequestErr(err, "error validating ssh certificate options")
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
@ -214,7 +213,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
|
||||||
// Use provisioner modifiers.
|
// Use provisioner modifiers.
|
||||||
for _, m := range mods {
|
for _, m := range mods {
|
||||||
if err := m.Modify(certTpl, opts); err != nil {
|
if err := m.Modify(certTpl, opts); err != nil {
|
||||||
return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH")
|
return nil, errs.ForbiddenErr(err, "error creating ssh certificate")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -244,7 +243,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
|
||||||
// User provisioners validators.
|
// User provisioners validators.
|
||||||
for _, v := range validators {
|
for _, v := range validators {
|
||||||
if err := v.Valid(cert, opts); err != nil {
|
if err := v.Valid(cert, opts); err != nil {
|
||||||
return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH")
|
return nil, errs.ForbiddenErr(err, "error validating ssh certificate")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -382,7 +381,7 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub
|
||||||
// Apply validators from provisioner.
|
// Apply validators from provisioner.
|
||||||
for _, v := range validators {
|
for _, v := range validators {
|
||||||
if err := v.Valid(cert, provisioner.SignSSHOptions{Backdate: backdate}); err != nil {
|
if err := v.Valid(cert, provisioner.SignSSHOptions{Backdate: backdate}); err != nil {
|
||||||
return nil, errs.Wrap(http.StatusForbidden, err, "rekeySSH")
|
return nil, errs.ForbiddenErr(err, "error validating ssh certificate")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -407,12 +406,12 @@ func (a *Authority) storeSSHCertificate(cert *ssh.Certificate) error {
|
||||||
// the given certificate.
|
// the given certificate.
|
||||||
func IsValidForAddUser(cert *ssh.Certificate) error {
|
func IsValidForAddUser(cert *ssh.Certificate) error {
|
||||||
if cert.CertType != ssh.UserCert {
|
if cert.CertType != ssh.UserCert {
|
||||||
return errors.New("certificate is not a user certificate")
|
return errs.Forbidden("certificate is not a user certificate")
|
||||||
}
|
}
|
||||||
|
|
||||||
switch len(cert.ValidPrincipals) {
|
switch len(cert.ValidPrincipals) {
|
||||||
case 0:
|
case 0:
|
||||||
return errors.New("certificate does not have any principals")
|
return errs.Forbidden("certificate does not have any principals")
|
||||||
case 1:
|
case 1:
|
||||||
return nil
|
return nil
|
||||||
case 2:
|
case 2:
|
||||||
|
@ -421,9 +420,9 @@ func IsValidForAddUser(cert *ssh.Certificate) error {
|
||||||
if strings.Index(cert.ValidPrincipals[1], "@") > 0 {
|
if strings.Index(cert.ValidPrincipals[1], "@") > 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return errors.New("certificate does not have only one principal")
|
return errs.Forbidden("certificate does not have only one principal")
|
||||||
default:
|
default:
|
||||||
return errors.New("certificate does not have only one principal")
|
return errs.Forbidden("certificate does not have only one principal")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -433,7 +432,7 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje
|
||||||
return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled")
|
return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled")
|
||||||
}
|
}
|
||||||
if err := IsValidForAddUser(subject); err != nil {
|
if err := IsValidForAddUser(subject); err != nil {
|
||||||
return nil, errs.Wrap(http.StatusForbidden, err, "signSSHAddUser")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
nonce, err := randutil.ASCII(32)
|
nonce, err := randutil.ASCII(32)
|
||||||
|
|
|
@ -94,7 +94,10 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
|
||||||
// Validate the given certificate request.
|
// Validate the given certificate request.
|
||||||
case provisioner.CertificateRequestValidator:
|
case provisioner.CertificateRequestValidator:
|
||||||
if err := k.Valid(csr); err != nil {
|
if err := k.Valid(csr); err != nil {
|
||||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...)
|
return nil, errs.ApplyOptions(
|
||||||
|
errs.ForbiddenErr(err, "error validating certificate"),
|
||||||
|
opts...,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validates the unsigned certificate template.
|
// Validates the unsigned certificate template.
|
||||||
|
@ -131,26 +134,38 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
|
||||||
|
|
||||||
// Set default subject
|
// Set default subject
|
||||||
if err := withDefaultASN1DN(a.config.AuthorityConfig.Template).Modify(leaf, signOpts); err != nil {
|
if err := withDefaultASN1DN(a.config.AuthorityConfig.Template).Modify(leaf, signOpts); err != nil {
|
||||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...)
|
return nil, errs.ApplyOptions(
|
||||||
|
errs.ForbiddenErr(err, "error creating certificate"),
|
||||||
|
opts...,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, m := range certModifiers {
|
for _, m := range certModifiers {
|
||||||
if err := m.Modify(leaf, signOpts); err != nil {
|
if err := m.Modify(leaf, signOpts); err != nil {
|
||||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...)
|
return nil, errs.ApplyOptions(
|
||||||
|
errs.ForbiddenErr(err, "error creating certificate"),
|
||||||
|
opts...,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Certificate validation.
|
// Certificate validation.
|
||||||
for _, v := range certValidators {
|
for _, v := range certValidators {
|
||||||
if err := v.Valid(leaf, signOpts); err != nil {
|
if err := v.Valid(leaf, signOpts); err != nil {
|
||||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...)
|
return nil, errs.ApplyOptions(
|
||||||
|
errs.ForbiddenErr(err, "error validating certificate"),
|
||||||
|
opts...,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Certificate modifiers after validation
|
// Certificate modifiers after validation
|
||||||
for _, m := range certEnforcers {
|
for _, m := range certEnforcers {
|
||||||
if err := m.Enforce(leaf); err != nil {
|
if err := m.Enforce(leaf); err != nil {
|
||||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...)
|
return nil, errs.ApplyOptions(
|
||||||
|
errs.ForbiddenErr(err, "error creating certificate"),
|
||||||
|
opts...,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -328,6 +343,7 @@ type RevokeOptions struct {
|
||||||
ReasonCode int
|
ReasonCode int
|
||||||
PassiveOnly bool
|
PassiveOnly bool
|
||||||
MTLS bool
|
MTLS bool
|
||||||
|
ACME bool
|
||||||
Crt *x509.Certificate
|
Crt *x509.Certificate
|
||||||
OTT string
|
OTT string
|
||||||
}
|
}
|
||||||
|
@ -345,9 +361,10 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
|
||||||
errs.WithKeyVal("reason", revokeOpts.Reason),
|
errs.WithKeyVal("reason", revokeOpts.Reason),
|
||||||
errs.WithKeyVal("passiveOnly", revokeOpts.PassiveOnly),
|
errs.WithKeyVal("passiveOnly", revokeOpts.PassiveOnly),
|
||||||
errs.WithKeyVal("MTLS", revokeOpts.MTLS),
|
errs.WithKeyVal("MTLS", revokeOpts.MTLS),
|
||||||
|
errs.WithKeyVal("ACME", revokeOpts.ACME),
|
||||||
errs.WithKeyVal("context", provisioner.MethodFromContext(ctx).String()),
|
errs.WithKeyVal("context", provisioner.MethodFromContext(ctx).String()),
|
||||||
}
|
}
|
||||||
if revokeOpts.MTLS {
|
if revokeOpts.MTLS || revokeOpts.ACME {
|
||||||
opts = append(opts, errs.WithKeyVal("certificate", base64.StdEncoding.EncodeToString(revokeOpts.Crt.Raw)))
|
opts = append(opts, errs.WithKeyVal("certificate", base64.StdEncoding.EncodeToString(revokeOpts.Crt.Raw)))
|
||||||
} else {
|
} else {
|
||||||
opts = append(opts, errs.WithKeyVal("token", revokeOpts.OTT))
|
opts = append(opts, errs.WithKeyVal("token", revokeOpts.OTT))
|
||||||
|
@ -358,6 +375,7 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
|
||||||
ReasonCode: revokeOpts.ReasonCode,
|
ReasonCode: revokeOpts.ReasonCode,
|
||||||
Reason: revokeOpts.Reason,
|
Reason: revokeOpts.Reason,
|
||||||
MTLS: revokeOpts.MTLS,
|
MTLS: revokeOpts.MTLS,
|
||||||
|
ACME: revokeOpts.ACME,
|
||||||
RevokedAt: time.Now().UTC(),
|
RevokedAt: time.Now().UTC(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -365,8 +383,8 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
|
||||||
p provisioner.Interface
|
p provisioner.Interface
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
// If not mTLS then get the TokenID of the token.
|
// If not mTLS nor ACME, then get the TokenID of the token.
|
||||||
if !revokeOpts.MTLS {
|
if !(revokeOpts.MTLS || revokeOpts.ACME) {
|
||||||
token, err := jose.ParseSigned(revokeOpts.OTT)
|
token, err := jose.ParseSigned(revokeOpts.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errs.Wrap(http.StatusUnauthorized, err,
|
return errs.Wrap(http.StatusUnauthorized, err,
|
||||||
|
|
|
@ -281,8 +281,8 @@ func TestAuthority_Sign(t *testing.T) {
|
||||||
csr: csr,
|
csr: csr,
|
||||||
extraOpts: extraOpts,
|
extraOpts: extraOpts,
|
||||||
signOpts: signOpts,
|
signOpts: signOpts,
|
||||||
err: errors.New("authority.Sign: default ASN1DN template cannot be nil"),
|
err: errors.New("default ASN1DN template cannot be nil"),
|
||||||
code: http.StatusUnauthorized,
|
code: http.StatusForbidden,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail create cert": func(t *testing.T) *signTest {
|
"fail create cert": func(t *testing.T) *signTest {
|
||||||
|
@ -309,8 +309,8 @@ func TestAuthority_Sign(t *testing.T) {
|
||||||
csr: csr,
|
csr: csr,
|
||||||
extraOpts: extraOpts,
|
extraOpts: extraOpts,
|
||||||
signOpts: _signOpts,
|
signOpts: _signOpts,
|
||||||
err: errors.New("authority.Sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h1m0s"),
|
err: errors.New("requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h1m0s"),
|
||||||
code: http.StatusBadRequest,
|
code: http.StatusForbidden,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail validate sans when adding common name not in claims": func(t *testing.T) *signTest {
|
"fail validate sans when adding common name not in claims": func(t *testing.T) *signTest {
|
||||||
|
@ -322,8 +322,8 @@ func TestAuthority_Sign(t *testing.T) {
|
||||||
csr: csr,
|
csr: csr,
|
||||||
extraOpts: extraOpts,
|
extraOpts: extraOpts,
|
||||||
signOpts: signOpts,
|
signOpts: signOpts,
|
||||||
err: errors.New("authority.Sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
|
err: errors.New("certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
|
||||||
code: http.StatusUnauthorized,
|
code: http.StatusForbidden,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail rsa key too short": func(t *testing.T) *signTest {
|
"fail rsa key too short": func(t *testing.T) *signTest {
|
||||||
|
@ -348,8 +348,8 @@ ZYtQ9Ot36qc=
|
||||||
csr: csr,
|
csr: csr,
|
||||||
extraOpts: extraOpts,
|
extraOpts: extraOpts,
|
||||||
signOpts: signOpts,
|
signOpts: signOpts,
|
||||||
err: errors.New("authority.Sign: rsa key in CSR must be at least 2048 bits (256 bytes)"),
|
err: errors.New("certificate request RSA key must be at least 2048 bits (256 bytes)"),
|
||||||
code: http.StatusUnauthorized,
|
code: http.StatusForbidden,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail store cert in db": func(t *testing.T) *signTest {
|
"fail store cert in db": func(t *testing.T) *signTest {
|
||||||
|
@ -1267,6 +1267,23 @@ func TestAuthority_Revoke(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"ok/ACME": func() test {
|
||||||
|
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{}))
|
||||||
|
|
||||||
|
crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
return test{
|
||||||
|
auth: _a,
|
||||||
|
opts: &RevokeOptions{
|
||||||
|
Crt: crt,
|
||||||
|
Serial: "102012593071130646873265215610956555026",
|
||||||
|
ReasonCode: reasonCode,
|
||||||
|
Reason: reason,
|
||||||
|
ACME: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for name, f := range tests {
|
for name, f := range tests {
|
||||||
tc := f()
|
tc := f()
|
||||||
|
|
2
ca/ca.go
2
ca/ca.go
|
@ -442,7 +442,7 @@ func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) {
|
||||||
return tlsConfig, nil
|
return tlsConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// shouldMountSCEPEndpoints returns if the CA should be
|
// shouldServeSCEPEndpoints returns if the CA should be
|
||||||
// configured with endpoints for SCEP. This is assumed to be
|
// configured with endpoints for SCEP. This is assumed to be
|
||||||
// true if a SCEPService exists, which is true in case a
|
// true if a SCEPService exists, which is true in case a
|
||||||
// SCEP provisioner was configured.
|
// SCEP provisioner was configured.
|
||||||
|
|
|
@ -200,8 +200,8 @@ ZEp7knvU2psWRw==
|
||||||
return &signTest{
|
return &signTest{
|
||||||
ca: ca,
|
ca: ca,
|
||||||
body: string(body),
|
body: string(body),
|
||||||
status: http.StatusUnauthorized,
|
status: http.StatusForbidden,
|
||||||
errMsg: errs.UnauthorizedDefaultMsg,
|
errMsg: errs.ForbiddenPrefix,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) *signTest {
|
"ok": func(t *testing.T) *signTest {
|
||||||
|
|
1
db/db.go
1
db/db.go
|
@ -104,6 +104,7 @@ type RevokedCertificateInfo struct {
|
||||||
RevokedAt time.Time
|
RevokedAt time.Time
|
||||||
TokenID string
|
TokenID string
|
||||||
MTLS bool
|
MTLS bool
|
||||||
|
ACME bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsRevoked returns whether or not a certificate with the given identifier
|
// IsRevoked returns whether or not a certificate with the given identifier
|
||||||
|
|
BIN
docs/images/star.gif
Normal file
BIN
docs/images/star.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 89 KiB |
|
@ -169,7 +169,8 @@ func StatusCodeError(code int, e error, opts ...Option) error {
|
||||||
case http.StatusUnauthorized:
|
case http.StatusUnauthorized:
|
||||||
return UnauthorizedErr(e, opts...)
|
return UnauthorizedErr(e, opts...)
|
||||||
case http.StatusForbidden:
|
case http.StatusForbidden:
|
||||||
return ForbiddenErr(e, opts...)
|
opts = append(opts, withDefaultMessage(ForbiddenDefaultMsg))
|
||||||
|
return NewErr(http.StatusForbidden, e, opts...)
|
||||||
case http.StatusInternalServerError:
|
case http.StatusInternalServerError:
|
||||||
return InternalServerErr(e, opts...)
|
return InternalServerErr(e, opts...)
|
||||||
case http.StatusNotImplemented:
|
case http.StatusNotImplemented:
|
||||||
|
@ -199,12 +200,18 @@ var (
|
||||||
// BadRequestPrefix is the prefix added to the bad request messages that are
|
// BadRequestPrefix is the prefix added to the bad request messages that are
|
||||||
// directly sent to the cli.
|
// directly sent to the cli.
|
||||||
BadRequestPrefix = "The request could not be completed: "
|
BadRequestPrefix = "The request could not be completed: "
|
||||||
|
|
||||||
|
// ForbiddenPrefix is the prefix added to the forbidden messates that are
|
||||||
|
// sent to the cli.
|
||||||
|
ForbiddenPrefix = "The request was forbidden by the certificate authority: "
|
||||||
)
|
)
|
||||||
|
|
||||||
func formatMessage(status int, msg string) string {
|
func formatMessage(status int, msg string) string {
|
||||||
switch status {
|
switch status {
|
||||||
case http.StatusBadRequest:
|
case http.StatusBadRequest:
|
||||||
return BadRequestPrefix + msg + "."
|
return BadRequestPrefix + msg + "."
|
||||||
|
case http.StatusForbidden:
|
||||||
|
return ForbiddenPrefix + msg + "."
|
||||||
default:
|
default:
|
||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
|
@ -356,14 +363,12 @@ func UnauthorizedErr(err error, opts ...Option) error {
|
||||||
|
|
||||||
// Forbidden creates a 403 error with the given format and arguments.
|
// Forbidden creates a 403 error with the given format and arguments.
|
||||||
func Forbidden(format string, args ...interface{}) error {
|
func Forbidden(format string, args ...interface{}) error {
|
||||||
args = append(args, withDefaultMessage(ForbiddenDefaultMsg))
|
return New(http.StatusForbidden, format, args...)
|
||||||
return Errorf(http.StatusForbidden, format, args...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForbiddenErr returns an 403 error with the given error.
|
// ForbiddenErr returns an 403 error with the given error.
|
||||||
func ForbiddenErr(err error, opts ...Option) error {
|
func ForbiddenErr(err error, format string, args ...interface{}) error {
|
||||||
opts = append(opts, withDefaultMessage(ForbiddenDefaultMsg))
|
return NewError(http.StatusForbidden, err, format, args...)
|
||||||
return NewErr(http.StatusForbidden, err, opts...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NotFound creates a 404 error with the given format and arguments.
|
// NotFound creates a 404 error with the given format and arguments.
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -18,6 +18,7 @@ require (
|
||||||
github.com/go-kit/kit v0.10.0 // indirect
|
github.com/go-kit/kit v0.10.0 // indirect
|
||||||
github.com/go-piv/piv-go v1.7.0
|
github.com/go-piv/piv-go v1.7.0
|
||||||
github.com/golang/mock v1.6.0
|
github.com/golang/mock v1.6.0
|
||||||
|
github.com/google/go-cmp v0.5.6
|
||||||
github.com/google/uuid v1.3.0
|
github.com/google/uuid v1.3.0
|
||||||
github.com/googleapis/gax-go/v2 v2.0.5
|
github.com/googleapis/gax-go/v2 v2.0.5
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
|
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
|
||||||
|
|
Loading…
Reference in a new issue