Merge pull request #161 from smallstep/unittests

Introduce generalized statusCoder errors and loads of ssh unit tests.
This commit is contained in:
Max 2020-01-24 16:16:00 -08:00 committed by GitHub
commit f3f8ee4207
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
88 changed files with 5620 additions and 2544 deletions

View file

@ -63,6 +63,7 @@ issues:
- declaration of "err" shadows declaration at line
- should have a package comment, unless it's in another file for this package
- error strings should not be capitalized or end with punctuation or a newline
- Wrapf call needs 1 arg but has 2 args
# golangci.com configuration
# https://github.com/golangci/golangci/wiki/Configuration
service:

View file

@ -5,7 +5,6 @@ import (
"crypto/dsa"
"crypto/ecdsa"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"encoding/base64"
@ -209,14 +208,6 @@ type RootResponse struct {
RootPEM Certificate `json:"ca"`
}
// SignRequest is the request body for a certificate signature request.
type SignRequest struct {
CsrPEM CertificateRequest `json:"csr"`
OTT string `json:"ott"`
NotAfter TimeDuration `json:"notAfter"`
NotBefore TimeDuration `json:"notBefore"`
}
// ProvisionersResponse is the response object that returns the list of
// provisioners.
type ProvisionersResponse struct {
@ -230,31 +221,6 @@ type ProvisionerKeyResponse struct {
Key string `json:"key"`
}
// Validate checks the fields of the SignRequest and returns nil if they are ok
// or an error if something is wrong.
func (s *SignRequest) Validate() error {
if s.CsrPEM.CertificateRequest == nil {
return errs.BadRequest(errors.New("missing csr"))
}
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
return errs.BadRequest(errors.Wrap(err, "invalid csr"))
}
if s.OTT == "" {
return errs.BadRequest(errors.New("missing ott"))
}
return nil
}
// SignResponse is the response object of the certificate signature request.
type SignResponse struct {
ServerPEM Certificate `json:"crt"`
CaPEM Certificate `json:"ca"`
CertChainPEM []Certificate `json:"certChain"`
TLSOptions *tlsutil.TLSOptions `json:"tlsOptions,omitempty"`
TLS *tls.ConnectionState `json:"-"`
}
// RootsResponse is the response object of the roots request.
type RootsResponse struct {
Certificates []Certificate `json:"crts"`
@ -329,7 +295,7 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
// Load root certificate with the
cert, err := h.Authority.Root(sum)
if err != nil {
WriteError(w, errs.NotFound(errors.Wrapf(err, "%s was not found", r.RequestURI)))
WriteError(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
return
}
@ -344,91 +310,17 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
return certChainPEM
}
// Sign is an HTTP handler that reads a certificate request and an
// one-time-token (ott) from the body and creates a new certificate with the
// information in the certificate request.
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, err)
return
}
opts := provisioner.Options{
NotBefore: body.NotBefore,
NotAfter: body.NotAfter,
}
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
if err != nil {
WriteError(w, errs.Unauthorized(err))
return
}
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil {
WriteError(w, errs.Forbidden(err))
return
}
certChainPEM := certChainToPEM(certChain)
var caPEM Certificate
if len(certChainPEM) > 0 {
caPEM = certChainPEM[1]
}
logCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0],
CaPEM: caPEM,
CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(),
}, http.StatusCreated)
}
// Renew uses the information of certificate in the TLS connection to create a
// new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest(errors.New("missing peer certificate")))
return
}
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0])
if err != nil {
WriteError(w, errs.Forbidden(err))
return
}
certChainPEM := certChainToPEM(certChain)
var caPEM Certificate
if len(certChainPEM) > 0 {
caPEM = certChainPEM[1]
}
logCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0],
CaPEM: caPEM,
CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(),
}, http.StatusCreated)
}
// Provisioners returns the list of provisioners configured in the authority.
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := parseCursor(r)
if err != nil {
WriteError(w, errs.BadRequest(err))
WriteError(w, errs.BadRequestErr(err))
return
}
p, next, err := h.Authority.GetProvisioners(cursor, limit)
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
return
}
JSON(w, &ProvisionersResponse{
@ -442,7 +334,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
kid := chi.URLParam(r, "kid")
key, err := h.Authority.GetEncryptedKey(kid)
if err != nil {
WriteError(w, errs.NotFound(err))
WriteError(w, errs.NotFoundErr(err))
return
}
JSON(w, &ProvisionerKeyResponse{key})
@ -452,7 +344,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots()
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}
@ -470,7 +362,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
federated, err := h.Authority.GetFederation()
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}

View file

@ -28,6 +28,7 @@ import (
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/sshutil"
"github.com/smallstep/certificates/templates"
@ -914,7 +915,7 @@ func Test_caHandler_Renew(t *testing.T) {
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"no tls", nil, nil, nil, nil, http.StatusBadRequest},
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
{"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
}
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
@ -934,13 +935,13 @@ func Test_caHandler_Renew(t *testing.T) {
res := w.Result()
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
}
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Errorf("caHandler.Root unexpected error = %v", err)
t.Errorf("caHandler.Renew unexpected error = %v", err)
}
if tt.statusCode < http.StatusBadRequest {
if !bytes.Equal(bytes.TrimSpace(body), expected) {
@ -1009,8 +1010,12 @@ func Test_caHandler_Provisioners(t *testing.T) {
t.Fatal(err)
}
expectedError400 := []byte(`{"status":400,"message":"Bad Request"}`)
expectedError500 := []byte(`{"status":500,"message":"Internal Server Error"}`)
expectedError400 := errs.BadRequest("force")
expectedError400Bytes, err := json.Marshal(expectedError400)
assert.FatalError(t, err)
expectedError500 := errs.InternalServer("force")
expectedError500Bytes, err := json.Marshal(expectedError500)
assert.FatalError(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &caHandler{
@ -1035,12 +1040,12 @@ func Test_caHandler_Provisioners(t *testing.T) {
} else {
switch tt.statusCode {
case 400:
if !bytes.Equal(bytes.TrimSpace(body), expectedError400) {
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError400)
if !bytes.Equal(bytes.TrimSpace(body), expectedError400Bytes) {
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError400Bytes)
}
case 500:
if !bytes.Equal(bytes.TrimSpace(body), expectedError500) {
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError500)
if !bytes.Equal(bytes.TrimSpace(body), expectedError500Bytes) {
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError500Bytes)
}
default:
t.Errorf("caHandler.Provisioner unexpected status code = %d", tt.statusCode)
@ -1077,7 +1082,9 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
}
expected := []byte(`{"key":"` + privKey + `"}`)
expectedError := []byte(`{"status":404,"message":"Not Found"}`)
expectedError404 := errs.NotFound("force")
expectedError404Bytes, err := json.Marshal(expectedError404)
assert.FatalError(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -1101,8 +1108,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expected)
}
} else {
if !bytes.Equal(bytes.TrimSpace(body), expectedError) {
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError)
if !bytes.Equal(bytes.TrimSpace(body), expectedError404Bytes) {
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError404Bytes)
}
}
})

35
api/renew.go Normal file
View file

@ -0,0 +1,35 @@
package api
import (
"net/http"
"github.com/smallstep/certificates/errs"
)
// Renew uses the information of certificate in the TLS connection to create a
// new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing peer certificate"))
return
}
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0])
if err != nil {
WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
return
}
certChainPEM := certChainToPEM(certChain)
var caPEM Certificate
if len(certChainPEM) > 1 {
caPEM = certChainPEM[1]
}
logCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0],
CaPEM: caPEM,
CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(),
}, http.StatusCreated)
}

View file

@ -4,7 +4,6 @@ import (
"context"
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
@ -30,13 +29,13 @@ type RevokeRequest struct {
// or an error if something is wrong.
func (r *RevokeRequest) Validate() (err error) {
if r.Serial == "" {
return errs.BadRequest(errors.New("missing serial"))
return errs.BadRequest("missing serial")
}
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
return errs.BadRequest(errors.New("reasonCode out of bounds"))
return errs.BadRequest("reasonCode out of bounds")
}
if !r.Passive {
return errs.NotImplemented(errors.New("non-passive revocation not implemented"))
return errs.NotImplemented("non-passive revocation not implemented")
}
return
@ -50,7 +49,7 @@ func (r *RevokeRequest) Validate() (err error) {
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
@ -72,7 +71,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
if len(body.OTT) > 0 {
logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, errs.Unauthorized(err))
WriteError(w, errs.UnauthorizedErr(err))
return
}
opts.OTT = body.OTT
@ -81,12 +80,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
// the client certificate Serial Number must match the serial number
// being revoked.
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest(errors.New("missing ott or peer certificate")))
WriteError(w, errs.BadRequest("missing ott or peer certificate"))
return
}
opts.Crt = r.TLS.PeerCertificates[0]
if opts.Crt.SerialNumber.String() != opts.Serial {
WriteError(w, errs.BadRequest(errors.New("revoke: serial number in mtls certificate different than body")))
WriteError(w, errs.BadRequest("revoke: serial number in mtls certificate different than body"))
return
}
// TODO: should probably be checking if the certificate was revoked here.
@ -97,7 +96,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
}
if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}

View file

@ -190,7 +190,7 @@ func Test_caHandler_Revoke(t *testing.T) {
return nil, nil
},
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
return errs.InternalServerError(errors.New("force"))
return errs.InternalServer("force")
},
},
}

89
api/sign.go Normal file
View file

@ -0,0 +1,89 @@
package api
import (
"crypto/tls"
"net/http"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/tlsutil"
)
// SignRequest is the request body for a certificate signature request.
type SignRequest struct {
CsrPEM CertificateRequest `json:"csr"`
OTT string `json:"ott"`
NotAfter TimeDuration `json:"notAfter"`
NotBefore TimeDuration `json:"notBefore"`
}
// Validate checks the fields of the SignRequest and returns nil if they are ok
// or an error if something is wrong.
func (s *SignRequest) Validate() error {
if s.CsrPEM.CertificateRequest == nil {
return errs.BadRequest("missing csr")
}
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
return errs.Wrap(http.StatusBadRequest, err, "invalid csr")
}
if s.OTT == "" {
return errs.BadRequest("missing ott")
}
return nil
}
// SignResponse is the response object of the certificate signature request.
type SignResponse struct {
ServerPEM Certificate `json:"crt"`
CaPEM Certificate `json:"ca"`
CertChainPEM []Certificate `json:"certChain"`
TLSOptions *tlsutil.TLSOptions `json:"tlsOptions,omitempty"`
TLS *tls.ConnectionState `json:"-"`
}
// Sign is an HTTP handler that reads a certificate request and an
// one-time-token (ott) from the body and creates a new certificate with the
// information in the certificate request.
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, err)
return
}
opts := provisioner.Options{
NotBefore: body.NotBefore,
NotAfter: body.NotAfter,
}
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
if err != nil {
WriteError(w, errs.UnauthorizedErr(err))
return
}
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil {
WriteError(w, errs.ForbiddenErr(err))
return
}
certChainPEM := certChainToPEM(certChain)
var caPEM Certificate
if len(certChainPEM) > 1 {
caPEM = certChainPEM[1]
}
logCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0],
CaPEM: caPEM,
CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(),
}, http.StatusCreated)
}

View file

@ -249,19 +249,19 @@ type SSHBastionResponse struct {
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err))
WriteError(w, errs.BadRequestErr(err))
return
}
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey"))
return
}
@ -269,7 +269,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if body.AddUserPublicKey != nil {
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
if err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing addUserPublicKey")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing addUserPublicKey"))
return
}
}
@ -282,16 +282,16 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ValidAfter: body.ValidAfter,
}
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignSSHMethod)
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, errs.Unauthorized(err))
WriteError(w, errs.UnauthorizedErr(err))
return
}
cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...)
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}
@ -299,7 +299,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 {
addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert)
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}
addUserCertificate = &SSHCertificate{addUserCert}
@ -320,12 +320,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, errs.Unauthorized(err))
WriteError(w, errs.UnauthorizedErr(err))
return
}
certChain, err := h.Authority.Sign(cr, opts, signOpts...)
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}
identityCertificate = certChainToPEM(certChain)
@ -343,12 +343,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHRoots()
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
return
}
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, errs.NotFound(errors.New("no keys found")))
WriteError(w, errs.NotFound("no keys found"))
return
}
@ -368,12 +368,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHFederation()
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
return
}
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, errs.NotFound(errors.New("no keys found")))
WriteError(w, errs.NotFound("no keys found"))
return
}
@ -393,17 +393,17 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err))
WriteError(w, errs.BadRequestErr(err))
return
}
ts, err := h.Authority.GetSSHConfig(body.Type, body.Data)
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
return
}
@ -414,7 +414,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
case provisioner.SSHHostCert:
config.HostTemplates = ts
default:
WriteError(w, errs.InternalServerError(errors.New("it should hot get here")))
WriteError(w, errs.InternalServer("it should hot get here"))
return
}
@ -429,13 +429,13 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
return
}
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err))
WriteError(w, errs.BadRequestErr(err))
return
}
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
return
}
JSON(w, &SSHCheckPrincipalResponse{
@ -452,7 +452,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
hosts, err := h.Authority.GetSSHHosts(cert)
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
return
}
JSON(w, &SSHGetHostsResponse{
@ -464,17 +464,17 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err))
WriteError(w, errs.BadRequestErr(err))
return
}
bastion, err := h.Authority.GetSSHBastion(body.User, body.Hostname)
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
return
}

View file

@ -40,42 +40,42 @@ type SSHRekeyResponse struct {
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err))
WriteError(w, errs.BadRequestErr(err))
return
}
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey"))
return
}
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RekeySSHMethod)
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, errs.Unauthorized(err))
WriteError(w, errs.UnauthorizedErr(err))
return
}
oldCert, err := provisioner.ExtractSSHPOPCert(body.OTT)
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
}
newCert, err := h.Authority.RekeySSH(oldCert, publicKey, signOpts...)
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}
identity, err := h.renewIdentityCertificate(r)
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}

View file

@ -36,36 +36,36 @@ type SSHRenewResponse struct {
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err))
WriteError(w, errs.BadRequestErr(err))
return
}
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RenewSSHMethod)
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod)
_, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, errs.Unauthorized(err))
WriteError(w, errs.UnauthorizedErr(err))
return
}
oldCert, err := provisioner.ExtractSSHPOPCert(body.OTT)
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
}
newCert, err := h.Authority.RenewSSH(oldCert)
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}
identity, err := h.renewIdentityCertificate(r)
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}

View file

@ -4,7 +4,6 @@ import (
"context"
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
@ -30,16 +29,16 @@ type SSHRevokeRequest struct {
// or an error if something is wrong.
func (r *SSHRevokeRequest) Validate() (err error) {
if r.Serial == "" {
return errs.BadRequest(errors.New("missing serial"))
return errs.BadRequest("missing serial")
}
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
return errs.BadRequest(errors.New("reasonCode out of bounds"))
return errs.BadRequest("reasonCode out of bounds")
}
if !r.Passive {
return errs.NotImplemented(errors.New("non-passive revocation not implemented"))
return errs.NotImplemented("non-passive revocation not implemented")
}
if len(r.OTT) == 0 {
return errs.BadRequest(errors.New("missing ott"))
return errs.BadRequest("missing ott")
}
return
}
@ -50,7 +49,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
@ -66,18 +65,18 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
PassiveOnly: body.Passive,
}
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeSSHMethod)
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod)
// A token indicates that we are using the api via a provisioner token,
// otherwise it is assumed that the certificate is revoking itself over mTLS.
logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, errs.Unauthorized(err))
WriteError(w, errs.UnauthorizedErr(err))
return
}
opts.OTT = body.OTT
if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}

View file

@ -6,7 +6,6 @@ import (
"log"
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
)
@ -69,7 +68,7 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
// pointed by v.
func ReadJSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.BadRequest(errors.Wrap(err, "error decoding json"))
return errs.Wrap(http.StatusBadRequest, err, "error decoding json")
}
return nil
}

View file

@ -13,12 +13,13 @@ import (
stepJOSE "github.com/smallstep/cli/jose"
)
func testAuthority(t *testing.T) *Authority {
func testAuthority(t *testing.T, opts ...Option) *Authority {
maxjwk, err := stepJOSE.ParseKey("testdata/secrets/max_pub.jwk")
assert.FatalError(t, err)
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
assert.FatalError(t, err)
disableRenewal := true
enableSSHCA := true
p := provisioner.List{
&provisioner.JWK{
Name: "Max",
@ -29,6 +30,9 @@ func testAuthority(t *testing.T) *Authority {
Name: "step-cli",
Type: "JWK",
Key: clijwk,
Claims: &provisioner.Claims{
EnableSSHCA: &enableSSHCA,
},
},
&provisioner.JWK{
Name: "dev",
@ -46,19 +50,30 @@ func testAuthority(t *testing.T) *Authority {
DisableRenewal: &disableRenewal,
},
},
&provisioner.SSHPOP{
Name: "sshpop",
Type: "SSHPOP",
Claims: &provisioner.Claims{
EnableSSHCA: &enableSSHCA,
},
},
}
c := &Config{
Address: "127.0.0.1:443",
Root: []string{"testdata/certs/root_ca.crt"},
IntermediateCert: "testdata/certs/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.ca.smallstep.com"},
Password: "pass",
SSH: &SSHConfig{
HostKey: "testdata/secrets/ssh_host_ca_key",
UserKey: "testdata/secrets/ssh_user_ca_key",
},
DNSNames: []string{"example.com"},
Password: "pass",
AuthorityConfig: &AuthConfig{
Provisioners: p,
},
}
a, err := New(c)
a, err := New(c, opts...)
assert.FatalError(t, err)
return a
}

View file

@ -6,9 +6,10 @@ import (
"net/http"
"strings"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose"
"golang.org/x/crypto/ssh"
)
// Claims extends jose.Claims with step attributes.
@ -36,22 +37,19 @@ func SkipTokenReuseFromContext(ctx context.Context) bool {
// authorizeToken parses the token and returns the provisioner used to generate
// the token. This method enforces the One-Time use policy (tokens can only be
// used once).
func (a *Authority) authorizeToken(ctx context.Context, ott string) (provisioner.Interface, error) {
var errContext = map[string]interface{}{"ott": ott}
func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) {
// Validate payload
token, err := jose.ParseSigned(ott)
tok, err := jose.ParseSigned(token)
if err != nil {
return nil, &apiError{errors.Wrapf(err, "authorizeToken: error parsing token"),
http.StatusUnauthorized, errContext}
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken: error parsing token")
}
// Get claims w/out verification. We need to look up the provisioner
// key in order to verify the claims and we need the issuer from the claims
// before we can look up the provisioner.
var claims Claims
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeToken"), http.StatusUnauthorized, errContext}
if err = tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken")
}
// TODO: use new persistence layer abstraction.
@ -59,29 +57,27 @@ func (a *Authority) authorizeToken(ctx context.Context, ott string) (provisioner
// This check is meant as a stopgap solution to the current lack of a persistence layer.
if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck {
if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) {
return nil, &apiError{errors.New("authorizeToken: token issued before the bootstrap of certificate authority"),
http.StatusUnauthorized, errContext}
return nil, errs.Unauthorized("authority.authorizeToken: token issued before the bootstrap of certificate authority")
}
}
// This method will also validate the audiences for JWK provisioners.
p, ok := a.provisioners.LoadByToken(token, &claims.Claims)
p, ok := a.provisioners.LoadByToken(tok, &claims.Claims)
if !ok {
return nil, &apiError{
errors.Errorf("authorizeToken: provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")),
http.StatusUnauthorized, errContext}
return nil, errs.Unauthorized("authority.authorizeToken: provisioner "+
"not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))
}
// Store the token to protect against reuse unless it's skipped.
if !SkipTokenReuseFromContext(ctx) {
if reuseKey, err := p.GetTokenID(ott); err == nil {
ok, err := a.db.UseToken(reuseKey, ott)
if reuseKey, err := p.GetTokenID(token); err == nil {
ok, err := a.db.UseToken(reuseKey, token)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeToken: failed when attempting to store token"),
http.StatusInternalServerError, errContext}
return nil, errs.Wrap(http.StatusInternalServerError, err,
"authority.authorizeToken: failed when attempting to store token")
}
if !ok {
return nil, &apiError{errors.Errorf("authorizeToken: token already used"), http.StatusUnauthorized, errContext}
return nil, errs.Unauthorized("authority.authorizeToken: token already used")
}
}
}
@ -89,125 +85,158 @@ func (a *Authority) authorizeToken(ctx context.Context, ott string) (provisioner
return p, nil
}
// Authorize grabs the method from the context and authorizes a signature
// request by validating the one-time-token.
func (a *Authority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
var errContext = apiCtx{"ott": ott}
// Authorize grabs the method from the context and authorizes the request by
// validating the one-time-token.
func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.SignOption, error) {
var opts = []interface{}{errs.WithKeyVal("token", token)}
switch m := provisioner.MethodFromContext(ctx); m {
case provisioner.SignMethod:
return a.authorizeSign(ctx, ott)
signOpts, err := a.authorizeSign(ctx, token)
return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
case provisioner.RevokeMethod:
return nil, a.authorizeRevoke(ctx, ott)
case provisioner.SignSSHMethod:
return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeRevoke(ctx, token), "authority.Authorize", opts...)
case provisioner.SSHSignMethod:
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
return nil, &apiError{errors.New("authorize: ssh signing is not enabled"), http.StatusNotImplemented, errContext}
return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
}
return a.authorizeSSHSign(ctx, ott)
case provisioner.RenewSSHMethod:
_, err := a.authorizeSSHSign(ctx, token)
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
case provisioner.SSHRenewMethod:
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
return nil, &apiError{errors.New("authorize: ssh signing is not enabled"), http.StatusNotImplemented, errContext}
return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
}
if _, err := a.authorizeSSHRenew(ctx, ott); err != nil {
return nil, err
}
return nil, nil
case provisioner.RevokeSSHMethod:
return nil, a.authorizeSSHRevoke(ctx, ott)
case provisioner.RekeySSHMethod:
_, err := a.authorizeSSHRenew(ctx, token)
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
case provisioner.SSHRevokeMethod:
return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeSSHRevoke(ctx, token), "authority.Authorize", opts...)
case provisioner.SSHRekeyMethod:
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
return nil, &apiError{errors.New("authorize: ssh signing is not enabled"), http.StatusNotImplemented, errContext}
return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
}
_, opts, err := a.authorizeSSHRekey(ctx, ott)
if err != nil {
return nil, err
}
return opts, nil
_, signOpts, err := a.authorizeSSHRekey(ctx, token)
return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
default:
return nil, &apiError{errors.Errorf("authorize: method %d is not supported", m), http.StatusInternalServerError, errContext}
return nil, errs.InternalServer("authority.Authorize; method %d is not supported", append([]interface{}{m}, opts...)...)
}
}
// authorizeSign loads the provisioner from the token, checks that it has not
// been used again and calls the provisioner AuthorizeSign method. Returns a
// list of methods to apply to the signing flow.
func (a *Authority) authorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
var errContext = apiCtx{"ott": ott}
p, err := a.authorizeToken(ctx, ott)
// authorizeSign loads the provisioner from the token and calls the provisioner
// AuthorizeSign method. Returns a list of methods to apply to the signing flow.
func (a *Authority) authorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) {
p, err := a.authorizeToken(ctx, token)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSign")
}
opts, err := p.AuthorizeSign(ctx, ott)
signOpts, err := p.AuthorizeSign(ctx, token)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSign")
}
return opts, nil
return signOpts, nil
}
// AuthorizeSign authorizes a signature request by validating and authenticating
// a OTT that must be sent w/ the request.
// a token that must be sent w/ the request.
//
// NOTE: This method is deprecated and should not be used. We make it available
// in the short term os as not to break existing clients.
func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
func (a *Authority) AuthorizeSign(token string) ([]provisioner.SignOption, error) {
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
return a.Authorize(ctx, ott)
return a.Authorize(ctx, token)
}
// authorizeRevoke authorizes a revocation request by validating and authenticating
// the RevokeOptions POSTed with the request.
// Returns a tuple of the provisioner ID and error, if one occurred.
// authorizeRevoke locates the provisioner used to generate the authenticating
// token and then performs the token validation flow.
func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
errContext := map[string]interface{}{"ott": token}
p, err := a.authorizeToken(ctx, token)
if err != nil {
return &apiError{errors.Wrap(err, "authorizeRevoke"), http.StatusUnauthorized, errContext}
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
}
if err = p.AuthorizeRevoke(ctx, token); err != nil {
return &apiError{errors.Wrap(err, "authorizeRevoke"), http.StatusUnauthorized, errContext}
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
}
return nil
}
// authorizeRenewl tries to locate the step provisioner extension, and checks
// authorizeRenew locates the provisioner (using the provisioner extension in the cert), and checks
// if for the configured provisioner, the renewal is enabled or not. If the
// extra extension cannot be found, authorize the renewal by default.
//
// TODO(mariano): should we authorize by default?
func (a *Authority) authorizeRenew(crt *x509.Certificate) error {
errContext := map[string]interface{}{"serialNumber": crt.SerialNumber.String()}
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
var opts = []interface{}{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())}
// Check the passive revocation table.
isRevoked, err := a.db.IsRevoked(crt.SerialNumber.String())
isRevoked, err := a.db.IsRevoked(cert.SerialNumber.String())
if err != nil {
return &apiError{
err: errors.Wrap(err, "renew"),
code: http.StatusInternalServerError,
context: errContext,
}
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
}
if isRevoked {
return &apiError{
err: errors.New("renew: certificate has been revoked"),
code: http.StatusUnauthorized,
context: errContext,
}
return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
}
p, ok := a.provisioners.LoadByCertificate(crt)
p, ok := a.provisioners.LoadByCertificate(cert)
if !ok {
return &apiError{
err: errors.New("renew: provisioner not found"),
code: http.StatusUnauthorized,
context: errContext,
}
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
}
if err := p.AuthorizeRenew(context.Background(), crt); err != nil {
return &apiError{
err: errors.Wrap(err, "renew"),
code: http.StatusUnauthorized,
context: errContext,
}
if err := p.AuthorizeRenew(context.Background(), cert); err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
}
return nil
}
// authorizeSSHSign loads the provisioner from the token, checks that it has not
// been used again and calls the provisioner AuthorizeSSHSign method. Returns a
// list of methods to apply to the signing flow.
func (a *Authority) authorizeSSHSign(ctx context.Context, token string) ([]provisioner.SignOption, error) {
p, err := a.authorizeToken(ctx, token)
if err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeSSHSign")
}
signOpts, err := p.AuthorizeSSHSign(ctx, token)
if err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeSSHSign")
}
return signOpts, nil
}
// authorizeSSHRenew authorizes an SSH certificate renewal request, by
// validating the contents of an SSHPOP token.
func (a *Authority) authorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
p, err := a.authorizeToken(ctx, token)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRenew")
}
cert, err := p.AuthorizeSSHRenew(ctx, token)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRenew")
}
return cert, nil
}
// authorizeSSHRekey authorizes an SSH certificate rekey request, by
// validating the contents of an SSHPOP token.
func (a *Authority) authorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []provisioner.SignOption, error) {
p, err := a.authorizeToken(ctx, token)
if err != nil {
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRekey")
}
cert, signOpts, err := p.AuthorizeSSHRekey(ctx, token)
if err != nil {
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRekey")
}
return cert, signOpts, nil
}
// authorizeSSHRevoke authorizes an SSH certificate revoke request, by
// validating the contents of an SSHPOP token.
func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error {
p, err := a.authorizeToken(ctx, token)
if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRevoke")
}
if err = p.AuthorizeSSHRevoke(ctx, token); err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRevoke")
}
return nil
}

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,7 @@
package authority
import (
"fmt"
"testing"
"github.com/pkg/errors"
@ -9,7 +10,6 @@ import (
"github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/crypto/x509util"
stepJOSE "github.com/smallstep/cli/jose"
jose "gopkg.in/square/go-jose.v2"
)
func TestConfigValidate(t *testing.T) {
@ -255,28 +255,6 @@ func TestAuthConfigValidate(t *testing.T) {
err: errors.New("authority cannot be undefined"),
}
},
"fail-invalid-provisioners": func(t *testing.T) AuthConfigValidateTest {
return AuthConfigValidateTest{
ac: &AuthConfig{
Provisioners: provisioner.List{
&provisioner.JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
&provisioner.JWK{Name: "foo", Key: &jose.JSONWebKey{}},
},
},
err: errors.New("provisioner type cannot be empty"),
}
},
"fail-invalid-claims": func(t *testing.T) AuthConfigValidateTest {
return AuthConfigValidateTest{
ac: &AuthConfig{
Provisioners: p,
Claims: &provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: -1},
},
},
err: errors.New("claims: MinTLSCertDuration must be greater than 0"),
}
},
"ok-empty-provisioners": func(t *testing.T) AuthConfigValidateTest {
return AuthConfigValidateTest{
ac: &AuthConfig{},
@ -311,7 +289,7 @@ func TestAuthConfigValidate(t *testing.T) {
assert.Equals(t, tc.err.Error(), err.Error())
}
} else {
if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err, fmt.Sprintf("expected error: %s, but got <nil>", tc.err)) {
assert.Equals(t, *tc.ac.Template, tc.asn1dn)
}
}

View file

@ -1,96 +0,0 @@
package authority
import (
"crypto/x509"
"github.com/smallstep/certificates/db"
"golang.org/x/crypto/ssh"
)
type MockAuthDB struct {
err error
ret1 interface{}
isRevoked func(string) (bool, error)
isSSHRevoked func(string) (bool, error)
revoke func(rci *db.RevokedCertificateInfo) error
revokeSSH func(rci *db.RevokedCertificateInfo) error
storeCertificate func(crt *x509.Certificate) error
useToken func(id, tok string) (bool, error)
isSSHHost func(principal string) (bool, error)
storeSSHCertificate func(crt *ssh.Certificate) error
getSSHHostPrincipals func() ([]string, error)
shutdown func() error
}
func (m *MockAuthDB) IsRevoked(sn string) (bool, error) {
if m.isRevoked != nil {
return m.isRevoked(sn)
}
return m.ret1.(bool), m.err
}
func (m *MockAuthDB) IsSSHRevoked(sn string) (bool, error) {
if m.isSSHRevoked != nil {
return m.isSSHRevoked(sn)
}
return m.ret1.(bool), m.err
}
func (m *MockAuthDB) UseToken(id, tok string) (bool, error) {
if m.useToken != nil {
return m.useToken(id, tok)
}
if m.ret1 == nil {
return false, m.err
}
return m.ret1.(bool), m.err
}
func (m *MockAuthDB) Revoke(rci *db.RevokedCertificateInfo) error {
if m.revoke != nil {
return m.revoke(rci)
}
return m.err
}
func (m *MockAuthDB) RevokeSSH(rci *db.RevokedCertificateInfo) error {
if m.revokeSSH != nil {
return m.revokeSSH(rci)
}
return m.err
}
func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error {
if m.storeCertificate != nil {
return m.storeCertificate(crt)
}
return m.err
}
func (m *MockAuthDB) IsSSHHost(principal string) (bool, error) {
if m.isSSHHost != nil {
return m.isSSHHost(principal)
}
return m.ret1.(bool), m.err
}
func (m *MockAuthDB) StoreSSHCertificate(crt *ssh.Certificate) error {
if m.storeSSHCertificate != nil {
return m.storeSSHCertificate(crt)
}
return m.err
}
func (m *MockAuthDB) GetSSHHostPrincipals() ([]string, error) {
if m.getSSHHostPrincipals != nil {
return m.getSSHHostPrincipals()
}
return m.ret1.([]string), m.err
}
func (m *MockAuthDB) Shutdown() error {
if m.shutdown != nil {
return m.shutdown()
}
return m.err
}

View file

@ -1,67 +0,0 @@
package authority
import (
"encoding/json"
"fmt"
"net/http"
)
type apiCtx map[string]interface{}
// Error implements the api.Error interface and adds context to error messages.
type apiError struct {
err error
code int
context apiCtx
}
// Cause implements the errors.Causer interface and returns the original error.
func (e *apiError) Cause() error {
return e.err
}
// Error returns an error message with additional context.
func (e *apiError) Error() string {
ret := e.err.Error()
/*
if len(e.context) > 0 {
ret += "\n\nContext:"
for k, v := range e.context {
ret += fmt.Sprintf("\n %s: %v", k, v)
}
}
*/
return ret
}
// ErrorResponse represents an error in JSON format.
type ErrorResponse struct {
Status int `json:"status"`
Message string `json:"message"`
}
// StatusCode returns an http status code indicating the type and severity of
// the error.
func (e *apiError) StatusCode() int {
if e.code == 0 {
return http.StatusInternalServerError
}
return e.code
}
// MarshalJSON implements json.Marshaller interface for the Error struct.
func (e *apiError) MarshalJSON() ([]byte, error) {
return json.Marshal(&ErrorResponse{Status: e.code, Message: http.StatusText(e.code)})
}
// UnmarshalJSON implements json.Unmarshaler interface for the Error struct.
func (e *apiError) UnmarshalJSON(data []byte) error {
var er ErrorResponse
if err := json.Unmarshal(data, &er); err != nil {
return err
}
e.code = er.Status
e.err = fmt.Errorf(er.Message)
return nil
}

View file

@ -5,6 +5,7 @@ import (
"crypto/x509"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
)
// ACME is the acme provisioner type, an entity that can authorize the ACME
@ -79,7 +80,7 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// certificate was configured to allow renewals.
func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID())
}
return nil
}

View file

@ -3,11 +3,13 @@ package provisioner
import (
"context"
"crypto/x509"
"net/http"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
)
func TestACME_Getters(t *testing.T) {
@ -88,86 +90,98 @@ func TestACME_Init(t *testing.T) {
}
}
func TestACME_AuthorizeRevoke(t *testing.T) {
p, err := generateACME()
assert.FatalError(t, err)
assert.Nil(t, p.AuthorizeRevoke(context.TODO(), ""))
}
func TestACME_AuthorizeRenew(t *testing.T) {
p1, err := generateACME()
assert.FatalError(t, err)
p2, err := generateACME()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
type test struct {
p *ACME
cert *x509.Certificate
}
tests := []struct {
name string
prov *ACME
args args
err error
}{
{"ok", p1, args{nil}, nil},
{"fail", p2, args{nil}, errors.Errorf("renew is disabled for provisioner %s", p2.GetID())},
code int
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); err != nil {
if assert.NotNil(t, tt.err) {
assert.HasPrefix(t, err.Error(), tt.err.Error())
tests := map[string]func(*testing.T) test{
"fail/renew-disabled": func(t *testing.T) test {
p, err := generateACME()
assert.FatalError(t, err)
// disable renewal
disable := true
p.Claims = &Claims{DisableRenewal: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
return test{
p: p,
cert: &x509.Certificate{},
code: http.StatusUnauthorized,
err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID()),
}
},
"ok": func(t *testing.T) test {
p, err := generateACME()
assert.FatalError(t, err)
return test{
p: p,
cert: &x509.Certificate{},
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tt.err)
assert.Nil(t, tc.err)
}
})
}
}
func TestACME_AuthorizeSign(t *testing.T) {
p1, err := generateACME()
assert.FatalError(t, err)
tests := []struct {
name string
prov *ACME
method Method
err error
}{
{"fail/method", p1, SignSSHMethod, errors.New("unexpected method type 1 in context")},
{"ok", p1, SignMethod, nil},
type test struct {
p *ACME
token string
code int
err error
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), tt.method)
if got, err := tt.prov.AuthorizeSign(ctx, ""); err != nil {
if assert.NotNil(t, tt.err) {
assert.HasPrefix(t, err.Error(), tt.err.Error())
tests := map[string]func(*testing.T) test{
"ok": func(t *testing.T) test {
p, err := generateACME()
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.NotNil(t, got) {
assert.Len(t, 4, got)
for _, o := range got {
if assert.Nil(t, tc.err) && assert.NotNil(t, opts) {
assert.Len(t, 4, opts)
for _, o := range opts {
switch v := o.(type) {
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeACME))
assert.Equals(t, v.Name, tt.prov.GetName())
assert.Equals(t, v.Name, tc.p.GetName())
assert.Equals(t, v.CredentialID, "")
assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration())
assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration())
case defaultPublicKeyValidator:
case *validityValidator:
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration())
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration())
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
}

View file

@ -16,6 +16,7 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose"
)
@ -271,7 +272,7 @@ func (p *AWS) Init(config Config) (err error) {
func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
payload, err := p.authorizeToken(token)
if err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSign")
}
doc := payload.document
@ -305,7 +306,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// certificate was configured to allow renewals.
func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner %s", p.GetID())
}
return nil
}
@ -349,41 +350,41 @@ func (p *AWS) readURL(url string) ([]byte, error) {
func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, errors.Wrapf(err, "error parsing token")
return nil, errs.Wrapf(http.StatusUnauthorized, err, "aws.authorizeToken; error parsing aws token")
}
if len(jwt.Headers) == 0 {
return nil, errors.New("error parsing token: header is missing")
return nil, errs.InternalServer("aws.authorizeToken; error parsing token, header is missing")
}
var unsafeClaims awsPayload
if err := jwt.UnsafeClaimsWithoutVerification(&unsafeClaims); err != nil {
return nil, errors.Wrap(err, "error unmarshaling claims")
return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; error unmarshaling claims")
}
var payload awsPayload
if err := jwt.Claims(unsafeClaims.Amazon.Signature, &payload); err != nil {
return nil, errors.Wrap(err, "error verifying claims")
return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; error verifying claims")
}
// Validate identity document signature
if err := p.checkSignature(payload.Amazon.Document, payload.Amazon.Signature); err != nil {
return nil, err
return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; invalid aws token signature")
}
var doc awsInstanceIdentityDocument
if err := json.Unmarshal(payload.Amazon.Document, &doc); err != nil {
return nil, errors.Wrap(err, "error unmarshaling identity document")
return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; error unmarshaling aws identity document")
}
switch {
case doc.AccountID == "":
return nil, errors.New("identity document accountId cannot be empty")
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document accountId cannot be empty")
case doc.InstanceID == "":
return nil, errors.New("identity document instanceId cannot be empty")
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document instanceId cannot be empty")
case doc.PrivateIP == "":
return nil, errors.New("identity document privateIp cannot be empty")
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document privateIp cannot be empty")
case doc.Region == "":
return nil, errors.New("identity document region cannot be empty")
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document region cannot be empty")
}
// According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -393,12 +394,12 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
Issuer: awsIssuer,
Time: now,
}, time.Minute); err != nil {
return nil, errors.Wrapf(err, "invalid token")
return nil, errs.Wrapf(http.StatusUnauthorized, err, "aws.authorizeToken; invalid aws token")
}
// validate audiences with the defaults
if !matchesAudience(payload.Audience, p.audiences.Sign) {
return nil, errors.New("invalid token: invalid audience claim (aud)")
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)")
}
// Validate subject, it has to be known if disableCustomSANs is enabled
@ -406,7 +407,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
if payload.Subject != doc.InstanceID &&
payload.Subject != doc.PrivateIP &&
payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) {
return nil, errors.New("invalid token: invalid subject claim (sub)")
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)")
}
}
@ -420,14 +421,14 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
}
}
if !found {
return nil, errors.New("invalid identity document: accountId is not valid")
return nil, errs.Unauthorized("aws.authorizeToken; invalid aws identity document - accountId is not valid")
}
}
// validate instance age
if d := p.InstanceAge.Value(); d > 0 {
if now.Sub(doc.PendingTime) > d {
return nil, errors.New("identity document pendingTime is too old")
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document pendingTime is too old")
}
}
@ -438,18 +439,18 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner %s", p.GetID())
}
claims, err := p.authorizeToken(token)
if err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSSHSign")
}
doc := claims.document
signOptions := []SignOption{
// set the key id to the token subject
sshCertificateKeyIDModifier(claims.Subject),
sshCertKeyIDModifier(claims.Subject),
}
// Default to host + known IPs/hostnames
@ -461,9 +462,9 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
},
}
// Validate user options
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
signOptions = append(signOptions, sshCertOptionsValidator(defaults))
// Set defaults if not given as user options
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
return append(signOptions,
// Set the default extensions.
@ -473,8 +474,8 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertificateValidityValidator{p.claimer},
&sshCertValidityValidator{p.claimer},
// Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{},
&sshCertDefaultValidator{},
), nil
}

View file

@ -10,12 +10,15 @@ import (
"encoding/hex"
"encoding/pem"
"fmt"
"net/http"
"net/url"
"strings"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose"
)
@ -229,6 +232,213 @@ func TestAWS_Init(t *testing.T) {
}
}
func TestAWS_authorizeToken(t *testing.T) {
block, _ := pem.Decode([]byte(awsTestKey))
if block == nil || block.Type != "RSA PRIVATE KEY" {
t.Fatal("error decoding AWS key")
}
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
assert.FatalError(t, err)
badKey, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err)
type test struct {
p *AWS
token string
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; error parsing aws token"),
}
},
"fail/cannot-validate-sig": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now(), badKey)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; invalid aws token signature"),
}
},
"fail/empty-account-id": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", awsIssuer, p.GetID(), "", "instance-id",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; aws identity document accountId cannot be empty"),
}
},
"fail/empty-instance-id": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; aws identity document instanceId cannot be empty"),
}
},
"fail/empty-private-ip": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
"", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; aws identity document privateIp cannot be empty"),
}
},
"fail/empty-region": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
"127.0.0.1", "", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; aws identity document region cannot be empty"),
}
},
"fail/invalid-token-issuer": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", "bad-issuer", p.GetID(), p.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; invalid aws token"),
}
},
"fail/invalid-audience": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", awsIssuer, "bad-audience", p.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; invalid token - invalid audience claim (aud)"),
}
},
"fail/invalid-subject-disabled-custom-SANs": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
p.DisableCustomSANs = true
tok, err := generateAWSToken(
"foo", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; invalid token - invalid subject claim (sub)"),
}
},
"fail/invalid-account-id": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", awsIssuer, p.GetID(), "foo", "instance-id",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; invalid aws identity document - accountId is not valid"),
}
},
"fail/instance-age": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
p.InstanceAge = Duration{1 * time.Minute}
tok, err := generateAWSToken(
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now().Add(-1*time.Minute), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; aws identity document pendingTime is too old"),
}
},
"ok": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) && assert.NotNil(t, claims) {
assert.Equals(t, claims.Subject, "instance-id")
assert.Equals(t, claims.Issuer, awsIssuer)
assert.NotNil(t, claims.Amazon)
aud, err := generateSignAudience("https://ca.smallstep.com", tc.p.GetID())
assert.FatalError(t, err)
assert.Equals(t, claims.Audience[0], aud)
}
}
})
}
}
func TestAWS_AuthorizeSign(t *testing.T) {
p1, srv, err := generateAWSWithServer()
assert.FatalError(t, err)
@ -326,26 +536,27 @@ func TestAWS_AuthorizeSign(t *testing.T) {
aws *AWS
args args
wantLen int
code int
wantErr bool
}{
{"ok", p1, args{t1}, 5, false},
{"ok", p2, args{t2}, 7, false},
{"ok", p2, args{t2Hostname}, 7, false},
{"ok", p2, args{t2PrivateIP}, 7, false},
{"ok", p1, args{t4}, 5, false},
{"fail account", p3, args{t3}, 0, true},
{"fail token", p1, args{"token"}, 0, true},
{"fail subject", p1, args{failSubject}, 0, true},
{"fail issuer", p1, args{failIssuer}, 0, true},
{"fail audience", p1, args{failAudience}, 0, true},
{"fail account", p1, args{failAccount}, 0, true},
{"fail instanceID", p1, args{failInstanceID}, 0, true},
{"fail privateIP", p1, args{failPrivateIP}, 0, true},
{"fail region", p1, args{failRegion}, 0, true},
{"fail exp", p1, args{failExp}, 0, true},
{"fail nbf", p1, args{failNbf}, 0, true},
{"fail key", p1, args{failKey}, 0, true},
{"fail instance age", p2, args{failInstanceAge}, 0, true},
{"ok", p1, args{t1}, 5, http.StatusOK, false},
{"ok", p2, args{t2}, 7, http.StatusOK, false},
{"ok", p2, args{t2Hostname}, 7, http.StatusOK, false},
{"ok", p2, args{t2PrivateIP}, 7, http.StatusOK, false},
{"ok", p1, args{t4}, 5, http.StatusOK, false},
{"fail account", p3, args{t3}, 0, http.StatusUnauthorized, true},
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
{"fail subject", p1, args{failSubject}, 0, http.StatusUnauthorized, true},
{"fail issuer", p1, args{failIssuer}, 0, http.StatusUnauthorized, true},
{"fail audience", p1, args{failAudience}, 0, http.StatusUnauthorized, true},
{"fail account", p1, args{failAccount}, 0, http.StatusUnauthorized, true},
{"fail instanceID", p1, args{failInstanceID}, 0, http.StatusUnauthorized, true},
{"fail privateIP", p1, args{failPrivateIP}, 0, http.StatusUnauthorized, true},
{"fail region", p1, args{failRegion}, 0, http.StatusUnauthorized, true},
{"fail exp", p1, args{failExp}, 0, http.StatusUnauthorized, true},
{"fail nbf", p1, args{failNbf}, 0, http.StatusUnauthorized, true},
{"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
{"fail instance age", p2, args{failInstanceAge}, 0, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -354,8 +565,13 @@ func TestAWS_AuthorizeSign(t *testing.T) {
if (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
} else {
assert.Len(t, tt.wantLen, got)
}
assert.Len(t, tt.wantLen, got)
})
}
}
@ -368,6 +584,14 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
assert.FatalError(t, err)
defer srv.Close()
p2, err := generateAWS()
assert.FatalError(t, err)
// disable sshCA
disable := false
p2.Claims = &Claims{EnableSSHCA: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com")
assert.FatalError(t, err)
@ -407,30 +631,35 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
aws *AWS
args args
expected *SSHOptions
code int
wantErr bool
wantSignErr bool
}{
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false},
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false},
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, false, false},
{"ok-principal-ip", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1"}}, pub}, expectedHostOptionsIP, false, false},
{"ok-principal-hostname", p1, args{t1, SSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptionsHostname, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, false, false},
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true},
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}, pub}, nil, false, true},
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, http.StatusOK, false, false},
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-principal-ip", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1"}}, pub}, expectedHostOptionsIP, http.StatusOK, false, false},
{"ok-principal-hostname", p1, args{t1, SSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptionsHostname, http.StatusOK, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, http.StatusOK, false, true},
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, http.StatusOK, false, true},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
{"fail-sshCA-disabled", p2, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
{"fail-invalid-token", p1, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
got, err := tt.aws.AuthorizeSSHSign(ctx, tt.args.token)
got, err := tt.aws.AuthorizeSSHSign(context.Background(), tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got)
} else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
@ -447,6 +676,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
})
}
}
func TestAWS_AuthorizeRenew(t *testing.T) {
p1, err := generateAWS()
assert.FatalError(t, err)
@ -466,44 +696,20 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
name string
aws *AWS
args args
code int
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
{"ok", p1, args{nil}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.aws.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestAWS_AuthorizeRevoke(t *testing.T) {
p1, srv, err := generateAWSWithServer()
assert.FatalError(t, err)
defer srv.Close()
t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com")
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
aws *AWS
args args
wantErr bool
}{
{"ok", p1, args{t1}, true}, // revoke is disabled
{"fail", p1, args{"token"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.aws.AuthorizeRevoke(context.TODO(), tt.args.token); (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}
})
}

View file

@ -13,6 +13,7 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose"
)
@ -209,14 +210,14 @@ func (p *Azure) Init(config Config) (err error) {
return nil
}
// parseToken returns the claims, name, group, error.
func (p *Azure) parseToken(token string) (*azurePayload, string, string, error) {
// authorizeToken returns the claims, name, group, error.
func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, "", "", errors.Wrapf(err, "error parsing token")
return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token")
}
if len(jwt.Headers) == 0 {
return nil, "", "", errors.New("error parsing token: header is missing")
return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token missing header")
}
var found bool
@ -229,7 +230,7 @@ func (p *Azure) parseToken(token string) (*azurePayload, string, string, error)
}
}
if !found {
return nil, "", "", errors.New("cannot validate token")
return nil, "", "", errs.Unauthorized("azure.authorizeToken; cannot validate azure token")
}
if err := claims.ValidateWithLeeway(jose.Expected{
@ -237,17 +238,17 @@ func (p *Azure) parseToken(token string) (*azurePayload, string, string, error)
Issuer: p.oidcConfig.Issuer,
Time: time.Now(),
}, 1*time.Minute); err != nil {
return nil, "", "", errors.Wrap(err, "failed to validate payload")
return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; failed to validate azure token payload")
}
// Validate TenantID
if claims.TenantID != p.TenantID {
return nil, "", "", errors.New("validation failed: invalid tenant id claim (tid)")
return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)")
}
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
if len(re) != 4 {
return nil, "", "", errors.Errorf("error parsing xms_mirid claim: %s", claims.XMSMirID)
return nil, "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)
}
group, name := re[2], re[3]
return &claims, name, group, nil
@ -256,9 +257,9 @@ func (p *Azure) parseToken(token string) (*azurePayload, string, string, error)
// AuthorizeSign validates the given token and returns the sign options that
// will be used on certificate creation.
func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
_, name, group, err := p.parseToken(token)
_, name, group, err := p.authorizeToken(token)
if err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSign")
}
// Filter by resource group
@ -271,7 +272,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
}
}
if !found {
return nil, errors.New("validation failed: invalid resource group")
return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid resource group")
}
}
@ -301,7 +302,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// certificate was configured to allow renewals.
func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner %s", p.GetID())
}
return nil
}
@ -309,16 +310,16 @@ func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner %s", p.GetID())
}
_, name, _, err := p.parseToken(token)
_, name, _, err := p.authorizeToken(token)
if err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSSHSign")
}
signOptions := []SignOption{
// set the key id to the token subject
sshCertificateKeyIDModifier(name),
sshCertKeyIDModifier(name),
}
// Default to host + known hostnames
@ -327,9 +328,9 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
Principals: []string{name},
}
// Validate user options
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
signOptions = append(signOptions, sshCertOptionsValidator(defaults))
// Set defaults if not given as user options
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
return append(signOptions,
// Set the default extensions.
@ -339,9 +340,9 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertificateValidityValidator{p.claimer},
&sshCertValidityValidator{p.claimer},
// Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{},
&sshCertDefaultValidator{},
), nil
}

View file

@ -15,7 +15,10 @@ import (
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose"
)
func TestAzure_Getters(t *testing.T) {
@ -209,6 +212,148 @@ func TestAzure_Init(t *testing.T) {
}
}
func TestAzure_authorizeToken(t *testing.T) {
type test struct {
p *Azure
token string
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
p, err := generateAzure()
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("azure.authorizeToken; error parsing azure token"),
}
},
"fail/cannot-validate-sig": func(t *testing.T) test {
p, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
time.Now(), jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("azure.authorizeToken; cannot validate azure token"),
}
},
"fail/invalid-token-issuer": func(t *testing.T) test {
p, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("azure.authorizeToken; failed to validate azure token payload"),
}
},
"fail/invalid-tenant-id": func(t *testing.T) test {
p, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
"foo", "subscriptionID", "resourceGroup", "virtualMachine",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)"),
}
},
"fail/invalid-xms-mir-id": func(t *testing.T) test {
p, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
jwk := &p.keyStore.keySet.Keys[0]
sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
)
assert.FatalError(t, err)
now := time.Now()
claims := azurePayload{
Claims: jose.Claims{
Subject: "subject",
Issuer: p.oidcConfig.Issuer,
IssuedAt: jose.NewNumericDate(now),
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
Audience: []string{azureDefaultAudience},
ID: "the-jti",
},
AppID: "the-appid",
AppIDAcr: "the-appidacr",
IdentityProvider: "the-idp",
ObjectID: "the-oid",
TenantID: p.TenantID,
Version: "the-version",
XMSMirID: "foo",
}
tok, err := jose.Signed(sig).Claims(claims).CompactSerialize()
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("azure.authorizeToken; error parsing xms_mirid claim - foo"),
}
},
"ok": func(t *testing.T) test {
p, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if claims, name, group, err := tc.p.authorizeToken(tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, claims.Subject, "subject")
assert.Equals(t, claims.Issuer, tc.p.oidcConfig.Issuer)
assert.Equals(t, claims.Audience[0], azureDefaultAudience)
assert.Equals(t, name, "virtualMachine")
assert.Equals(t, group, "resourceGroup")
}
}
})
}
}
func TestAzure_AuthorizeSign(t *testing.T) {
p1, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
@ -283,19 +428,20 @@ func TestAzure_AuthorizeSign(t *testing.T) {
azure *Azure
args args
wantLen int
code int
wantErr bool
}{
{"ok", p1, args{t1}, 4, false},
{"ok", p2, args{t2}, 6, false},
{"ok", p1, args{t11}, 4, false},
{"fail tenant", p3, args{t3}, 0, true},
{"fail resource group", p4, args{t4}, 0, true},
{"fail token", p1, args{"token"}, 0, true},
{"fail issuer", p1, args{failIssuer}, 0, true},
{"fail audience", p1, args{failAudience}, 0, true},
{"fail exp", p1, args{failExp}, 0, true},
{"fail nbf", p1, args{failNbf}, 0, true},
{"fail key", p1, args{failKey}, 0, true},
{"ok", p1, args{t1}, 4, http.StatusOK, false},
{"ok", p2, args{t2}, 6, http.StatusOK, false},
{"ok", p1, args{t11}, 4, http.StatusOK, false},
{"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true},
{"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true},
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
{"fail issuer", p1, args{failIssuer}, 0, http.StatusUnauthorized, true},
{"fail audience", p1, args{failAudience}, 0, http.StatusUnauthorized, true},
{"fail exp", p1, args{failExp}, 0, http.StatusUnauthorized, true},
{"fail nbf", p1, args{failNbf}, 0, http.StatusUnauthorized, true},
{"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -304,8 +450,51 @@ func TestAzure_AuthorizeSign(t *testing.T) {
if (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
} else {
assert.Len(t, tt.wantLen, got)
}
})
}
}
func TestAzure_AuthorizeRenew(t *testing.T) {
p1, err := generateAzure()
assert.FatalError(t, err)
p2, err := generateAzure()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
azure *Azure
args args
code int
wantErr bool
}{
{"ok", p1, args{nil}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}
assert.Len(t, tt.wantLen, got)
})
}
}
@ -318,6 +507,14 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
assert.FatalError(t, err)
defer srv.Close()
p2, err := generateAzure()
assert.FatalError(t, err)
// disable sshCA
disable := false
p2.Claims = &Claims{EnableSSHCA: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
t1, err := p1.GetIdentityToken("subject", "caURL")
assert.FatalError(t, err)
@ -349,28 +546,33 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
azure *Azure
args args
expected *SSHOptions
code int
wantErr bool
wantSignErr bool
}{
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false},
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false},
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, false, false},
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true},
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine", "smallstep.com"}}, pub}, nil, false, true},
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, http.StatusOK, false, false},
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, http.StatusOK, false, true},
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, http.StatusOK, false, true},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine", "smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
{"fail-sshCA-disabled", p2, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
{"fail-invalid-token", p1, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
got, err := tt.azure.AuthorizeSSHSign(ctx, tt.args.token)
got, err := tt.azure.AuthorizeSSHSign(context.Background(), tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got)
} else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
@ -388,68 +590,6 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
}
}
func TestAzure_AuthorizeRenew(t *testing.T) {
p1, err := generateAzure()
assert.FatalError(t, err)
p2, err := generateAzure()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
azure *Azure
args args
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.azure.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestAzure_AuthorizeRevoke(t *testing.T) {
az, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
token, err := az.GetIdentityToken("subject", "caURL")
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
azure *Azure
args args
wantErr bool
}{
{"ok token", az, args{token}, true}, // revoke is disabled
{"bad token", az, args{"bad token"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.azure.AuthorizeRevoke(context.TODO(), tt.args.token); (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestAzure_assertConfig(t *testing.T) {
p1, err := generateAzure()
assert.FatalError(t, err)

View file

@ -78,7 +78,7 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims)
// match with server audiences
if matchesAudience(claims.Audience, audiences) {
// Use fragment to get provisioner name (GCP, AWS)
// Use fragment to get provisioner name (GCP, AWS, SSHPOP)
if fragment != "" {
return c.Load(fragment)
}

View file

@ -14,6 +14,7 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose"
)
@ -210,7 +211,7 @@ func (p *GCP) Init(config Config) error {
func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token)
if err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "gcp.AuthorizeSign")
}
ce := claims.Google.ComputeEngine
@ -239,10 +240,10 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
), nil
}
// AuthorizeRenewal returns an error if the renewal is disabled.
func (p *GCP) AuthorizeRenewal(ctx context.Context, cert *x509.Certificate) error {
// AuthorizeRenew returns an error if the renewal is disabled.
func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner %s", p.GetID())
}
return nil
}
@ -260,10 +261,10 @@ func (p *GCP) assertConfig() {
func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, errors.Wrapf(err, "error parsing token")
return nil, errs.Wrap(http.StatusUnauthorized, err, "gcp.authorizeToken; error parsing gcp token")
}
if len(jwt.Headers) == 0 {
return nil, errors.New("error parsing token: header is missing")
return nil, errs.Unauthorized("gcp.authorizeToken; error parsing gcp token - header is missing")
}
var found bool
@ -277,7 +278,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
}
}
if !found {
return nil, errors.Errorf("failed to validate payload: cannot find key for kid %s", kid)
return nil, errs.Unauthorized("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid %s", kid)
}
// According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -287,12 +288,12 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
Issuer: "https://accounts.google.com",
Time: now,
}, time.Minute); err != nil {
return nil, errors.Wrapf(err, "invalid token")
return nil, errs.Wrap(http.StatusUnauthorized, err, "gcp.authorizeToken; invalid gcp token payload")
}
// validate audiences with the defaults
if !matchesAudience(claims.Audience, p.audiences.Sign) {
return nil, errors.New("invalid token: invalid audience claim (aud)")
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)")
}
// validate subject (service account)
@ -305,7 +306,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
}
}
if !found {
return nil, errors.New("invalid token: invalid subject claim")
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid subject claim")
}
}
@ -319,26 +320,26 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
}
}
if !found {
return nil, errors.New("invalid token: invalid project id")
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid project id")
}
}
// validate instance age
if d := p.InstanceAge.Value(); d > 0 {
if now.Sub(claims.Google.ComputeEngine.InstanceCreationTimestamp.Time()) > d {
return nil, errors.New("token google.compute_engine.instance_creation_timestamp is too old")
return nil, errs.Unauthorized("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old")
}
}
switch {
case claims.Google.ComputeEngine.InstanceID == "":
return nil, errors.New("token google.compute_engine.instance_id cannot be empty")
return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty")
case claims.Google.ComputeEngine.InstanceName == "":
return nil, errors.New("token google.compute_engine.instance_name cannot be empty")
return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty")
case claims.Google.ComputeEngine.ProjectID == "":
return nil, errors.New("token google.compute_engine.project_id cannot be empty")
return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty")
case claims.Google.ComputeEngine.Zone == "":
return nil, errors.New("token google.compute_engine.zone cannot be empty")
return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty")
}
return &claims, nil
@ -347,18 +348,18 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner %s", p.GetID())
}
claims, err := p.authorizeToken(token)
if err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "gcp.AuthorizeSSHSign")
}
ce := claims.Google.ComputeEngine
signOptions := []SignOption{
// set the key id to the token subject
sshCertificateKeyIDModifier(ce.InstanceName),
sshCertKeyIDModifier(ce.InstanceName),
}
// Default to host + known hostnames
@ -370,9 +371,9 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
},
}
// Validate user options
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
signOptions = append(signOptions, sshCertOptionsValidator(defaults))
// Set defaults if not given as user options
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
return append(signOptions,
// Set the default extensions
@ -382,8 +383,8 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertificateValidityValidator{p.claimer},
&sshCertValidityValidator{p.claimer},
// Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{},
&sshCertDefaultValidator{},
), nil
}

View file

@ -16,7 +16,10 @@ import (
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose"
)
func TestGCP_Getters(t *testing.T) {
@ -211,6 +214,202 @@ func TestGCP_Init(t *testing.T) {
}
}
func TestGCP_authorizeToken(t *testing.T) {
type test struct {
p *GCP
token string
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; error parsing gcp token"),
}
},
"fail/cannot-validate-sig": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://accounts.google.com", p.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid "),
}
},
"fail/invalid-issuer": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://foo.bar.zap", p.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; invalid gcp token payload"),
}
},
"fail/invalid-serviceAccount": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
tok, err := generateGCPToken("foo",
"https://accounts.google.com", p.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; invalid gcp token - invalid subject claim"),
}
},
"fail/invalid-projectID": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
p.ProjectIDs = []string{"foo", "bar"}
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://accounts.google.com", p.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; invalid gcp token - invalid project id"),
}
},
"fail/instance-age": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
p.InstanceAge = Duration{1 * time.Minute}
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://accounts.google.com", p.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now().Add(-1*time.Minute), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old"),
}
},
"fail/empty-instance-id": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://accounts.google.com", p.GetID(),
"", "instance-name", "project-id", "zone",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty"),
}
},
"fail/empty-instance-name": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://accounts.google.com", p.GetID(),
"instance-id", "", "project-id", "zone",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty"),
}
},
"fail/empty-project-id": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://accounts.google.com", p.GetID(),
"instance-id", "instance-name", "", "zone",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty"),
}
},
"fail/empty-zone": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://accounts.google.com", p.GetID(),
"instance-id", "instance-name", "project-id", "",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty"),
}
},
"ok": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://accounts.google.com", p.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) && assert.NotNil(t, claims) {
assert.Equals(t, claims.Subject, tc.p.ServiceAccounts[0])
assert.Equals(t, claims.Issuer, "https://accounts.google.com")
assert.NotNil(t, claims.Google)
aud, err := generateSignAudience("https://ca.smallstep.com", tc.p.GetID())
assert.FatalError(t, err)
assert.Equals(t, claims.Audience[0], aud)
}
}
})
}
}
func TestGCP_AuthorizeSign(t *testing.T) {
p1, err := generateGCP()
assert.FatalError(t, err)
@ -313,24 +512,25 @@ func TestGCP_AuthorizeSign(t *testing.T) {
gcp *GCP
args args
wantLen int
code int
wantErr bool
}{
{"ok", p1, args{t1}, 4, false},
{"ok", p2, args{t2}, 6, false},
{"ok", p3, args{t3}, 4, false},
{"fail token", p1, args{"token"}, 0, true},
{"fail key", p1, args{failKey}, 0, true},
{"fail iss", p1, args{failIss}, 0, true},
{"fail aud", p1, args{failAud}, 0, true},
{"fail exp", p1, args{failExp}, 0, true},
{"fail nbf", p1, args{failNbf}, 0, true},
{"fail service account", p1, args{failServiceAccount}, 0, true},
{"fail invalid project id", p3, args{failInvalidProjectID}, 0, true},
{"fail invalid instance age", p3, args{failInvalidInstanceAge}, 0, true},
{"fail instance id", p1, args{failInstanceID}, 0, true},
{"fail instance name", p1, args{failInstanceName}, 0, true},
{"fail project id", p1, args{failProjectID}, 0, true},
{"fail zone", p1, args{failZone}, 0, true},
{"ok", p1, args{t1}, 4, http.StatusOK, false},
{"ok", p2, args{t2}, 6, http.StatusOK, false},
{"ok", p3, args{t3}, 4, http.StatusOK, false},
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
{"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
{"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true},
{"fail aud", p1, args{failAud}, 0, http.StatusUnauthorized, true},
{"fail exp", p1, args{failExp}, 0, http.StatusUnauthorized, true},
{"fail nbf", p1, args{failNbf}, 0, http.StatusUnauthorized, true},
{"fail service account", p1, args{failServiceAccount}, 0, http.StatusUnauthorized, true},
{"fail invalid project id", p3, args{failInvalidProjectID}, 0, http.StatusUnauthorized, true},
{"fail invalid instance age", p3, args{failInvalidInstanceAge}, 0, http.StatusUnauthorized, true},
{"fail instance id", p1, args{failInstanceID}, 0, http.StatusUnauthorized, true},
{"fail instance name", p1, args{failInstanceName}, 0, http.StatusUnauthorized, true},
{"fail project id", p1, args{failProjectID}, 0, http.StatusUnauthorized, true},
{"fail zone", p1, args{failZone}, 0, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -339,8 +539,13 @@ func TestGCP_AuthorizeSign(t *testing.T) {
if (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
} else {
assert.Len(t, tt.wantLen, got)
}
assert.Len(t, tt.wantLen, got)
})
}
}
@ -352,6 +557,14 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
p1, err := generateGCP()
assert.FatalError(t, err)
p2, err := generateGCP()
assert.FatalError(t, err)
// disable sshCA
disable := false
p2.Claims = &Claims{EnableSSHCA: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
t1, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", p1.GetID(),
"instance-id", "instance-name", "project-id", "zone",
@ -394,30 +607,35 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
gcp *GCP
args args
expected *SSHOptions
code int
wantErr bool
wantSignErr bool
}{
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false},
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false},
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, false, false},
{"ok-principal1", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal1, false, false},
{"ok-principal2", p1, args{t1, SSHOptions{Principals: []string{"instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal2, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, false, false},
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true},
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal", "smallstep.com"}}, pub}, nil, false, true},
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, http.StatusOK, false, false},
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-principal1", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal1, http.StatusOK, false, false},
{"ok-principal2", p1, args{t1, SSHOptions{Principals: []string{"instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal2, http.StatusOK, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, http.StatusOK, false, true},
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, http.StatusOK, false, true},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal", "smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
{"fail-sshCA-disabled", p2, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
{"fail-invalid-token", p1, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
got, err := tt.gcp.AuthorizeSSHSign(ctx, tt.args.token)
got, err := tt.gcp.AuthorizeSSHSign(context.Background(), tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got)
} else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
@ -435,7 +653,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
}
}
func TestGCP_AuthorizeRenewal(t *testing.T) {
func TestGCP_AuthorizeRenew(t *testing.T) {
p1, err := generateGCP()
assert.FatalError(t, err)
p2, err := generateGCP()
@ -454,46 +672,20 @@ func TestGCP_AuthorizeRenewal(t *testing.T) {
name string
prov *GCP
args args
code int
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
{"ok", p1, args{nil}, http.StatusOK, false},
{"fail/renewal-disabled", p2, args{nil}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenewal(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestGCP_AuthorizeRevoke(t *testing.T) {
p1, err := generateGCP()
assert.FatalError(t, err)
t1, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", p1.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
gcp *GCP
args args
wantErr bool
}{
{"ok", p1, args{t1}, true}, // revoke is disabled
{"fail", p1, args{"token"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.gcp.AuthorizeRevoke(context.TODO(), tt.args.token); (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}
})
}

View file

@ -3,9 +3,11 @@ package provisioner
import (
"context"
"crypto/x509"
"net/http"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/x509util"
"github.com/smallstep/cli/jose"
)
@ -99,12 +101,12 @@ func (p *JWK) Init(config Config) (err error) {
func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, errors.Wrapf(err, "error parsing token")
return nil, errs.Wrap(http.StatusUnauthorized, err, "jwk.authorizeToken; error parsing jwk token")
}
var claims jwtPayload
if err = jwt.Claims(p.Key, &claims); err != nil {
return nil, errors.Wrap(err, "error parsing claims")
return nil, errs.Wrap(http.StatusUnauthorized, err, "jwk.authorizeToken; error parsing jwk claims")
}
// According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -113,17 +115,17 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
Issuer: p.Name,
Time: time.Now().UTC(),
}, time.Minute); err != nil {
return nil, errors.Wrapf(err, "invalid token")
return nil, errs.Wrapf(http.StatusUnauthorized, err, "jwk.authorizeToken; invalid jwk claims")
}
// validate audiences with the defaults
if !matchesAudience(claims.Audience, audiences) {
return nil, errors.Errorf("invalid token: invalid audience claim (aud); want %s, but got %s",
return nil, errs.Unauthorized("jwk.authorizeToken; invalid jwk token audience claim (aud); want %s, but got %s",
audiences, claims.Audience)
}
if claims.Subject == "" {
return nil, errors.New("token subject cannot be empty")
return nil, errs.Unauthorized("jwk.authorizeToken; jwk token subject cannot be empty")
}
return &claims, nil
@ -133,14 +135,14 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
// revoke the certificate with serial number in the `sub` property.
func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke)
return err
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke")
}
// AuthorizeSign validates the given token.
func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign)
if err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
}
// NOTE: This is for backwards compatibility with older versions of cli
@ -171,7 +173,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// certificate was configured to allow renewals.
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner %s", p.GetID())
}
return nil
}
@ -179,20 +181,20 @@ func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner %s", p.GetID())
}
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
if err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign")
}
if claims.Step == nil || claims.Step.SSH == nil {
return nil, errors.New("authorization token must be an SSH provisioning token")
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; jwk token must be an SSH provisioning token")
}
opts := claims.Step.SSH
signOptions := []SignOption{
// validates user's SSHOptions with the ones in the token
sshCertificateOptionsValidator(*opts),
sshCertOptionsValidator(*opts),
}
t := now()
@ -205,19 +207,19 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
signOptions = append(signOptions, sshCertPrincipalsModifier(opts.Principals))
}
if !opts.ValidAfter.IsZero() {
signOptions = append(signOptions, sshCertificateValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))
signOptions = append(signOptions, sshCertValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))
}
if !opts.ValidBefore.IsZero() {
signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
signOptions = append(signOptions, sshCertValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
}
if opts.KeyID != "" {
signOptions = append(signOptions, sshCertificateKeyIDModifier(opts.KeyID))
signOptions = append(signOptions, sshCertKeyIDModifier(opts.KeyID))
} else {
signOptions = append(signOptions, sshCertificateKeyIDModifier(claims.Subject))
signOptions = append(signOptions, sshCertKeyIDModifier(claims.Subject))
}
// Default to a user certificate with no principals if not set
signOptions = append(signOptions, sshCertificateDefaultsModifier{CertType: SSHUserCert})
signOptions = append(signOptions, sshCertDefaultsModifier{CertType: SSHUserCert})
return append(signOptions,
// Set the default extensions.
@ -229,14 +231,14 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertificateValidityValidator{p.claimer},
&sshCertValidityValidator{p.claimer},
// Require and validate all the default fields in the SSH certificate.
&sshCertificateDefaultValidator{},
&sshCertDefaultValidator{},
), nil
}
// AuthorizeSSHRevoke returns nil if the token is valid, false otherwise.
func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.SSHRevoke)
return err
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke")
}

View file

@ -7,12 +7,14 @@ import (
"crypto/rsa"
"crypto/x509"
"net"
"net/http"
"strings"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose"
)
@ -162,25 +164,29 @@ func TestJWK_authorizeToken(t *testing.T) {
name string
prov *JWK
args args
code int
err error
}{
{"fail-token", p1, args{failTok}, errors.New("error parsing token")},
{"fail-key", p1, args{failKey}, errors.New("error parsing claims")},
{"fail-claims", p1, args{failClaims}, errors.New("error parsing claims")},
{"fail-signature", p1, args{failSig}, errors.New("error parsing claims: square/go-jose: error in cryptographic primitive")},
{"fail-issuer", p1, args{failIss}, errors.New("invalid token: square/go-jose/jwt: validation failed, invalid issuer claim (iss)")},
{"fail-expired", p1, args{failExp}, errors.New("invalid token: square/go-jose/jwt: validation failed, token is expired (exp)")},
{"fail-not-before", p1, args{failNbf}, errors.New("invalid token: square/go-jose/jwt: validation failed, token not valid yet (nbf)")},
{"fail-audience", p1, args{failAud}, errors.New("invalid token: invalid audience claim (aud)")},
{"fail-subject", p1, args{failSub}, errors.New("token subject cannot be empty")},
{"ok", p1, args{t1}, nil},
{"ok-no-encrypted-key", p2, args{t2}, nil},
{"ok-no-sans", p1, args{t3}, nil},
{"fail-token", p1, args{failTok}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk token")},
{"fail-key", p1, args{failKey}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk claims")},
{"fail-claims", p1, args{failClaims}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk claims")},
{"fail-signature", p1, args{failSig}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk claims: square/go-jose: error in cryptographic primitive")},
{"fail-issuer", p1, args{failIss}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk claims: square/go-jose/jwt: validation failed, invalid issuer claim (iss)")},
{"fail-expired", p1, args{failExp}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk claims: square/go-jose/jwt: validation failed, token is expired (exp)")},
{"fail-not-before", p1, args{failNbf}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk claims: square/go-jose/jwt: validation failed, token not valid yet (nbf)")},
{"fail-audience", p1, args{failAud}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk token audience claim (aud)")},
{"fail-subject", p1, args{failSub}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; jwk token subject cannot be empty")},
{"ok", p1, args{t1}, http.StatusOK, nil},
{"ok-no-encrypted-key", p2, args{t2}, http.StatusOK, nil},
{"ok-no-sans", p1, args{t3}, http.StatusOK, nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got, err := tt.prov.authorizeToken(tt.args.token, testAudiences.Sign); err != nil {
if assert.NotNil(t, tt.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
} else {
@ -208,15 +214,19 @@ func TestJWK_AuthorizeRevoke(t *testing.T) {
name string
prov *JWK
args args
code int
err error
}{
{"fail-signature", p1, args{failSig}, errors.New("error parsing claims: square/go-jose: error in cryptographic primitive")},
{"ok", p1, args{t1}, nil},
{"fail-signature", p1, args{failSig}, http.StatusUnauthorized, errors.New("jwk.AuthorizeRevoke: jwk.authorizeToken; error parsing jwk claims: square/go-jose: error in cryptographic primitive")},
{"ok", p1, args{t1}, http.StatusOK, nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRevoke(context.TODO(), tt.args.token); err != nil {
if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil {
if assert.NotNil(t, tt.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
}
@ -246,20 +256,24 @@ func TestJWK_AuthorizeSign(t *testing.T) {
name string
prov *JWK
args args
code int
err error
dns []string
emails []string
ips []net.IP
}{
{name: "fail-signature", prov: p1, args: args{failSig}, err: errors.New("error parsing claims: square/go-jose: error in cryptographic primitive")},
{"ok-sans", p1, args{t1}, nil, []string{"foo"}, []string{"max@smallstep.com"}, []net.IP{net.ParseIP("127.0.0.1")}},
{"ok-no-sans", p1, args{t2}, nil, []string{"subject"}, []string{}, []net.IP{}},
{name: "fail-signature", prov: p1, args: args{failSig}, code: http.StatusUnauthorized, err: errors.New("jwk.AuthorizeSign: jwk.authorizeToken; error parsing jwk claims: square/go-jose: error in cryptographic primitive")},
{"ok-sans", p1, args{t1}, http.StatusOK, nil, []string{"foo"}, []string{"max@smallstep.com"}, []net.IP{net.ParseIP("127.0.0.1")}},
{"ok-no-sans", p1, args{t2}, http.StatusOK, nil, []string{"subject"}, []string{}, []net.IP{}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignMethod)
if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil {
if assert.NotNil(t, tt.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
} else {
@ -315,15 +329,20 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
name string
prov *JWK
args args
code int
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
{"ok", p1, args{nil}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}
})
}
@ -335,6 +354,14 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
p1, err := generateJWK()
assert.FatalError(t, err)
p2, err := generateJWK()
assert.FatalError(t, err)
// disable sshCA
disable := false
p2.Claims = &Claims{EnableSSHCA: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
assert.FatalError(t, err)
@ -382,30 +409,34 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
prov *JWK
args args
expected *SSHOptions
code int
wantErr bool
wantSignErr bool
}{
{"user", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, false, false},
{"user-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, false, false},
{"user-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, false, false},
{"user-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub}, expectedUserOptions, false, false},
{"user-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, false, false},
{"host", p1, args{t2, SSHOptions{}, pub}, expectedHostOptions, false, false},
{"host-type", p1, args{t2, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false},
{"host-principals", p1, args{t2, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false},
{"host-options", p1, args{t2, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false},
{"fail-signature", p1, args{failSig, SSHOptions{}, pub}, nil, true, false},
{"rail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, false, true},
{"user", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, http.StatusOK, false, false},
{"user-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, http.StatusOK, false, false},
{"user-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, http.StatusOK, false, false},
{"user-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub}, expectedUserOptions, http.StatusOK, false, false},
{"user-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, http.StatusOK, false, false},
{"host", p1, args{t2, SSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"host-type", p1, args{t2, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"host-principals", p1, args{t2, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"host-options", p1, args{t2, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"fail-sshCA-disabled", p2, args{"foo", SSHOptions{}, pub}, expectedUserOptions, http.StatusUnauthorized, true, false},
{"fail-signature", p1, args{failSig, SSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false},
{"rail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
got, err := tt.prov.AuthorizeSSHSign(ctx, tt.args.token)
got, err := tt.prov.AuthorizeSSHSign(context.Background(), tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("JWK.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got)
} else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
@ -511,10 +542,9 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
token, err := generateSSHToken(tt.args.sub, tt.args.iss, tt.args.aud, tt.args.iat, tt.args.tokSSHOpts, tt.args.jwk)
assert.FatalError(t, err)
if got, err := tt.prov.AuthorizeSSHSign(ctx, token); (err != nil) != tt.wantErr {
if got, err := tt.prov.AuthorizeSSHSign(context.Background(), token); (err != nil) != tt.wantErr {
t.Errorf("JWK.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
} else if !tt.wantErr && assert.NotNil(t, got) {
var opts SSHOptions
@ -535,3 +565,52 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) {
})
}
}
func TestJWK_AuthorizeSSHRevoke(t *testing.T) {
type test struct {
p *JWK
token string
code int
err error
}
tests := map[string]func(*testing.T) test{
"fail/invalid-token": func(t *testing.T) test {
p, err := generateJWK()
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("jwk.AuthorizeSSHRevoke: jwk.authorizeToken; error parsing jwk token"),
}
},
"ok": func(t *testing.T) test {
p, err := generateJWK()
assert.FatalError(t, err)
jwk, err := decryptJSONWebKey(p.EncryptedKey)
assert.FatalError(t, err)
tok, err := generateToken("subject", p.Name, testAudiences.SSHRevoke[0], "name@smallstep.com", []string{"127.0.0.1", "max@smallstep.com", "foo"}, time.Now(), jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}

View file

@ -6,8 +6,10 @@ import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/jose"
"golang.org/x/crypto/ed25519"
@ -138,7 +140,8 @@ func (p *K8sSA) Init(config Config) (err error) {
func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, errors.Wrapf(err, "error parsing token")
return nil, errs.Wrap(http.StatusUnauthorized, err,
"k8ssa.authorizeToken; error parsing k8sSA token")
}
var (
@ -146,7 +149,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
claims k8sSAPayload
)
if p.pubKeys == nil {
return nil, errors.New("TokenReview API integration not implemented")
return nil, errs.Unauthorized("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented")
/* NOTE: We plan to support the TokenReview API in a future release.
Below is some code that should be useful when we prioritize
this integration.
@ -174,7 +177,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
}
}
if !valid {
return nil, errors.New("error validating token and extracting claims")
return nil, errs.Unauthorized("k8ssa.authorizeToken; error validating k8sSA token and extracting claims")
}
// According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -182,11 +185,11 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
if err = claims.Validate(jose.Expected{
Issuer: k8sSAIssuer,
}); err != nil {
return nil, errors.Wrapf(err, "invalid token claims")
return nil, errs.Wrap(http.StatusUnauthorized, err, "k8ssa.authorizeToken; invalid k8sSA token claims")
}
if claims.Subject == "" {
return nil, errors.New("token subject cannot be empty")
return nil, errs.Unauthorized("k8ssa.authorizeToken; k8sSA token subject cannot be empty")
}
return &claims, nil
@ -196,14 +199,13 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
// revoke the certificate with serial number in the `sub` property.
func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke)
return err
return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke")
}
// AuthorizeSign validates the given token.
func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
_, err := p.authorizeToken(token, p.audiences.Sign)
if err != nil {
return nil, err
if _, err := p.authorizeToken(token, p.audiences.Sign); err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign")
}
return []SignOption{
@ -219,7 +221,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// AuthorizeRenew returns an error if the renewal is disabled.
func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID())
}
return nil
}
@ -227,17 +229,14 @@ func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro
// AuthorizeSSHSign validates an request for an SSH certificate.
func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
return nil, errors.Errorf("authorizeSSHSign: ssh ca is disabled for provisioner %s", p.GetID())
return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID())
}
_, err := p.authorizeToken(token, p.audiences.SSHSign)
if err != nil {
return nil, errors.Wrap(err, "authorizeSSHSign")
if _, err := p.authorizeToken(token, p.audiences.SSHSign); err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign")
}
// Default to a user certificate with no principals if not set
signOptions := []SignOption{
sshCertificateDefaultsModifier{CertType: SSHUserCert},
}
signOptions := []SignOption{sshCertDefaultsModifier{CertType: SSHUserCert}}
return append(signOptions,
// Set the default extensions.
@ -247,9 +246,9 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertificateValidityValidator{p.claimer},
&sshCertValidityValidator{p.claimer},
// Require and validate all the default fields in the SSH certificate.
&sshCertificateDefaultValidator{},
&sshCertDefaultValidator{},
), nil
}

View file

@ -3,11 +3,13 @@ package provisioner
import (
"context"
"crypto/x509"
"net/http"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose"
)
@ -36,6 +38,7 @@ func TestK8sSA_authorizeToken(t *testing.T) {
p *K8sSA
token string
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
@ -44,7 +47,24 @@ func TestK8sSA_authorizeToken(t *testing.T) {
return test{
p: p,
token: "foo",
err: errors.New("error parsing token"),
code: http.StatusUnauthorized,
err: errors.New("k8ssa.authorizeToken; error parsing k8sSA token"),
}
},
"fail/not-implemented": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
p, err := generateK8sSA(nil)
assert.FatalError(t, err)
tok, err := generateToken("", p.Name, testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk)
p.pubKeys = nil
assert.FatalError(t, err)
return test{
p: p,
token: tok,
err: errors.New("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented"),
code: http.StatusUnauthorized,
}
},
"fail/error-validating-token": func(t *testing.T) test {
@ -58,7 +78,8 @@ func TestK8sSA_authorizeToken(t *testing.T) {
return test{
p: p,
token: tok,
err: errors.New("error validating token and extracting claims"),
err: errors.New("k8ssa.authorizeToken; error validating k8sSA token and extracting claims"),
code: http.StatusUnauthorized,
}
},
"fail/invalid-issuer": func(t *testing.T) test {
@ -73,7 +94,8 @@ func TestK8sSA_authorizeToken(t *testing.T) {
return test{
p: p,
token: tok,
err: errors.New("invalid token claims: square/go-jose/jwt: validation failed, invalid issuer claim (iss)"),
code: http.StatusUnauthorized,
err: errors.New("k8ssa.authorizeToken; invalid k8sSA token claims: square/go-jose/jwt: validation failed, invalid issuer claim (iss)"),
}
},
"ok": func(t *testing.T) test {
@ -94,6 +116,9 @@ func TestK8sSA_authorizeToken(t *testing.T) {
tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
@ -105,12 +130,12 @@ func TestK8sSA_authorizeToken(t *testing.T) {
}
}
func TestK8sSA_AuthorizeSign(t *testing.T) {
func TestK8sSA_AuthorizeRevoke(t *testing.T) {
type test struct {
p *K8sSA
token string
ctx context.Context
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/invalid-token": func(t *testing.T) test {
@ -119,21 +144,8 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
return test{
p: p,
token: "foo",
err: errors.New("error parsing token"),
}
},
"fail/ssh-unimplemented": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
p, err := generateK8sSA(jwk.Public().Key)
assert.FatalError(t, err)
tok, err := generateK8sSAToken(jwk, nil)
assert.FatalError(t, err)
return test{
p: p,
ctx: NewContextWithMethod(context.Background(), SignSSHMethod),
token: tok,
err: errors.Errorf("ssh certificates not enabled for k8s ServiceAccount provisioners"),
code: http.StatusUnauthorized,
err: errors.New("k8ssa.AuthorizeRevoke: k8ssa.authorizeToken; error parsing k8sSA token"),
}
},
"ok": func(t *testing.T) test {
@ -145,7 +157,6 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
assert.FatalError(t, err)
return test{
p: p,
ctx: NewContextWithMethod(context.Background(), SignMethod),
token: tok,
}
},
@ -153,10 +164,110 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSign(tc.ctx, tc.token); err != nil {
if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestK8sSA_AuthorizeRenew(t *testing.T) {
type test struct {
p *K8sSA
cert *x509.Certificate
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/renew-disabled": func(t *testing.T) test {
p, err := generateK8sSA(nil)
assert.FatalError(t, err)
// disable renewal
disable := true
p.Claims = &Claims{DisableRenewal: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
return test{
p: p,
cert: &x509.Certificate{},
code: http.StatusUnauthorized,
err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID()),
}
},
"ok": func(t *testing.T) test {
p, err := generateK8sSA(nil)
assert.FatalError(t, err)
return test{
p: p,
cert: &x509.Certificate{},
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestK8sSA_AuthorizeSign(t *testing.T) {
type test struct {
p *K8sSA
token string
code int
err error
}
tests := map[string]func(*testing.T) test{
"fail/invalid-token": func(t *testing.T) test {
p, err := generateK8sSA(nil)
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("k8ssa.AuthorizeSign: k8ssa.authorizeToken; error parsing k8sSA token"),
}
},
"ok": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
p, err := generateK8sSA(jwk.Public().Key)
assert.FatalError(t, err)
tok, err := generateK8sSAToken(jwk, nil)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) {
@ -187,20 +298,37 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
}
}
func TestK8sSA_AuthorizeRevoke(t *testing.T) {
func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
type test struct {
p *K8sSA
token string
code int
err error
}
tests := map[string]func(*testing.T) test{
"fail/sshCA-disabled": func(t *testing.T) test {
p, err := generateK8sSA(nil)
assert.FatalError(t, err)
// disable sshCA
disable := false
p.Claims = &Claims{EnableSSHCA: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID()),
}
},
"fail/invalid-token": func(t *testing.T) test {
p, err := generateK8sSA(nil)
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
err: errors.New("error parsing token"),
code: http.StatusUnauthorized,
err: errors.New("k8ssa.AuthorizeSSHSign: k8ssa.authorizeToken; error parsing k8sSA token"),
}
},
"ok": func(t *testing.T) test {
@ -219,45 +347,36 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) {
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeRevoke(context.TODO(), tc.token); err != nil {
if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestK8sSA_AuthorizeRenew(t *testing.T) {
p1, err := generateK8sSA(nil)
assert.FatalError(t, err)
p2, err := generateK8sSA(nil)
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
prov *K8sSA
args args
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("X5C.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) {
tot := 0
for _, o := range opts {
switch v := o.(type) {
case sshCertDefaultsModifier:
assert.Equals(t, v.CertType, SSHUserCert)
case *sshDefaultExtensionModifier:
case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer)
case *sshDefaultPublicKeyValidator:
case *sshCertDefaultValidator:
case *sshDefaultDuration:
assert.Equals(t, v.Claimer, tc.p.claimer)
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
}
tot++
}
assert.Equals(t, tot, 6)
}
}
}
})
}

View file

@ -16,14 +16,16 @@ const (
SignMethod Method = iota
// RevokeMethod is the method used to revoke X.509 certificates.
RevokeMethod
// SignSSHMethod is the method used to sign SSH certificates.
SignSSHMethod
// RenewSSHMethod is the method used to renew SSH certificates.
RenewSSHMethod
// RevokeSSHMethod is the method used to revoke SSH certificates.
RevokeSSHMethod
// RekeySSHMethod is the method used to rekey SSH certificates.
RekeySSHMethod
// RenewMethod is the method used to renew X.509 certificates.
RenewMethod
// SSHSignMethod is the method used to sign SSH certificates.
SSHSignMethod
// SSHRenewMethod is the method used to renew SSH certificates.
SSHRenewMethod
// SSHRevokeMethod is the method used to revoke SSH certificates.
SSHRevokeMethod
// SSHRekeyMethod is the method used to rekey SSH certificates.
SSHRekeyMethod
)
// String returns a string representation of the context method.
@ -33,14 +35,16 @@ func (m Method) String() string {
return "sign-method"
case RevokeMethod:
return "revoke-method"
case SignSSHMethod:
return "sign-ssh-method"
case RenewSSHMethod:
return "renew-ssh-method"
case RevokeSSHMethod:
return "revoke-ssh-method"
case RekeySSHMethod:
return "rekey-ssh-method"
case RenewMethod:
return "renew-method"
case SSHSignMethod:
return "ssh-sign-method"
case SSHRenewMethod:
return "ssh-renew-method"
case SSHRevokeMethod:
return "ssh-revoke-method"
case SSHRekeyMethod:
return "ssh-rekey-method"
default:
return "unknown"
}

View file

@ -14,8 +14,8 @@ func Test_noop(t *testing.T) {
assert.Equals(t, "noop", p.GetName())
assert.Equals(t, noopType, p.GetType())
assert.Equals(t, nil, p.Init(Config{}))
assert.Equals(t, nil, p.AuthorizeRenew(context.TODO(), &x509.Certificate{}))
assert.Equals(t, nil, p.AuthorizeRevoke(context.TODO(), "foo"))
assert.Equals(t, nil, p.AuthorizeRenew(context.Background(), &x509.Certificate{}))
assert.Equals(t, nil, p.AuthorizeRevoke(context.Background(), "foo"))
kid, key, ok := p.GetEncryptedKey()
assert.Equals(t, "", kid)

View file

@ -12,6 +12,7 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose"
)
@ -189,17 +190,17 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
Audience: jose.Audience{o.ClientID},
Time: time.Now().UTC(),
}, time.Minute); err != nil {
return errors.Wrap(err, "failed to validate payload")
return errs.Wrap(http.StatusUnauthorized, err, "validatePayload: failed to validate oidc token payload")
}
// Validate azp if present
if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID {
return errors.New("failed to validate payload: invalid azp")
return errs.Unauthorized("validatePayload: failed to validate oidc token payload: invalid azp")
}
// Enforce an email claim
if p.Email == "" {
return errors.New("failed to validate payload: email not found")
return errs.Unauthorized("validatePayload: failed to validate oidc token payload: email not found")
}
// Validate domains (case-insensitive)
@ -213,7 +214,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
}
}
if !found {
return errors.New("failed to validate payload: email is not allowed")
return errs.Unauthorized("validatePayload: failed to validate oidc token payload: email is not allowed")
}
}
@ -229,7 +230,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
}
}
if !found {
return errors.New("validation failed: invalid group")
return errs.Unauthorized("validatePayload: oidc token payload validation failed: invalid group")
}
}
@ -241,13 +242,15 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, errors.Wrapf(err, "error parsing token")
return nil, errs.Wrap(http.StatusUnauthorized, err,
"oidc.AuthorizeToken; error parsing oidc token")
}
// Parse claims to get the kid
var claims openIDPayload
if err := jwt.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, errors.Wrap(err, "error parsing claims")
return nil, errs.Wrap(http.StatusUnauthorized, err,
"oidc.AuthorizeToken; error parsing oidc token claims")
}
found := false
@ -260,11 +263,11 @@ func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) {
}
}
if !found {
return nil, errors.New("cannot validate token")
return nil, errs.Unauthorized("oidc.AuthorizeToken; cannot validate oidc token")
}
if err := o.ValidatePayload(claims); err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeToken")
}
return &claims, nil
@ -276,21 +279,21 @@ func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) {
func (o *OIDC) AuthorizeRevoke(ctx context.Context, token string) error {
claims, err := o.authorizeToken(token)
if err != nil {
return err
return errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeRevoke")
}
// Only admins can revoke certificates.
if o.IsAdmin(claims.Email) {
return nil
}
return errors.New("cannot revoke with non-admin token")
return errs.Unauthorized("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token")
}
// AuthorizeSign validates the given token.
func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := o.authorizeToken(token)
if err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSign")
}
so := []SignOption{
@ -315,7 +318,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// certificate was configured to allow renewals.
func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if o.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", o.GetID())
return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner %s", o.GetID())
}
return nil
}
@ -323,22 +326,22 @@ func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !o.claimer.IsSSHCAEnabled() {
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", o.GetID())
return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner %s", o.GetID())
}
claims, err := o.authorizeToken(token)
if err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
}
signOptions := []SignOption{
// set the key id to the token email
sshCertificateKeyIDModifier(claims.Email),
sshCertKeyIDModifier(claims.Email),
}
// Get the identity using either the default identityFunc or one injected
// externally.
iden, err := o.getIdentityFunc(o, claims.Email)
if err != nil {
return nil, errors.Wrap(err, "authorizeSSHSign")
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
}
defaults := SSHOptions{
CertType: SSHUserCert,
@ -349,12 +352,12 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
// Non-admin users can only use principals returned by the identityFunc, and
// can only sign user certificates.
if !o.IsAdmin(claims.Email) {
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
signOptions = append(signOptions, sshCertOptionsValidator(defaults))
}
// Default to a user certificate with usernames as principals if those options
// are not set.
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
return append(signOptions,
// Set the default extensions
@ -364,9 +367,9 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertificateValidityValidator{o.claimer},
&sshCertValidityValidator{o.claimer},
// Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{},
&sshCertDefaultValidator{},
), nil
}
@ -374,14 +377,14 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
func (o *OIDC) AuthorizeSSHRevoke(ctx context.Context, token string) error {
claims, err := o.authorizeToken(token)
if err != nil {
return err
return errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHRevoke")
}
// Only admins can revoke certificates.
if o.IsAdmin(claims.Email) {
return nil
if !o.IsAdmin(claims.Email) {
return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token")
}
return errors.New("cannot revoke with non-admin token")
return nil
}
func getAndDecode(uri string, v interface{}) error {

View file

@ -7,12 +7,14 @@ import (
"crypto/rsa"
"crypto/x509"
"fmt"
"net/http"
"strings"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose"
)
@ -206,20 +208,21 @@ func TestOIDC_authorizeToken(t *testing.T) {
name string
prov *OIDC
args args
code int
wantErr bool
}{
{"ok1", p1, args{t1}, false},
{"ok2", p2, args{t2}, false},
{"fail-email", p3, args{failEmail}, true},
{"fail-domain", p3, args{failDomain}, true},
{"fail-key", p1, args{failKey}, true},
{"fail-token", p1, args{failTok}, true},
{"fail-claims", p1, args{failClaims}, true},
{"fail-issuer", p1, args{failIss}, true},
{"fail-audience", p1, args{failAud}, true},
{"fail-signature", p1, args{failSig}, true},
{"fail-expired", p1, args{failExp}, true},
{"fail-not-before", p1, args{failNbf}, true},
{"ok1", p1, args{t1}, http.StatusOK, false},
{"ok2", p2, args{t2}, http.StatusOK, false},
{"fail-email", p3, args{failEmail}, http.StatusUnauthorized, true},
{"fail-domain", p3, args{failDomain}, http.StatusUnauthorized, true},
{"fail-key", p1, args{failKey}, http.StatusUnauthorized, true},
{"fail-token", p1, args{failTok}, http.StatusUnauthorized, true},
{"fail-claims", p1, args{failClaims}, http.StatusUnauthorized, true},
{"fail-issuer", p1, args{failIss}, http.StatusUnauthorized, true},
{"fail-audience", p1, args{failAud}, http.StatusUnauthorized, true},
{"fail-signature", p1, args{failSig}, http.StatusUnauthorized, true},
{"fail-expired", p1, args{failExp}, http.StatusUnauthorized, true},
{"fail-not-before", p1, args{failNbf}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -230,6 +233,9 @@ func TestOIDC_authorizeToken(t *testing.T) {
return
}
if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got)
} else {
assert.NotNil(t, got)
@ -282,21 +288,24 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
name string
prov *OIDC
args args
code int
wantErr bool
}{
{"ok1", p1, args{t1}, false},
{"admin", p3, args{okAdmin}, false},
{"fail-email", p3, args{failEmail}, true},
{"ok1", p1, args{t1}, http.StatusOK, false},
{"admin", p3, args{okAdmin}, http.StatusOK, false},
{"fail-email", p3, args{failEmail}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignMethod)
got, err := tt.prov.AuthorizeSign(ctx, tt.args.token)
got, err := tt.prov.AuthorizeSign(context.Background(), tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got)
} else {
if assert.NotNil(t, got) {
@ -330,6 +339,107 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
}
}
func TestOIDC_AuthorizeRevoke(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
var keys jose.JSONWebKeySet
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
// Create test provisioners
p1, err := generateOIDC()
assert.FatalError(t, err)
p3, err := generateOIDC()
assert.FatalError(t, err)
// Admin + Domains
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
p3.Domains = []string{"smallstep.com"}
// Update configuration endpoints and initialize
config := Config{Claims: globalProvisionerClaims}
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
assert.FatalError(t, p1.Init(config))
assert.FatalError(t, p3.Init(config))
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
// Admin email not in domains
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
// Invalid email
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
prov *OIDC
args args
code int
wantErr bool
}{
{"ok1", p1, args{t1}, http.StatusUnauthorized, true},
{"admin", p3, args{okAdmin}, http.StatusOK, false},
{"fail-email", p3, args{failEmail}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token)
if (err != nil) != tt.wantErr {
fmt.Println(tt)
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}
})
}
}
func TestOIDC_AuthorizeRenew(t *testing.T) {
p1, err := generateOIDC()
assert.FatalError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
prov *OIDC
args args
code int
wantErr bool
}{
{"ok", p1, args{nil}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert)
if (err != nil) != tt.wantErr {
t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}
})
}
}
func TestOIDC_AuthorizeSSHSign(t *testing.T) {
tm, fn := mockNow()
defer fn()
@ -351,9 +461,16 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
assert.FatalError(t, err)
p5, err := generateOIDC()
assert.FatalError(t, err)
p6, err := generateOIDC()
assert.FatalError(t, err)
// Admin + Domains
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
p3.Domains = []string{"smallstep.com"}
// disable sshCA
disable := false
p6.Claims = &Claims{EnableSSHCA: &disable}
p6.claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
// Update configuration endpoints and initialize
config := Config{Claims: globalProvisionerClaims}
@ -425,48 +542,53 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
prov *OIDC
args args
expected *SSHOptions
code int
wantErr bool
wantSignErr bool
}{
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, false, false},
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, false, false},
{"ok-user", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, false, false},
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, http.StatusOK, false, false},
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, http.StatusOK, false, false},
{"ok-user", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, http.StatusOK, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub},
&SSHOptions{CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
{"ok-principals-getIdentity", p4, args{okGetIdentityToken, SSHOptions{Principals: []string{"mariano"}}, pub},
&SSHOptions{CertType: "user", Principals: []string{"mariano"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
{"ok-emptyPrincipals-getIdentity", p4, args{okGetIdentityToken, SSHOptions{}, pub},
&SSHOptions{CertType: "user", Principals: []string{"max", "mariano"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub},
&SSHOptions{CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
{"admin", p3, args{okAdmin, SSHOptions{}, pub}, expectedAdminOptions, false, false},
{"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}, pub}, expectedAdminOptions, false, false},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
{"admin", p3, args{okAdmin, SSHOptions{}, pub}, expectedAdminOptions, http.StatusOK, false, false},
{"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}, pub}, expectedAdminOptions, http.StatusOK, false, false},
{"admin-principals", p3, args{okAdmin, SSHOptions{Principals: []string{"root"}}, pub},
&SSHOptions{CertType: "user", Principals: []string{"root"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
{"admin-options", p3, args{okAdmin, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub},
&SSHOptions{CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
{"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false},
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, false, true},
{"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}, pub}, nil, false, true},
{"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}, pub}, nil, false, true},
{"fail-email", p3, args{failEmail, SSHOptions{}, pub}, nil, true, false},
{"fail-getIdentity", p5, args{failGetIdentityToken, SSHOptions{}, pub}, nil, true, false},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
{"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub},
expectedHostOptions, http.StatusOK, false, false},
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true},
{"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}, pub}, nil, http.StatusOK, false, true},
{"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}, pub}, nil, http.StatusOK, false, true},
{"fail-email", p3, args{failEmail, SSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false},
{"fail-getIdentity", p5, args{failGetIdentityToken, SSHOptions{}, pub}, nil, http.StatusInternalServerError, true, false},
{"fail-sshCA-disabled", p6, args{"foo", SSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
got, err := tt.prov.AuthorizeSSHSign(ctx, tt.args.token)
got, err := tt.prov.AuthorizeSSHSign(context.Background(), tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("OIDC.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got)
} else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
@ -484,36 +606,32 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
}
}
func TestOIDC_AuthorizeRevoke(t *testing.T) {
func TestOIDC_AuthorizeSSHRevoke(t *testing.T) {
p1, err := generateOIDC()
assert.FatalError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
p2.Admins = []string{"root@example.com"}
srv := generateJWKServer(2)
defer srv.Close()
var keys jose.JSONWebKeySet
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
// Create test provisioners
p1, err := generateOIDC()
assert.FatalError(t, err)
p3, err := generateOIDC()
assert.FatalError(t, err)
// Admin + Domains
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
p3.Domains = []string{"smallstep.com"}
// Update configuration endpoints and initialize
config := Config{Claims: globalProvisionerClaims}
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
assert.FatalError(t, p1.Init(config))
assert.FatalError(t, p3.Init(config))
assert.FatalError(t, p2.Init(config))
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
// Invalid email
failEmail, err := generateToken("subject", "the-issuer", p1.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
// Admin email not in domains
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
noAdmin, err := generateToken("subject", "the-issuer", p1.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
// Invalid email
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
// Admin email in domains
okAdmin, err := generateToken("subject", "the-issuer", p2.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
type args struct {
@ -523,52 +641,22 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) {
name string
prov *OIDC
args args
code int
wantErr bool
}{
{"ok1", p1, args{t1}, true},
{"admin", p3, args{okAdmin}, false},
{"fail-email", p3, args{failEmail}, true},
{"ok", p2, args{okAdmin}, http.StatusOK, false},
{"fail/invalid-token", p1, args{failEmail}, http.StatusUnauthorized, true},
{"fail/not-admin", p1, args{noAdmin}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.prov.AuthorizeRevoke(context.TODO(), tt.args.token)
err := tt.prov.AuthorizeSSHRevoke(context.Background(), tt.args.token)
if (err != nil) != tt.wantErr {
fmt.Println(tt)
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func TestOIDC_AuthorizeRenew(t *testing.T) {
p1, err := generateOIDC()
assert.FatalError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
prov *OIDC
args args
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("OIDC.AuthorizeSSHRevoke() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}
})
}

View file

@ -10,6 +10,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
)
@ -283,43 +284,43 @@ type base struct{}
// AuthorizeSign returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for signing x509 Certificates.
func (b *base) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return nil, errors.New("not implemented; provisioner does not implement AuthorizeSign")
return nil, errs.Unauthorized("provisioner.AuthorizeSign not implemented")
}
// AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for revoking x509 Certificates.
func (b *base) AuthorizeRevoke(ctx context.Context, token string) error {
return errors.New("not implemented; provisioner does not implement AuthorizeRevoke")
return errs.Unauthorized("provisioner.AuthorizeRevoke not implemented")
}
// AuthorizeRenew returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for renewing x509 Certificates.
func (b *base) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
return errors.New("not implemented; provisioner does not implement AuthorizeRenew")
return errs.Unauthorized("provisioner.AuthorizeRenew not implemented")
}
// AuthorizeSSHSign returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for signing SSH Certificates.
func (b *base) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
return nil, errors.New("not implemented; provisioner does not implement AuthorizeSSHSign")
return nil, errs.Unauthorized("provisioner.AuthorizeSSHSign not implemented")
}
// AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for revoking SSH Certificates.
func (b *base) AuthorizeSSHRevoke(ctx context.Context, token string) error {
return errors.New("not implemented; provisioner does not implement AuthorizeSSHRevoke")
return errs.Unauthorized("provisioner.AuthorizeSSHRevoke not implemented")
}
// AuthorizeSSHRenew returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for renewing SSH Certificates.
func (b *base) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
return nil, errors.New("not implemented; provisioner does not implement AuthorizeSSHRenew")
return nil, errs.Unauthorized("provisioner.AuthorizeSSHRenew not implemented")
}
// AuthorizeSSHRekey returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for rekeying SSH Certificates.
func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
return nil, nil, errors.New("not implemented; provisioner does not implement AuthorizeSSHRekey")
return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented")
}
// Identity is the type representing an externally supplied identity that is used

View file

@ -1,10 +1,14 @@
package provisioner
import (
"context"
"net/http"
"testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
)
func TestType_String(t *testing.T) {
@ -101,3 +105,93 @@ func TestDefaultIdentityFunc(t *testing.T) {
})
}
}
func TestUnimplementedMethods(t *testing.T) {
tests := []struct {
name string
p Interface
method Method
}{
{"jwk/sshRekey", &JWK{}, SSHRekeyMethod},
{"jwk/sshRenew", &JWK{}, SSHRenewMethod},
{"aws/revoke", &AWS{}, RevokeMethod},
{"aws/sshRenew", &AWS{}, SSHRenewMethod},
{"aws/rekey", &AWS{}, SSHRekeyMethod},
{"aws/sshRevoke", &AWS{}, SSHRevokeMethod},
{"azure/revoke", &Azure{}, RevokeMethod},
{"azure/sshRenew", &Azure{}, SSHRenewMethod},
{"azure/sshRekey", &Azure{}, SSHRekeyMethod},
{"azure/sshRevoke", &Azure{}, SSHRevokeMethod},
{"gcp/revoke", &GCP{}, RevokeMethod},
{"gcp/sshRenew", &GCP{}, SSHRenewMethod},
{"gcp/sshRekey", &GCP{}, SSHRekeyMethod},
{"gcp/sshRevoke", &GCP{}, SSHRevokeMethod},
{"oidc/sshRenew", &OIDC{}, SSHRenewMethod},
{"oidc/sshRekey", &OIDC{}, SSHRekeyMethod},
{"x5c/sshRenew", &X5C{}, SSHRenewMethod},
{"x5c/sshRekey", &X5C{}, SSHRekeyMethod},
{"x5c/sshRevoke", &X5C{}, SSHRekeyMethod},
{"acme/revoke", &ACME{}, RevokeMethod},
{"acme/sshSign", &ACME{}, SSHSignMethod},
{"acme/sshRekey", &ACME{}, SSHRekeyMethod},
{"acme/sshRenew", &ACME{}, SSHRenewMethod},
{"acme/sshRevoke", &ACME{}, SSHRevokeMethod},
{"sshpop/sign", &SSHPOP{}, SignMethod},
{"sshpop/renew", &SSHPOP{}, RenewMethod},
{"sshpop/revoke", &SSHPOP{}, RevokeMethod},
{"sshpop/sshSign", &SSHPOP{}, SSHSignMethod},
{"k8ssa/sshRekey", &K8sSA{}, SSHRekeyMethod},
{"k8ssa/sshRenew", &K8sSA{}, SSHRenewMethod},
{"k8ssa/sshRevoke", &K8sSA{}, SSHRevokeMethod},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var (
err error
msg string
)
switch tt.method {
case SignMethod:
var signOpts []SignOption
signOpts, err = tt.p.AuthorizeSign(context.Background(), "")
assert.Nil(t, signOpts)
msg = "provisioner.AuthorizeSign not implemented"
case RenewMethod:
err = tt.p.AuthorizeRenew(context.Background(), nil)
msg = "provisioner.AuthorizeRenew not implemented"
case RevokeMethod:
err = tt.p.AuthorizeRevoke(context.Background(), "")
msg = "provisioner.AuthorizeRevoke not implemented"
case SSHSignMethod:
var signOpts []SignOption
signOpts, err = tt.p.AuthorizeSSHSign(context.Background(), "")
assert.Nil(t, signOpts)
msg = "provisioner.AuthorizeSSHSign not implemented"
case SSHRenewMethod:
var cert *ssh.Certificate
cert, err = tt.p.AuthorizeSSHRenew(context.Background(), "")
assert.Nil(t, cert)
msg = "provisioner.AuthorizeSSHRenew not implemented"
case SSHRekeyMethod:
var (
cert *ssh.Certificate
signOpts []SignOption
)
cert, signOpts, err = tt.p.AuthorizeSSHRekey(context.Background(), "")
assert.Nil(t, cert)
assert.Nil(t, signOpts)
msg = "provisioner.AuthorizeSSHRekey not implemented"
case SSHRevokeMethod:
err = tt.p.AuthorizeSSHRevoke(context.Background(), "")
msg = "provisioner.AuthorizeSSHRevoke not implemented"
default:
t.Errorf("unexpected method %s", tt.method)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), http.StatusUnauthorized)
assert.Equals(t, err.Error(), msg)
})
}
}

View file

@ -30,7 +30,7 @@ type SignOption interface{}
// CertificateValidator is the interface used to validate a X.509 certificate.
type CertificateValidator interface {
SignOption
Valid(crt *x509.Certificate) error
Valid(cert *x509.Certificate, o Options) error
}
// CertificateRequestValidator is the interface used to validate a X.509
@ -106,7 +106,7 @@ func (v commonNameValidator) Valid(req *x509.CertificateRequest) error {
return errors.New("certificate request cannot contain an empty common name")
}
if req.Subject.CommonName != string(v) {
return errors.Errorf("certificate request does not contain the valid common name, got %s, want %s", req.Subject.CommonName, v)
return errors.Errorf("certificate request does not contain the valid common name; requested common name = %s, token subject = %s", req.Subject.CommonName, v)
}
return nil
}
@ -265,35 +265,32 @@ func newValidityValidator(min, max time.Duration) *validityValidator {
// Valid validates the certificate validity settings (notBefore/notAfter) and
// and total duration.
func (v *validityValidator) Valid(crt *x509.Certificate) error {
func (v *validityValidator) Valid(cert *x509.Certificate, o Options) error {
var (
na = crt.NotAfter.Truncate(time.Second)
nb = crt.NotBefore.Truncate(time.Second)
na = cert.NotAfter.Truncate(time.Second)
nb = cert.NotBefore.Truncate(time.Second)
now = time.Now().Truncate(time.Second)
)
// To not take into account the backdate, time.Now() will be used to
// calculate the duration if NotBefore is in the past.
var d time.Duration
if now.After(nb) {
d = na.Sub(now)
} else {
d = na.Sub(nb)
}
d := na.Sub(nb)
if na.Before(now) {
return errors.Errorf("NotAfter: %v cannot be in the past", na)
return errors.Errorf("notAfter cannot be in the past; na=%v", na)
}
if na.Before(nb) {
return errors.Errorf("NotAfter: %v cannot be before NotBefore: %v", na, nb)
return errors.Errorf("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb)
}
if d < v.min {
return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v",
d, v.min)
}
if d > v.max {
// NOTE: this check is not "technically correct". We're allowing the max
// duration of a cert to be "max + backdate" and not all certificates will
// be backdated (e.g. if a user passes the NotBefore value then we do not
// apply a backdate). This is good enough.
if d > v.max+o.Backdate {
return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v",
d, v.max)
d, v.max+o.Backdate)
}
return nil
}

View file

@ -3,9 +3,10 @@ package provisioner
import (
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"net"
"net/url"
"reflect"
"strings"
"testing"
"time"
@ -48,22 +49,22 @@ func Test_emailOnlyIdentity_Valid(t *testing.T) {
}
func Test_defaultPublicKeyValidator_Valid(t *testing.T) {
_shortRSA, err := pemutil.Read("./testdata/short-rsa.csr")
_shortRSA, err := pemutil.Read("./testdata/certs/short-rsa.csr")
assert.FatalError(t, err)
shortRSA, ok := _shortRSA.(*x509.CertificateRequest)
assert.Fatal(t, ok)
_rsa, err := pemutil.Read("./testdata/rsa.csr")
_rsa, err := pemutil.Read("./testdata/certs/rsa.csr")
assert.FatalError(t, err)
rsaCSR, ok := _rsa.(*x509.CertificateRequest)
assert.Fatal(t, ok)
_ecdsa, err := pemutil.Read("./testdata/ecdsa.csr")
_ecdsa, err := pemutil.Read("./testdata/certs/ecdsa.csr")
assert.FatalError(t, err)
ecdsaCSR, ok := _ecdsa.(*x509.CertificateRequest)
assert.Fatal(t, ok)
_ed25519, err := pemutil.Read("./testdata/ed25519.csr")
_ed25519, err := pemutil.Read("./testdata/certs/ed25519.csr")
assert.FatalError(t, err)
ed25519CSR, ok := _ed25519.(*x509.CertificateRequest)
assert.Fatal(t, ok)
@ -246,30 +247,191 @@ func Test_ipAddressesValidator_Valid(t *testing.T) {
}
func Test_validityValidator_Valid(t *testing.T) {
type fields struct {
min time.Duration
max time.Duration
type test struct {
cert *x509.Certificate
opts Options
vv *validityValidator
err error
}
type args struct {
crt *x509.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &validityValidator{
min: tt.fields.min,
max: tt.fields.max,
tests := map[string]func() test{
"fail/notAfter-past": func() test {
return test{
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
cert: &x509.Certificate{NotAfter: time.Now().Add(-5 * time.Minute)},
opts: Options{},
err: errors.New("notAfter cannot be in the past"),
}
if err := v.Valid(tt.args.crt); (err != nil) != tt.wantErr {
t.Errorf("validityValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
},
"fail/notBefore-after-notAfter": func() test {
return test{
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
cert: &x509.Certificate{NotBefore: time.Now().Add(10 * time.Minute),
NotAfter: time.Now().Add(5 * time.Minute)},
opts: Options{},
err: errors.New("notAfter cannot be before notBefore"),
}
},
"fail/duration-too-short": func() test {
n := now()
return test{
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
cert: &x509.Certificate{NotBefore: n,
NotAfter: n.Add(3 * time.Minute)},
opts: Options{},
err: errors.New("is less than the authorized minimum certificate duration of "),
}
},
"ok/duration-exactly-min": func() test {
n := now()
return test{
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
cert: &x509.Certificate{NotBefore: n,
NotAfter: n.Add(5 * time.Minute)},
opts: Options{},
}
},
"fail/duration-too-great": func() test {
n := now()
return test{
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
cert: &x509.Certificate{NotBefore: n,
NotAfter: n.Add(24*time.Hour + time.Second)},
err: errors.New("is more than the authorized maximum certificate duration of "),
}
},
"ok/duration-exactly-max": func() test {
n := time.Now()
return test{
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
cert: &x509.Certificate{NotBefore: n,
NotAfter: n.Add(24 * time.Hour)},
}
},
"ok/duration-exact-min-with-backdate": func() test {
now := time.Now()
cert := &x509.Certificate{NotBefore: now, NotAfter: now.Add(5 * time.Minute)}
time.Sleep(time.Second)
return test{
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
cert: cert,
opts: Options{Backdate: time.Second},
}
},
"ok/duration-exact-max-with-backdate": func() test {
backdate := time.Second
now := time.Now()
cert := &x509.Certificate{NotBefore: now, NotAfter: now.Add(24*time.Hour + backdate)}
time.Sleep(backdate)
return test{
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
cert: cert,
opts: Options{Backdate: backdate},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tt := run()
if err := tt.vv.Valid(tt.cert, tt.opts); err != nil {
if assert.NotNil(t, tt.err, fmt.Sprintf("expected no error, but got err = %s", err.Error())) {
assert.True(t, strings.Contains(err.Error(), tt.err.Error()),
fmt.Sprintf("want err = %s, but got err = %s", tt.err.Error(), err.Error()))
}
} else {
assert.Nil(t, tt.err, fmt.Sprintf("expected err = %s, but not <nil>", tt.err))
}
})
}
}
func Test_profileDefaultDuration_Option(t *testing.T) {
type test struct {
so Options
pdd profileDefaultDuration
cert *x509.Certificate
valid func(*x509.Certificate)
}
tests := map[string]func() test{
"ok/notBefore-notAfter-duration-empty": func() test {
return test{
pdd: profileDefaultDuration(0),
so: Options{},
cert: new(x509.Certificate),
valid: func(cert *x509.Certificate) {
n := now()
assert.True(t, n.After(cert.NotBefore))
assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore))
assert.True(t, n.Add(24*time.Hour).After(cert.NotAfter))
assert.True(t, n.Add(24*time.Hour).Add(-1*time.Minute).Before(cert.NotAfter))
},
}
},
"ok/notBefore-set": func() test {
nb := time.Now().Add(5 * time.Minute).UTC()
return test{
pdd: profileDefaultDuration(0),
so: Options{NotBefore: NewTimeDuration(nb)},
cert: new(x509.Certificate),
valid: func(cert *x509.Certificate) {
assert.Equals(t, cert.NotBefore, nb)
assert.Equals(t, cert.NotAfter, nb.Add(24*time.Hour))
},
}
},
"ok/duration-set": func() test {
d := 4 * time.Hour
return test{
pdd: profileDefaultDuration(d),
so: Options{Backdate: time.Second},
cert: new(x509.Certificate),
valid: func(cert *x509.Certificate) {
n := now()
assert.True(t, n.After(cert.NotBefore), fmt.Sprintf("expected now = %s to be after cert.NotBefore = %s", n, cert.NotBefore))
assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore))
assert.True(t, n.Add(d).After(cert.NotAfter))
assert.True(t, n.Add(d).Add(-1*time.Minute).Before(cert.NotAfter))
},
}
},
"ok/notAfter-set": func() test {
na := now().Add(10 * time.Minute).UTC()
return test{
pdd: profileDefaultDuration(0),
so: Options{NotAfter: NewTimeDuration(na)},
cert: new(x509.Certificate),
valid: func(cert *x509.Certificate) {
n := now()
assert.True(t, n.After(cert.NotBefore), fmt.Sprintf("expected now = %s to be after cert.NotBefore = %s", n, cert.NotBefore))
assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore))
assert.Equals(t, cert.NotAfter, na)
},
}
},
"ok/notBefore-and-notAfter-set": func() test {
nb := time.Now().Add(5 * time.Minute).UTC()
na := time.Now().Add(10 * time.Minute).UTC()
d := 4 * time.Hour
return test{
pdd: profileDefaultDuration(d),
so: Options{NotBefore: NewTimeDuration(nb), NotAfter: NewTimeDuration(na)},
cert: new(x509.Certificate),
valid: func(cert *x509.Certificate) {
assert.Equals(t, cert.NotBefore, nb)
assert.Equals(t, cert.NotAfter, na)
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tt := run()
prof := &x509util.Leaf{}
prof.SetSubject(tt.cert)
assert.FatalError(t, tt.pdd.Option(tt.so)(prof), "unexpected error")
tt.valid(prof.Subject())
})
}
}
@ -381,43 +543,3 @@ func Test_profileLimitDuration_Option(t *testing.T) {
})
}
}
func Test_profileDefaultDuration_Option(t *testing.T) {
tm, fn := mockNow()
defer fn()
v := profileDefaultDuration(24 * time.Hour)
type args struct {
so Options
}
tests := []struct {
name string
v profileDefaultDuration
args args
want *x509.Certificate
}{
{"default", v, args{Options{}}, &x509.Certificate{NotBefore: tm, NotAfter: tm.Add(24 * time.Hour)}},
{"backdate", v, args{Options{Backdate: 1 * time.Minute}}, &x509.Certificate{NotBefore: tm.Add(-1 * time.Minute), NotAfter: tm.Add(24 * time.Hour)}},
{"notBefore", v, args{Options{NotBefore: NewTimeDuration(tm.Add(10 * time.Second))}}, &x509.Certificate{NotBefore: tm.Add(10 * time.Second), NotAfter: tm.Add(24*time.Hour + 10*time.Second)}},
{"notAfter", v, args{Options{NotAfter: NewTimeDuration(tm.Add(1 * time.Hour))}}, &x509.Certificate{NotBefore: tm, NotAfter: tm.Add(1 * time.Hour)}},
{"notBefore and notAfter", v, args{Options{NotBefore: NewTimeDuration(tm.Add(10 * time.Second)), NotAfter: NewTimeDuration(tm.Add(1 * time.Hour))}},
&x509.Certificate{NotBefore: tm.Add(10 * time.Second), NotAfter: tm.Add(1 * time.Hour)}},
{"notBefore and backdate", v, args{Options{Backdate: 1 * time.Minute, NotBefore: NewTimeDuration(tm.Add(10 * time.Second))}},
&x509.Certificate{NotBefore: tm.Add(10 * time.Second), NotAfter: tm.Add(24*time.Hour + 10*time.Second)}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cert := &x509.Certificate{}
profile := &x509util.Leaf{}
profile.SetSubject(cert)
fn := tt.v.Option(tt.args.so)
if err := fn(profile); err != nil {
t.Errorf("profileDefaultDuration.Option() error = %v", err)
}
if !reflect.DeepEqual(cert, tt.want) {
t.Errorf("profileDefaultDuration.Option() = %v, \nwant %v", cert, tt.want)
}
})
}
}

View file

@ -19,29 +19,29 @@ const (
SSHHostCert = "host"
)
// SSHCertificateModifier is the interface used to change properties in an SSH
// SSHCertModifier is the interface used to change properties in an SSH
// certificate.
type SSHCertificateModifier interface {
type SSHCertModifier interface {
SignOption
Modify(cert *ssh.Certificate) error
}
// SSHCertificateOptionModifier is the interface used to add custom options used
// SSHCertOptionModifier is the interface used to add custom options used
// to modify the SSH certificate.
type SSHCertificateOptionModifier interface {
type SSHCertOptionModifier interface {
SignOption
Option(o SSHOptions) SSHCertificateModifier
Option(o SSHOptions) SSHCertModifier
}
// SSHCertificateValidator is the interface used to validate an SSH certificate.
type SSHCertificateValidator interface {
// SSHCertValidator is the interface used to validate an SSH certificate.
type SSHCertValidator interface {
SignOption
Valid(cert *ssh.Certificate) error
Valid(cert *ssh.Certificate, opts SSHOptions) error
}
// SSHCertificateOptionsValidator is the interface used to validate the custom
// SSHCertOptionsValidator is the interface used to validate the custom
// options used to modify the SSH certificate.
type SSHCertificateOptionsValidator interface {
type SSHCertOptionsValidator interface {
SignOption
Valid(got SSHOptions) error
}
@ -69,7 +69,7 @@ func (o SSHOptions) Type() uint32 {
return sshCertTypeUInt32(o.CertType)
}
// Modify implements SSHCertificateModifier and sets the SSHOption in the ssh.Certificate.
// Modify implements SSHCertModifier and sets the SSHOption in the ssh.Certificate.
func (o SSHOptions) Modify(cert *ssh.Certificate) error {
switch o.CertType {
case "": // ignore
@ -78,7 +78,7 @@ func (o SSHOptions) Modify(cert *ssh.Certificate) error {
case SSHHostCert:
cert.CertType = ssh.HostCert
default:
return errors.Errorf("ssh certificate has an unknown type: %s", o.CertType)
return errors.Errorf("ssh certificate has an unknown type - %s", o.CertType)
}
cert.KeyId = o.KeyID
@ -116,7 +116,7 @@ func (o SSHOptions) match(got SSHOptions) error {
return nil
}
// sshCertPrincipalsModifier is an SSHCertificateModifier that sets the
// sshCertPrincipalsModifier is an SSHCertModifier that sets the
// principals to the SSH certificate.
type sshCertPrincipalsModifier []string
@ -126,16 +126,16 @@ func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error {
return nil
}
// sshCertificateKeyIDModifier is an SSHCertificateModifier that sets the given
// sshCertKeyIDModifier is an SSHCertModifier that sets the given
// Key ID in the SSH certificate.
type sshCertificateKeyIDModifier string
type sshCertKeyIDModifier string
func (m sshCertificateKeyIDModifier) Modify(cert *ssh.Certificate) error {
func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate) error {
cert.KeyId = string(m)
return nil
}
// sshCertTypeModifier is an SSHCertificateModifier that sets the
// sshCertTypeModifier is an SSHCertModifier that sets the
// certificate type.
type sshCertTypeModifier string
@ -145,30 +145,30 @@ func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error {
return nil
}
// sshCertificateValidAfterModifier is an SSHCertificateModifier that sets the
// sshCertValidAfterModifier is an SSHCertModifier that sets the
// ValidAfter in the SSH certificate.
type sshCertificateValidAfterModifier uint64
type sshCertValidAfterModifier uint64
func (m sshCertificateValidAfterModifier) Modify(cert *ssh.Certificate) error {
func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate) error {
cert.ValidAfter = uint64(m)
return nil
}
// sshCertificateValidBeforeModifier is an SSHCertificateModifier that sets the
// sshCertValidBeforeModifier is an SSHCertModifier that sets the
// ValidBefore in the SSH certificate.
type sshCertificateValidBeforeModifier uint64
type sshCertValidBeforeModifier uint64
func (m sshCertificateValidBeforeModifier) Modify(cert *ssh.Certificate) error {
func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate) error {
cert.ValidBefore = uint64(m)
return nil
}
// sshCertificateDefaultModifier implements a SSHCertificateModifier that
// sshCertDefaultsModifier implements a SSHCertModifier that
// modifies the certificate with the given options if they are not set.
type sshCertificateDefaultsModifier SSHOptions
type sshCertDefaultsModifier SSHOptions
// Modify implements the SSHCertificateModifier interface.
func (m sshCertificateDefaultsModifier) Modify(cert *ssh.Certificate) error {
// Modify implements the SSHCertModifier interface.
func (m sshCertDefaultsModifier) Modify(cert *ssh.Certificate) error {
if cert.CertType == 0 {
cert.CertType = sshCertTypeUInt32(m.CertType)
}
@ -184,7 +184,7 @@ func (m sshCertificateDefaultsModifier) Modify(cert *ssh.Certificate) error {
return nil
}
// sshDefaultExtensionModifier implements an SSHCertificateModifier that sets
// sshDefaultExtensionModifier implements an SSHCertModifier that sets
// the default extensions in an SSH certificate.
type sshDefaultExtensionModifier struct{}
@ -208,14 +208,14 @@ func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate) error {
}
}
// sshDefaultDuration is an SSHCertificateModifier that sets the certificate
// sshDefaultDuration is an SSHCertModifier that sets the certificate
// ValidAfter and ValidBefore if they have not been set. It will fail if a
// CertType has not been set or is not valid.
type sshDefaultDuration struct {
*Claimer
}
func (m *sshDefaultDuration) Option(o SSHOptions) SSHCertificateModifier {
func (m *sshDefaultDuration) Option(o SSHOptions) SSHCertModifier {
return sshModifierFunc(func(cert *ssh.Certificate) error {
d, err := m.DefaultSSHCertDuration(cert.CertType)
if err != nil {
@ -248,7 +248,7 @@ type sshLimitDuration struct {
NotAfter time.Time
}
func (m *sshLimitDuration) Option(o SSHOptions) SSHCertificateModifier {
func (m *sshLimitDuration) Option(o SSHOptions) SSHCertModifier {
if m.NotAfter.IsZero() {
defaultDuration := &sshDefaultDuration{m.Claimer}
return defaultDuration.Option(o)
@ -295,22 +295,22 @@ func (m *sshLimitDuration) Option(o SSHOptions) SSHCertificateModifier {
})
}
// sshCertificateOptionsValidator validates the user SSHOptions with the ones
// sshCertOptionsValidator validates the user SSHOptions with the ones
// usually present in the token.
type sshCertificateOptionsValidator SSHOptions
type sshCertOptionsValidator SSHOptions
// Valid implements SSHCertificateOptionsValidator and returns nil if both
// Valid implements SSHCertOptionsValidator and returns nil if both
// SSHOptions match.
func (v sshCertificateOptionsValidator) Valid(got SSHOptions) error {
func (v sshCertOptionsValidator) Valid(got SSHOptions) error {
want := SSHOptions(v)
return want.match(got)
}
type sshCertificateValidityValidator struct {
type sshCertValidityValidator struct {
*Claimer
}
func (v *sshCertificateValidityValidator) Valid(cert *ssh.Certificate) error {
func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SSHOptions) error {
switch {
case cert.ValidAfter == 0:
return errors.New("ssh certificate validAfter cannot be 0")
@ -336,31 +336,26 @@ func (v *sshCertificateValidityValidator) Valid(cert *ssh.Certificate) error {
// To not take into account the backdate, time.Now() will be used to
// calculate the duration if ValidAfter is in the past.
var dur time.Duration
if t := now().Unix(); t > int64(cert.ValidAfter) {
dur = time.Duration(int64(cert.ValidBefore)-t) * time.Second
} else {
dur = time.Duration(cert.ValidBefore-cert.ValidAfter) * time.Second
}
dur := time.Duration(cert.ValidBefore-cert.ValidAfter) * time.Second
switch {
case dur < min:
return errors.Errorf("requested duration of %s is less than minimum "+
"accepted duration for selected provisioner of %s", dur, min)
case dur > max:
case dur > max+opts.Backdate:
return errors.Errorf("requested duration of %s is greater than maximum "+
"accepted duration for selected provisioner of %s", dur, max)
"accepted duration for selected provisioner of %s", dur, max+opts.Backdate)
default:
return nil
}
}
// sshCertificateDefaultValidator implements a simple validator for all the
// sshCertDefaultValidator implements a simple validator for all the
// fields in the SSH certificate.
type sshCertificateDefaultValidator struct{}
type sshCertDefaultValidator struct{}
// Valid returns an error if the given certificate does not contain the necessary fields.
func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error {
func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, o SSHOptions) error {
switch {
case len(cert.Nonce) == 0:
return errors.New("ssh certificate nonce cannot be empty")
@ -395,7 +390,7 @@ func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error {
type sshDefaultPublicKeyValidator struct{}
// Valid checks that certificate request common name matches the one configured.
func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate) error {
func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SSHOptions) error {
if cert.Key == nil {
return errors.New("ssh certificate key cannot be nil")
}
@ -425,7 +420,7 @@ func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate) error {
type sshCertKeyIDValidator string
// Valid returns an error if the given certificate does not contain the necessary fields.
func (v sshCertKeyIDValidator) Valid(cert *ssh.Certificate) error {
func (v sshCertKeyIDValidator) Valid(cert *ssh.Certificate, o SSHOptions) error {
if string(v) != cert.KeyId {
return errors.Errorf("invalid ssh certificate KeyId; want %s, but got %s", string(v), cert.KeyId)
}

View file

@ -38,12 +38,463 @@ func TestSSHOptions_Type(t *testing.T) {
}
}
func Test_sshCertificateDefaultValidator_Valid(t *testing.T) {
func TestSSHOptions_Modify(t *testing.T) {
type test struct {
so *SSHOptions
cert *ssh.Certificate
valid func(*ssh.Certificate)
err error
}
tests := map[string](func() test){
"fail/unexpected-cert-type": func() test {
return test{
so: &SSHOptions{CertType: "foo"},
cert: new(ssh.Certificate),
err: errors.Errorf("ssh certificate has an unknown type - foo"),
}
},
"fail/validAfter-greater-validBefore": func() test {
return test{
so: &SSHOptions{CertType: "user"},
cert: &ssh.Certificate{ValidAfter: uint64(15), ValidBefore: uint64(10)},
err: errors.Errorf("ssh certificate valid after cannot be greater than valid before"),
}
},
"ok/user-cert": func() test {
return test{
so: &SSHOptions{CertType: "user"},
cert: new(ssh.Certificate),
valid: func(cert *ssh.Certificate) {
assert.Equals(t, cert.CertType, uint32(ssh.UserCert))
},
}
},
"ok/host-cert": func() test {
return test{
so: &SSHOptions{CertType: "host"},
cert: new(ssh.Certificate),
valid: func(cert *ssh.Certificate) {
assert.Equals(t, cert.CertType, uint32(ssh.HostCert))
},
}
},
"ok": func() test {
va := time.Now().Add(5 * time.Minute)
vb := time.Now().Add(1 * time.Hour)
so := &SSHOptions{CertType: "host", KeyID: "foo", Principals: []string{"foo", "bar"},
ValidAfter: NewTimeDuration(va), ValidBefore: NewTimeDuration(vb)}
return test{
so: so,
cert: new(ssh.Certificate),
valid: func(cert *ssh.Certificate) {
assert.Equals(t, cert.CertType, uint32(ssh.HostCert))
assert.Equals(t, cert.KeyId, so.KeyID)
assert.Equals(t, cert.ValidPrincipals, so.Principals)
assert.Equals(t, cert.ValidAfter, uint64(so.ValidAfter.RelativeTime(time.Now()).Unix()))
assert.Equals(t, cert.ValidBefore, uint64(so.ValidBefore.RelativeTime(time.Now()).Unix()))
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if err := tc.so.Modify(tc.cert); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
tc.valid(tc.cert)
}
}
})
}
}
func TestSSHOptions_Match(t *testing.T) {
type test struct {
so SSHOptions
cmp SSHOptions
err error
}
tests := map[string](func() test){
"fail/cert-type": func() test {
return test{
so: SSHOptions{CertType: "foo"},
cmp: SSHOptions{CertType: "bar"},
err: errors.Errorf("ssh certificate type does not match - got bar, want foo"),
}
},
"fail/pricipals": func() test {
return test{
so: SSHOptions{Principals: []string{"foo"}},
cmp: SSHOptions{Principals: []string{"bar"}},
err: errors.Errorf("ssh certificate principals does not match - got [bar], want [foo]"),
}
},
"fail/validAfter": func() test {
return test{
so: SSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute))},
cmp: SSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(5 * time.Minute))},
err: errors.Errorf("ssh certificate valid after does not match"),
}
},
"fail/validBefore": func() test {
return test{
so: SSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(1 * time.Minute))},
cmp: SSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute))},
err: errors.Errorf("ssh certificate valid before does not match"),
}
},
"ok/original-empty": func() test {
return test{
so: SSHOptions{},
cmp: SSHOptions{
CertType: "foo",
Principals: []string{"foo"},
ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute)),
ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute)),
},
}
},
"ok/cmp-empty": func() test {
return test{
cmp: SSHOptions{},
so: SSHOptions{
CertType: "foo",
Principals: []string{"foo"},
ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute)),
ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute)),
},
}
},
"ok/equal": func() test {
n := time.Now()
va := NewTimeDuration(n.Add(1 * time.Minute))
vb := NewTimeDuration(n.Add(5 * time.Minute))
return test{
cmp: SSHOptions{
CertType: "foo",
Principals: []string{"foo"},
ValidAfter: va,
ValidBefore: vb,
},
so: SSHOptions{
CertType: "foo",
Principals: []string{"foo"},
ValidAfter: va,
ValidBefore: vb,
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if err := tc.so.match(tc.cmp); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func Test_sshCertPrincipalsModifier_Modify(t *testing.T) {
type test struct {
modifier sshCertPrincipalsModifier
cert *ssh.Certificate
expected []string
}
tests := map[string](func() test){
"ok": func() test {
a := []string{"foo", "bar"}
return test{
modifier: sshCertPrincipalsModifier(a),
cert: new(ssh.Certificate),
expected: a,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
assert.Equals(t, tc.cert.ValidPrincipals, tc.expected)
}
})
}
}
func Test_sshCertKeyIDModifier_Modify(t *testing.T) {
type test struct {
modifier sshCertKeyIDModifier
cert *ssh.Certificate
expected string
}
tests := map[string](func() test){
"ok": func() test {
a := "foo"
return test{
modifier: sshCertKeyIDModifier(a),
cert: new(ssh.Certificate),
expected: a,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
assert.Equals(t, tc.cert.KeyId, tc.expected)
}
})
}
}
func Test_sshCertTypeModifier_Modify(t *testing.T) {
type test struct {
modifier sshCertTypeModifier
cert *ssh.Certificate
expected uint32
}
tests := map[string](func() test){
"ok/user": func() test {
return test{
modifier: sshCertTypeModifier("user"),
cert: new(ssh.Certificate),
expected: ssh.UserCert,
}
},
"ok/host": func() test {
return test{
modifier: sshCertTypeModifier("host"),
cert: new(ssh.Certificate),
expected: ssh.HostCert,
}
},
"ok/default": func() test {
return test{
modifier: sshCertTypeModifier("foo"),
cert: new(ssh.Certificate),
expected: 0,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
assert.Equals(t, tc.cert.CertType, uint32(tc.expected))
}
})
}
}
func Test_sshCertValidAfterModifier_Modify(t *testing.T) {
type test struct {
modifier sshCertValidAfterModifier
cert *ssh.Certificate
expected uint64
}
tests := map[string](func() test){
"ok": func() test {
return test{
modifier: sshCertValidAfterModifier(15),
cert: new(ssh.Certificate),
expected: 15,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
assert.Equals(t, tc.cert.ValidAfter, tc.expected)
}
})
}
}
func Test_sshCertDefaultsModifier_Modify(t *testing.T) {
type test struct {
modifier sshCertDefaultsModifier
cert *ssh.Certificate
valid func(*ssh.Certificate)
}
tests := map[string](func() test){
"ok/changes": func() test {
n := time.Now()
va := NewTimeDuration(n.Add(1 * time.Minute))
vb := NewTimeDuration(n.Add(5 * time.Minute))
so := SSHOptions{
Principals: []string{"foo", "bar"},
CertType: "host",
ValidAfter: va,
ValidBefore: vb,
}
return test{
modifier: sshCertDefaultsModifier(so),
cert: new(ssh.Certificate),
valid: func(cert *ssh.Certificate) {
assert.Equals(t, cert.ValidPrincipals, so.Principals)
assert.Equals(t, cert.CertType, uint32(ssh.HostCert))
assert.Equals(t, cert.ValidAfter, uint64(so.ValidAfter.RelativeTime(time.Now()).Unix()))
assert.Equals(t, cert.ValidBefore, uint64(so.ValidBefore.RelativeTime(time.Now()).Unix()))
},
}
},
"ok/no-changes": func() test {
n := time.Now()
so := SSHOptions{
Principals: []string{"foo", "bar"},
CertType: "host",
ValidAfter: NewTimeDuration(n.Add(15 * time.Minute)),
ValidBefore: NewTimeDuration(n.Add(25 * time.Minute)),
}
return test{
modifier: sshCertDefaultsModifier(so),
cert: &ssh.Certificate{
CertType: uint32(ssh.UserCert),
ValidPrincipals: []string{"zap", "zoop"},
ValidAfter: 15,
ValidBefore: 25,
},
valid: func(cert *ssh.Certificate) {
assert.Equals(t, cert.ValidPrincipals, []string{"zap", "zoop"})
assert.Equals(t, cert.CertType, uint32(ssh.UserCert))
assert.Equals(t, cert.ValidAfter, uint64(15))
assert.Equals(t, cert.ValidBefore, uint64(25))
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
tc.valid(tc.cert)
}
})
}
}
func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
type test struct {
modifier sshDefaultExtensionModifier
cert *ssh.Certificate
valid func(*ssh.Certificate)
err error
}
tests := map[string](func() test){
"fail/unexpected-cert-type": func() test {
cert := &ssh.Certificate{CertType: 3}
return test{
modifier: sshDefaultExtensionModifier{},
cert: cert,
err: errors.New("ssh certificate type has not been set or is invalid"),
}
},
"ok/host": func() test {
cert := &ssh.Certificate{CertType: ssh.HostCert}
return test{
modifier: sshDefaultExtensionModifier{},
cert: cert,
valid: func(cert *ssh.Certificate) {
assert.Len(t, 0, cert.Extensions)
},
}
},
"ok/user/extensions-exists": func() test {
cert := &ssh.Certificate{CertType: ssh.UserCert, Permissions: ssh.Permissions{Extensions: map[string]string{
"foo": "bar",
}}}
return test{
modifier: sshDefaultExtensionModifier{},
cert: cert,
valid: func(cert *ssh.Certificate) {
val, ok := cert.Extensions["foo"]
assert.True(t, ok)
assert.Equals(t, val, "bar")
val, ok = cert.Extensions["permit-X11-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-agent-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-port-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-pty"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-user-rc"]
assert.True(t, ok)
assert.Equals(t, val, "")
},
}
},
"ok/user/no-extensions": func() test {
return test{
modifier: sshDefaultExtensionModifier{},
cert: &ssh.Certificate{CertType: ssh.UserCert},
valid: func(cert *ssh.Certificate) {
_, ok := cert.Extensions["foo"]
assert.False(t, ok)
val, ok := cert.Extensions["permit-X11-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-agent-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-port-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-pty"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-user-rc"]
assert.True(t, ok)
assert.Equals(t, val, "")
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if err := tc.modifier.Modify(tc.cert); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
tc.valid(tc.cert)
}
}
})
}
}
func Test_sshCertDefaultValidator_Valid(t *testing.T) {
pub, _, err := keys.GenerateDefaultKeyPair()
assert.FatalError(t, err)
sshPub, err := ssh.NewPublicKey(pub)
assert.FatalError(t, err)
v := sshCertificateDefaultValidator{}
v := sshCertDefaultValidator{}
tests := []struct {
name string
cert *ssh.Certificate
@ -208,7 +659,7 @@ func Test_sshCertificateDefaultValidator_Valid(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := v.Valid(tt.cert); err != nil {
if err := v.Valid(tt.cert, SSHOptions{}); err != nil {
if assert.NotNil(t, tt.err) {
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
@ -219,34 +670,39 @@ func Test_sshCertificateDefaultValidator_Valid(t *testing.T) {
}
}
func Test_sshCertificateValidityValidator(t *testing.T) {
func Test_sshCertValidityValidator(t *testing.T) {
p, err := generateX5C(nil)
assert.FatalError(t, err)
v := sshCertificateValidityValidator{p.claimer}
v := sshCertValidityValidator{p.claimer}
n := now()
tests := []struct {
name string
cert *ssh.Certificate
opts SSHOptions
err error
}{
{
"fail/validAfter-0",
&ssh.Certificate{CertType: ssh.UserCert},
SSHOptions{},
errors.New("ssh certificate validAfter cannot be 0"),
},
{
"fail/validBefore-in-past",
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(-time.Minute).Unix())},
SSHOptions{},
errors.New("ssh certificate validBefore cannot be in the past"),
},
{
"fail/validBefore-before-validAfter",
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Add(5 * time.Minute).Unix()), ValidBefore: uint64(now().Add(3 * time.Minute).Unix())},
SSHOptions{},
errors.New("ssh certificate validBefore cannot be before validAfter"),
},
{
"fail/cert-type-not-set",
&ssh.Certificate{ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(10 * time.Minute).Unix())},
SSHOptions{},
errors.New("ssh certificate type has not been set"),
},
{
@ -256,6 +712,7 @@ func Test_sshCertificateValidityValidator(t *testing.T) {
ValidAfter: uint64(now().Unix()),
ValidBefore: uint64(now().Add(10 * time.Minute).Unix()),
},
SSHOptions{},
errors.New("unknown ssh certificate type 3"),
},
{
@ -265,8 +722,19 @@ func Test_sshCertificateValidityValidator(t *testing.T) {
ValidAfter: uint64(n.Unix()),
ValidBefore: uint64(n.Add(4 * time.Minute).Unix()),
},
SSHOptions{Backdate: time.Second},
errors.New("requested duration of 4m0s is less than minimum accepted duration for selected provisioner of 5m0s"),
},
{
"ok/duration-exactly-min",
&ssh.Certificate{
CertType: 1,
ValidAfter: uint64(n.Unix()),
ValidBefore: uint64(n.Add(5 * time.Minute).Unix()),
},
SSHOptions{Backdate: time.Second},
nil,
},
{
"fail/duration>max",
&ssh.Certificate{
@ -274,7 +742,18 @@ func Test_sshCertificateValidityValidator(t *testing.T) {
ValidAfter: uint64(n.Unix()),
ValidBefore: uint64(n.Add(48 * time.Hour).Unix()),
},
errors.New("requested duration of 48h0m0s is greater than maximum accepted duration for selected provisioner of 24h0m0s"),
SSHOptions{Backdate: time.Second},
errors.New("requested duration of 48h0m0s is greater than maximum accepted duration for selected provisioner of 24h0m1s"),
},
{
"ok/duration-exactly-max",
&ssh.Certificate{
CertType: 1,
ValidAfter: uint64(n.Unix()),
ValidBefore: uint64(n.Add(24*time.Hour + time.Second).Unix()),
},
SSHOptions{Backdate: time.Second},
nil,
},
{
"ok",
@ -283,12 +762,13 @@ func Test_sshCertificateValidityValidator(t *testing.T) {
ValidAfter: uint64(now().Unix()),
ValidBefore: uint64(now().Add(8 * time.Hour).Unix()),
},
SSHOptions{Backdate: time.Second},
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := v.Valid(tt.cert); err != nil {
if err := v.Valid(tt.cert, tt.opts); err != nil {
if assert.NotNil(t, tt.err) {
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
@ -505,7 +985,7 @@ func Test_sshDefaultDuration_Option(t *testing.T) {
{"host backdate", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert}},
&ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(30 * 24 * time.Hour)}, false},
{"user validAfter", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(1 * time.Hour)}},
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(time.Minute), ValidBefore: unix(17 * time.Hour)}, false},
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(time.Hour), ValidBefore: unix(17 * time.Hour)}, false},
{"user validBefore", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidBefore: unix(1 * time.Hour)}},
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(time.Hour)}, false},
{"host validAfter validBefore", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(1 * time.Minute), ValidBefore: unix(2 * time.Minute)}},
@ -541,7 +1021,7 @@ func Test_sshLimitDuration_Option(t *testing.T) {
name string
fields fields
args args
want SSHCertificateModifier
want SSHCertModifier
}{
// TODO: Add test cases.
}

View file

@ -45,22 +45,22 @@ func signSSHCertificate(key crypto.PublicKey, opts SSHOptions, signOpts []SignOp
return nil, err
}
var mods []SSHCertificateModifier
var validators []SSHCertificateValidator
var mods []SSHCertModifier
var validators []SSHCertValidator
for _, op := range signOpts {
switch o := op.(type) {
// modify the ssh.Certificate
case SSHCertificateModifier:
case SSHCertModifier:
mods = append(mods, o)
// modify the ssh.Certificate given the SSHOptions
case SSHCertificateOptionModifier:
case SSHCertOptionModifier:
mods = append(mods, o.Option(opts))
// validate the ssh.Certificate
case SSHCertificateValidator:
case SSHCertValidator:
validators = append(validators, o)
// validate the given SSHOptions
case SSHCertificateOptionsValidator:
case SSHCertOptionsValidator:
if err := o.Valid(opts); err != nil {
return nil, err
}
@ -116,7 +116,7 @@ func signSSHCertificate(key crypto.PublicKey, opts SSHOptions, signOpts []SignOp
// User provisioners validators
for _, v := range validators {
if err := v.Valid(cert); err != nil {
if err := v.Valid(cert, opts); err != nil {
return nil, err
}
}

View file

@ -3,11 +3,13 @@ package provisioner
import (
"context"
"encoding/base64"
"net/http"
"strconv"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose"
"golang.org/x/crypto/ssh"
)
@ -99,33 +101,31 @@ func (p *SSHPOP) Init(config Config) error {
// claims for case specific downstream parsing.
// e.g. a Sign request will auth/validate different fields than a Revoke request.
func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) {
sshCert, err := ExtractSSHPOPCert(token)
sshCert, jwt, err := ExtractSSHPOPCert(token)
if err != nil {
return nil, errors.Wrap(err, "authorizeToken ssh-pop")
return nil, errs.Wrap(http.StatusUnauthorized, err,
"sshpop.authorizeToken; error extracting sshpop header from token")
}
// Check for revocation.
if isRevoked, err := p.db.IsSSHRevoked(strconv.FormatUint(sshCert.Serial, 10)); err != nil {
return nil, errors.Wrap(err, "authorizeToken ssh-pop")
return nil, errs.Wrap(http.StatusInternalServerError, err,
"sshpop.authorizeToken; error checking checking sshpop cert revocation")
} else if isRevoked {
return nil, errors.New("authorizeToken ssh-pop: ssh certificate has been revoked")
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate is revoked")
}
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, errors.Wrapf(err, "error parsing token")
}
// Check validity period of the certificate.
n := time.Now()
if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) {
return nil, errors.New("sshpop certificate validAfter is in the future")
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future")
}
if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) {
return nil, errors.New("sshpop certificate validBefore is in the past")
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past")
}
sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey)
if !ok {
return nil, errors.New("ssh public key could not be cast to ssh CryptoPublicKey")
return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey")
}
pubKey := sshCryptoPubKey.CryptoPublicKey()
@ -146,7 +146,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
}
}
if !found {
return nil, errors.New("error: provisioner could could not verify the sshpop header certificate")
return nil, errs.Unauthorized("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate")
}
// Using the ssh certificates key to validate the claims accomplishes two
@ -156,7 +156,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
// 2. Asserts that the claims are valid - have not been tampered with.
var claims sshPOPPayload
if err = jwt.Claims(pubKey, &claims); err != nil {
return nil, errors.Wrap(err, "error parsing claims")
return nil, errs.Wrap(http.StatusUnauthorized, err, "sshpop.authorizeToken; error parsing sshpop token claims")
}
// According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -165,16 +165,17 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
Issuer: p.Name,
Time: time.Now().UTC(),
}, time.Minute); err != nil {
return nil, errors.Wrapf(err, "invalid token")
return nil, errs.Wrap(http.StatusUnauthorized, err, "sshpop.authorizeToken; invalid sshpop token")
}
// validate audiences with the defaults
if !matchesAudience(claims.Audience, audiences) {
return nil, errors.New("invalid token: invalid audience claim (aud)")
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop token has invalid audience "+
"claim (aud): expected %s, but got %s", audiences, claims.Audience)
}
if claims.Subject == "" {
return nil, errors.New("token subject cannot be empty")
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop token subject cannot be empty")
}
claims.sshCert = sshCert
@ -186,12 +187,13 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
claims, err := p.authorizeToken(token, p.audiences.SSHRevoke)
if err != nil {
return err
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
}
if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) {
return errors.New("token subject must be equivalent to certificate serial number")
return errs.BadRequest("sshpop.AuthorizeSSHRevoke; sshpop token subject " +
"must be equivalent to sshpop certificate serial number")
}
return err
return nil
}
// AuthorizeSSHRenew validates the authorization token and extracts/validates
@ -199,10 +201,10 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
claims, err := p.authorizeToken(token, p.audiences.SSHRenew)
if err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew")
}
if claims.sshCert.CertType != ssh.HostCert {
return nil, errors.New("sshpop AuthorizeSSHRenew: sshpop certificate must be a host ssh certificate")
return nil, errs.BadRequest("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate")
}
return claims.sshCert, nil
@ -214,51 +216,52 @@ func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Cert
func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.SSHRekey)
if err != nil {
return nil, nil, err
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")
}
if claims.sshCert.CertType != ssh.HostCert {
return nil, nil, errors.New("sshpop AuthorizeSSHRekey: sshpop certificate must be a host ssh certificate")
return nil, nil, errs.BadRequest("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate")
}
return claims.sshCert, []SignOption{
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertificateValidityValidator{p.claimer},
&sshCertValidityValidator{p.claimer},
// Require and validate all the default fields in the SSH certificate.
&sshCertificateDefaultValidator{},
&sshCertDefaultValidator{},
}, nil
}
// ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate
// in the sshpop header. If the header is missing, an error is returned.
func ExtractSSHPOPCert(token string) (*ssh.Certificate, error) {
func ExtractSSHPOPCert(token string) (*ssh.Certificate, *jose.JSONWebToken, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, errors.Wrapf(err, "error parsing token")
return nil, nil, errors.Wrapf(err, "extractSSHPOPCert; error parsing token")
}
encodedSSHCert, ok := jwt.Headers[0].ExtraHeaders["sshpop"]
if !ok {
return nil, errors.New("token missing sshpop header")
return nil, nil, errors.New("extractSSHPOPCert; token missing sshpop header")
}
encodedSSHCertStr, ok := encodedSSHCert.(string)
if !ok {
return nil, errors.New("error unexpected type for sshpop header")
return nil, nil, errors.Errorf("extractSSHPOPCert; error unexpected type for sshpop header: "+
"want 'string', but got '%T'", encodedSSHCert)
}
sshCertBytes, err := base64.StdEncoding.DecodeString(encodedSSHCertStr)
if err != nil {
return nil, errors.Wrap(err, "error decoding sshpop header")
return nil, nil, errors.Wrap(err, "extractSSHPOPCert; error base64 decoding sshpop header")
}
sshPub, err := ssh.ParsePublicKey(sshCertBytes)
if err != nil {
return nil, errors.Wrap(err, "error parsing ssh public key")
return nil, nil, errors.Wrap(err, "extractSSHPOPCert; error parsing ssh public key")
}
sshCert, ok := sshPub.(*ssh.Certificate)
if !ok {
return nil, errors.New("error converting ssh public key to ssh certificate")
return nil, nil, errors.New("extractSSHPOPCert; error converting ssh public key to ssh certificate")
}
return sshCert, nil
return sshCert, jwt, nil
}
func bytesForSigning(cert *ssh.Certificate) []byte {

View file

@ -0,0 +1,684 @@
package provisioner
import (
"context"
"crypto"
"crypto/rand"
"encoding/base64"
"net/http"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/jose"
"golang.org/x/crypto/ssh"
)
func TestSSHPOP_Getters(t *testing.T) {
p, err := generateSSHPOP()
assert.FatalError(t, err)
id := "sshpop/" + p.Name
if got := p.GetID(); got != id {
t.Errorf("SSHPOP.GetID() = %v, want %v", got, id)
}
if got := p.GetName(); got != p.Name {
t.Errorf("SSHPOP.GetName() = %v, want %v", got, p.Name)
}
if got := p.GetType(); got != TypeSSHPOP {
t.Errorf("SSHPOP.GetType() = %v, want %v", got, TypeSSHPOP)
}
kid, key, ok := p.GetEncryptedKey()
if kid != "" || key != "" || ok == true {
t.Errorf("SSHPOP.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
kid, key, ok, "", "", false)
}
}
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
if err != nil {
return nil, nil, err
}
cert.Key, err = ssh.NewPublicKey(jwk.Public().Key)
if err != nil {
return nil, nil, err
}
if err = cert.SignCert(rand.Reader, signer); err != nil {
return nil, nil, err
}
return cert, jwk, nil
}
func generateSSHPOPToken(p Interface, cert *ssh.Certificate, jwk *jose.JSONWebKey) (string, error) {
return generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
}
func TestSSHPOP_authorizeToken(t *testing.T) {
key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key")
assert.FatalError(t, err)
signer, ok := key.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
sshSigner, err := ssh.NewSignerFromSigner(signer)
assert.FatalError(t, err)
type test struct {
p *SSHPOP
token string
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "),
}
},
"fail/error-revoked-db-check": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, errors.New("force")
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusInternalServerError,
err: errors.New("sshpop.authorizeToken; error checking checking sshpop cert revocation: force"),
}
},
"fail/cert-already-revoked": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return true, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; sshpop certificate is revoked"),
}
},
"fail/cert-not-yet-valid": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{
CertType: ssh.UserCert,
ValidAfter: uint64(time.Now().Add(time.Minute).Unix()),
}, sshSigner)
assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; sshpop certificate validAfter is in the future"),
}
},
"fail/cert-past-validity": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{
CertType: ssh.UserCert,
ValidBefore: uint64(time.Now().Add(-time.Minute).Unix()),
}, sshSigner)
assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; sshpop certificate validBefore is in the past"),
}
},
"fail/no-signer-found": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate"),
}
},
"fail/error-parsing-claims-bad-sig": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, _, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
otherJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, otherJWK)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; error parsing sshpop token claims"),
}
},
"fail/invalid-claims-issuer": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateToken("foo", "bar", testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; invalid sshpop token"),
}
},
"fail/invalid-audience": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), "invalid-aud", "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; sshpop token has invalid audience claim (aud)"),
}
},
"fail/empty-subject": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; sshpop token subject cannot be empty"),
}
},
"ok": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.NotNil(t, claims)
}
}
})
}
}
func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key")
assert.FatalError(t, err)
signer, ok := key.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
sshSigner, err := ssh.NewSignerFromSigner(signer)
assert.FatalError(t, err)
type test struct {
p *SSHPOP
token string
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("sshpop.AuthorizeSSHRevoke: sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "),
}
},
"fail/subject-not-equal-serial": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRevoke[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusBadRequest,
err: errors.New("sshpop.AuthorizeSSHRevoke; sshpop token subject must be equivalent to sshpop certificate serial number"),
}
},
"ok": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRevoke[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key")
assert.FatalError(t, err)
userSigner, ok := key.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh user signing key to crypto signer")
sshUserSigner, err := ssh.NewSignerFromSigner(userSigner)
assert.FatalError(t, err)
hostKey, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
assert.FatalError(t, err)
hostSigner, ok := hostKey.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh host signing key to crypto signer")
sshHostSigner, err := ssh.NewSignerFromSigner(hostSigner)
assert.FatalError(t, err)
type test struct {
p *SSHPOP
token string
cert *ssh.Certificate
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("sshpop.AuthorizeSSHRenew: sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "),
}
},
"fail/not-host-cert": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner)
assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusBadRequest,
err: errors.New("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate"),
}
},
"ok": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner)
assert.FatalError(t, err)
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRenew[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
cert: cert,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if cert, err := tc.p.AuthorizeSSHRenew(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.cert.Nonce, cert.Nonce)
}
}
})
}
}
func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key")
assert.FatalError(t, err)
userSigner, ok := key.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh user signing key to crypto signer")
sshUserSigner, err := ssh.NewSignerFromSigner(userSigner)
assert.FatalError(t, err)
hostKey, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
assert.FatalError(t, err)
hostSigner, ok := hostKey.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh host signing key to crypto signer")
sshHostSigner, err := ssh.NewSignerFromSigner(hostSigner)
assert.FatalError(t, err)
type test struct {
p *SSHPOP
token string
cert *ssh.Certificate
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("sshpop.AuthorizeSSHRekey: sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "),
}
},
"fail/not-host-cert": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner)
assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRekey[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusBadRequest,
err: errors.New("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate"),
}
},
"ok": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner)
assert.FatalError(t, err)
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRekey[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
cert: cert,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if cert, opts, err := tc.p.AuthorizeSSHRekey(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Len(t, 3, opts)
for _, o := range opts {
switch v := o.(type) {
case *sshDefaultPublicKeyValidator:
case *sshCertDefaultValidator:
case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer)
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
}
}
assert.Equals(t, tc.cert.Nonce, cert.Nonce)
}
}
})
}
}
func TestSSHPOP_ExtractSSHPOPCert(t *testing.T) {
hostKey, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
assert.FatalError(t, err)
hostSigner, ok := hostKey.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh host signing key to crypto signer")
sshHostSigner, err := ssh.NewSignerFromSigner(hostSigner)
assert.FatalError(t, err)
type test struct {
token string
cert *ssh.Certificate
jwk *jose.JSONWebKey
err error
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
return test{
token: "foo",
err: errors.New("extractSSHPOPCert; error parsing token"),
}
},
"fail/sshpop-missing": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
tok, err := generateToken("sub", "sshpop-provisioner", testAudiences.SSHRekey[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk)
assert.FatalError(t, err)
return test{
token: tok,
err: errors.New("extractSSHPOPCert; token missing sshpop header"),
}
},
"fail/wrong-sshpop-type": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, func(so *jose.SignerOptions) error {
so.WithHeader("sshpop", 12345)
return nil
})
assert.FatalError(t, err)
return test{
token: tok,
err: errors.New("extractSSHPOPCert; error unexpected type for sshpop header: "),
}
},
"fail/base64decode-error": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, func(so *jose.SignerOptions) error {
so.WithHeader("sshpop", "!@#$%^&*")
return nil
})
assert.FatalError(t, err)
return test{
token: tok,
err: errors.New("extractSSHPOPCert; error base64 decoding sshpop header: illegal base64"),
}
},
"fail/parsing-sshpop-pubkey": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, func(so *jose.SignerOptions) error {
so.WithHeader("sshpop", base64.StdEncoding.EncodeToString([]byte("foo")))
return nil
})
assert.FatalError(t, err)
return test{
token: tok,
err: errors.New("extractSSHPOPCert; error parsing ssh public key"),
}
},
"ok": func(t *testing.T) test {
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner)
assert.FatalError(t, err)
tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
token: tok,
jwk: jwk,
cert: cert,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if cert, jwt, err := ExtractSSHPOPCert(tc.token); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.cert.Nonce, cert.Nonce)
assert.Equals(t, tc.jwk.KeyID, jwt.Headers[0].KeyID)
}
}
})
}
}

View file

@ -0,0 +1 @@
ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJj80EJXJR9vxefhdqOLSdzRzBw24t9YKPxb+eCYLf7BU50pJQnB/jK2ZM3qLFbieLaYjngZ86T4DzHxlPAnlAY=

View file

@ -0,0 +1 @@
ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJ8einS88ZaWpcTZG27D5N9JDKfGv0rzjDByLGsZzMsLYl3XcsN9IWKXB6b+5GJ3UaoZf/pFxzRzIdDIh7Ypw3Y=

View file

@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIHzAUYu3h8e1gL5ONGZo+lghJJa9rl1TvP2UlqDXazxvoAoGCCqGSM49
AwEHoUQDQgAEOLScS+1Yzmqdyots9lSC0tzTSXUXEgyOD9wYrQ0BqnVZtBXlQw1p
m3fnF/7Ehl6bD1YZWjrF1t+IBZQMq1uBBw==
-----END EC PRIVATE KEY-----

View file

@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEINWGD2xneE43YeytQzORItISxv6d/oH+9TXvDKHo6TyXoAoGCCqGSM49
AwEHoUQDQgAEVK/EtXgVV7+7ppnQSjCtI5qb/gIGnQUF4i//F/JKKho7kRNyMDSn
BP3kndiv8Yfxg4PsyIRY5ZofbEo5eJE6bg==
-----END EC PRIVATE KEY-----

View file

@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIKZCgb5pTSSCbr/xcHCOkl9O6tQtZmNahr3Ap3/c2nBLoAoGCCqGSM49
AwEHoUQDQgAEmPzQQlclH2/F5+F2o4tJ3NHMHDbi31go/Fv54Jgt/sFTnSklCcH+
MrZkzeosVuJ4tpiOeBnzpPgPMfGU8CeUBg==
-----END EC PRIVATE KEY-----

View file

@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIDuzykyPM6rLnSoyF4jnOpPAlyKZERqtaB8PTh179DMgoAoGCCqGSM49
AwEHoUQDQgAEnx6KdLzxlpalxNkbbsPk30kMp8a/SvOMMHIsaxnMywtiXddyw30h
YpcHpv7kYndRqhl/+kXHNHMh0MiHtinDdg==
-----END EC PRIVATE KEY-----

View file

@ -19,6 +19,7 @@ import (
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/randutil"
"github.com/smallstep/cli/jose"
"golang.org/x/crypto/ssh"
)
var (
@ -47,24 +48,6 @@ var (
}
)
func provisionerClaims() *Claims {
ddr := false
des := true
return &Claims{
MinTLSDur: &Duration{5 * time.Minute},
MaxTLSDur: &Duration{24 * time.Hour},
DefaultTLSDur: &Duration{24 * time.Hour},
DisableRenewal: &ddr,
MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &Duration{Duration: 4 * time.Hour},
MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &des,
}
}
const awsTestCertificate = `-----BEGIN CERTIFICATE-----
MIICFTCCAX6gAwIBAgIRAKmbVVYAl/1XEqRfF3eJ97MwDQYJKoZIhvcNAQELBQAw
GDEWMBQGA1UEAxMNQVdTIFRlc3QgQ2VydDAeFw0xOTA0MjQyMjU3MzlaFw0yOTA0
@ -204,7 +187,7 @@ func generateJWK() (*JWK, error) {
}
func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
fooPubB, err := ioutil.ReadFile("./testdata/foo.pub")
fooPubB, err := ioutil.ReadFile("./testdata/certs/foo.pub")
if err != nil {
return nil, err
}
@ -212,7 +195,7 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
if err != nil {
return nil, err
}
barPubB, err := ioutil.ReadFile("./testdata/bar.pub")
barPubB, err := ioutil.ReadFile("./testdata/certs/bar.pub")
if err != nil {
return nil, err
}
@ -240,6 +223,46 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
}, nil
}
func generateSSHPOP() (*SSHPOP, error) {
name, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
userB, err := ioutil.ReadFile("./testdata/certs/ssh_user_ca_key.pub")
if err != nil {
return nil, err
}
userKey, _, _, _, err := ssh.ParseAuthorizedKey(userB)
if err != nil {
return nil, err
}
hostB, err := ioutil.ReadFile("./testdata/certs/ssh_host_ca_key.pub")
if err != nil {
return nil, err
}
hostKey, _, _, _, err := ssh.ParseAuthorizedKey(hostB)
if err != nil {
return nil, err
}
return &SSHPOP{
Name: name,
Type: "SSHPOP",
Claims: &globalProvisionerClaims,
audiences: testAudiences,
claimer: claimer,
sshPubKeys: &SSHKeys{
UserKeys: []ssh.PublicKey{userKey},
HostKeys: []ssh.PublicKey{hostKey},
},
}, nil
}
func generateX5C(root []byte) (*X5C, error) {
if root == nil {
root = []byte(`-----BEGIN CERTIFICATE-----
@ -589,6 +612,13 @@ func withX5CHdr(certs []*x509.Certificate) tokOption {
}
}
func withSSHPOPFile(cert *ssh.Certificate) tokOption {
return func(so *jose.SignerOptions) error {
so.WithHeader("sshpop", base64.StdEncoding.EncodeToString(cert.Marshal()))
return nil
}
}
func generateToken(sub, iss, aud string, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) {
so := new(jose.SignerOptions)
so.WithType("JWT")
@ -630,6 +660,24 @@ func generateToken(sub, iss, aud string, email string, sans []string, iat time.T
return jose.Signed(sig).Claims(claims).CompactSerialize()
}
func generateX5CSSHToken(jwk *jose.JSONWebKey, claims *x5cPayload, tokOpts ...tokOption) (string, error) {
so := new(jose.SignerOptions)
so.WithType("JWT")
so.WithHeader("kid", jwk.KeyID)
for _, o := range tokOpts {
if err := o(so); err != nil {
return "", err
}
}
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so)
if err != nil {
return "", err
}
return jose.Signed(sig).Claims(claims).CompactSerialize()
}
func getK8sSAPayload() *k8sSAPayload {
return &k8sSAPayload{
Claims: jose.Claims{

View file

@ -4,9 +4,11 @@ import (
"context"
"crypto/x509"
"encoding/pem"
"net/http"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/x509util"
"github.com/smallstep/cli/jose"
)
@ -121,19 +123,20 @@ func (p *X5C) Init(config Config) error {
func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, errors.Wrapf(err, "error parsing token")
return nil, errs.Wrap(http.StatusUnauthorized, err, "x5c.authorizeToken; error parsing x5c token")
}
verifiedChains, err := jwt.Headers[0].Certificates(x509.VerifyOptions{
Roots: p.rootPool,
})
if err != nil {
return nil, errors.Wrap(err, "error verifying x5c certificate chain")
return nil, errs.Wrap(http.StatusUnauthorized, err,
"x5c.authorizeToken; error verifying x5c certificate chain in token")
}
leaf := verifiedChains[0][0]
if leaf.KeyUsage&x509.KeyUsageDigitalSignature == 0 {
return nil, errors.New("certificate used to sign x5c token cannot be used for digital signature")
return nil, errs.Unauthorized("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature")
}
// Using the leaf certificates key to validate the claims accomplishes two
@ -143,7 +146,7 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
// 2. Asserts that the claims are valid - have not been tampered with.
var claims x5cPayload
if err = jwt.Claims(leaf.PublicKey, &claims); err != nil {
return nil, errors.Wrap(err, "error parsing claims")
return nil, errs.Wrap(http.StatusUnauthorized, err, "x5c.authorizeToken; error parsing x5c claims")
}
// According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -152,16 +155,17 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
Issuer: p.Name,
Time: time.Now().UTC(),
}, time.Minute); err != nil {
return nil, errors.Wrapf(err, "invalid token")
return nil, errs.Wrapf(http.StatusUnauthorized, err, "x5c.authorizeToken; invalid x5c claims")
}
// validate audiences with the defaults
if !matchesAudience(claims.Audience, audiences) {
return nil, errors.New("invalid token: invalid audience claim (aud)")
return nil, errs.Unauthorized("x5c.authorizeToken; x5c token has invalid audience "+
"claim (aud); expected %s, but got %s", audiences, claims.Audience)
}
if claims.Subject == "" {
return nil, errors.New("token subject cannot be empty")
return nil, errs.Unauthorized("x5c.authorizeToken; x5c token subject cannot be empty")
}
// Save the verified chains on the x5c payload object.
@ -173,14 +177,14 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
// revoke the certificate with serial number in the `sub` property.
func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke)
return err
return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke")
}
// AuthorizeSign validates the given token.
func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign)
if err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign")
}
// NOTE: This is for backwards compatibility with older versions of cli
@ -209,7 +213,7 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// AuthorizeRenew returns an error if the renewal is disabled.
func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID())
}
return nil
}
@ -217,22 +221,22 @@ func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID())
}
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
if err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign")
}
if claims.Step == nil || claims.Step.SSH == nil {
return nil, errors.New("authorization token must be an SSH provisioning token")
return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token")
}
opts := claims.Step.SSH
signOptions := []SignOption{
// validates user's SSHOptions with the ones in the token
sshCertificateOptionsValidator(*opts),
sshCertOptionsValidator(*opts),
}
// Add modifiers from custom claims
@ -245,18 +249,18 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
}
t := now()
if !opts.ValidAfter.IsZero() {
signOptions = append(signOptions, sshCertificateValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))
signOptions = append(signOptions, sshCertValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))
}
if !opts.ValidBefore.IsZero() {
signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
signOptions = append(signOptions, sshCertValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
}
// Make sure to define the the KeyID
if opts.KeyID == "" {
signOptions = append(signOptions, sshCertificateKeyIDModifier(claims.Subject))
signOptions = append(signOptions, sshCertKeyIDModifier(claims.Subject))
}
// Default to a user certificate with no principals if not set
signOptions = append(signOptions, sshCertificateDefaultsModifier{CertType: SSHUserCert})
signOptions = append(signOptions, sshCertDefaultsModifier{CertType: SSHUserCert})
return append(signOptions,
// Set the default extensions.
@ -268,8 +272,8 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate public key.
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertificateValidityValidator{p.claimer},
&sshCertValidityValidator{p.claimer},
// Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{},
&sshCertDefaultValidator{},
), nil
}

View file

@ -2,14 +2,16 @@ package provisioner
import (
"context"
"crypto/x509"
"net"
"net/http"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/randutil"
"github.com/smallstep/cli/jose"
)
@ -151,9 +153,15 @@ M46l92gdOozT
}
func TestX5C_authorizeToken(t *testing.T) {
x5cCerts, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
assert.FatalError(t, err)
x5cJWK, err := jose.ParseKey("./testdata/secrets/x5c-leaf.key")
assert.FatalError(t, err)
type test struct {
p *X5C
token string
code int
err error
}
tests := map[string]func(*testing.T) test{
@ -163,7 +171,8 @@ func TestX5C_authorizeToken(t *testing.T) {
return test{
p: p,
token: "foo",
err: errors.New("error parsing token"),
code: http.StatusUnauthorized,
err: errors.New("x5c.authorizeToken; error parsing x5c token"),
}
},
"fail/invalid-cert-chain": func(t *testing.T) test {
@ -190,7 +199,8 @@ a5wpg+9s6QIgHIW6L60F8klQX+EO3o0SBqLeNcaskA4oSZsKjEdpSGo=
return test{
p: p,
token: tok,
err: errors.New("error verifying x5c certificate chain: x509: certificate signed by unknown authority"),
code: http.StatusUnauthorized,
err: errors.New("x5c.authorizeToken; error verifying x5c certificate chain in token"),
}
},
"fail/doubled-up-self-signed-cert": func(t *testing.T) test {
@ -228,7 +238,8 @@ EXAHTA9L
return test{
p: p,
token: tok,
err: errors.New("error verifying x5c certificate chain: x509: certificate signed by unknown authority"),
code: http.StatusUnauthorized,
err: errors.New("x5c.authorizeToken; error verifying x5c certificate chain in token"),
}
},
"fail/digital-signature-ext-required": func(t *testing.T) test {
@ -269,7 +280,8 @@ lgsqsR63is+0YQ==
return test{
p: p,
token: tok,
err: errors.New("certificate used to sign x5c token cannot be used for digital signature"),
code: http.StatusUnauthorized,
err: errors.New("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature"),
}
},
"fail/signature-does-not-match-x5c-pub-key": func(t *testing.T) test {
@ -309,74 +321,58 @@ lgsqsR63is+0YQ==
return test{
p: p,
token: tok,
err: errors.New("error parsing claims: square/go-jose: error in cryptographic primitive"),
code: http.StatusUnauthorized,
err: errors.New("x5c.authorizeToken; error parsing x5c claims"),
}
},
"fail/invalid-issuer": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
assert.FatalError(t, err)
p, err := generateX5C(nil)
assert.FatalError(t, err)
tok, err := generateToken("", "foobar", testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk,
withX5CHdr(certs))
[]string{"test.smallstep.com"}, time.Now(), x5cJWK,
withX5CHdr(x5cCerts))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
err: errors.New("invalid token: square/go-jose/jwt: validation failed, invalid issuer claim (iss)"),
code: http.StatusUnauthorized,
err: errors.New("x5c.authorizeToken; invalid x5c claims"),
}
},
"fail/invalid-audience": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
assert.FatalError(t, err)
p, err := generateX5C(nil)
assert.FatalError(t, err)
tok, err := generateToken("", p.GetName(), "foobar", "",
[]string{"test.smallstep.com"}, time.Now(), jwk,
withX5CHdr(certs))
[]string{"test.smallstep.com"}, time.Now(), x5cJWK,
withX5CHdr(x5cCerts))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
err: errors.New("invalid token: invalid audience claim (aud)"),
code: http.StatusUnauthorized,
err: errors.New("x5c.authorizeToken; x5c token has invalid audience claim (aud)"),
}
},
"fail/empty-subject": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
assert.FatalError(t, err)
p, err := generateX5C(nil)
assert.FatalError(t, err)
tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk,
withX5CHdr(certs))
[]string{"test.smallstep.com"}, time.Now(), x5cJWK,
withX5CHdr(x5cCerts))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
err: errors.New("token subject cannot be empty"),
code: http.StatusUnauthorized,
err: errors.New("x5c.authorizeToken; x5c token subject cannot be empty"),
}
},
"ok": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
assert.FatalError(t, err)
p, err := generateX5C(nil)
assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk,
withX5CHdr(certs))
[]string{"test.smallstep.com"}, time.Now(), x5cJWK,
withX5CHdr(x5cCerts))
assert.FatalError(t, err)
return test{
p: p,
@ -389,6 +385,9 @@ lgsqsR63is+0YQ==
tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
@ -402,10 +401,15 @@ lgsqsR63is+0YQ==
}
func TestX5C_AuthorizeSign(t *testing.T) {
certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/secrets/x5c-leaf.key")
assert.FatalError(t, err)
type test struct {
p *X5C
token string
ctx context.Context
code int
err error
dns []string
emails []string
@ -418,56 +422,11 @@ func TestX5C_AuthorizeSign(t *testing.T) {
return test{
p: p,
token: "foo",
ctx: NewContextWithMethod(context.Background(), SignMethod),
err: errors.New("error parsing token"),
}
},
"fail/ssh/disabled": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
assert.FatalError(t, err)
p, err := generateX5C(nil)
assert.FatalError(t, err)
p.claimer.claims = provisionerClaims()
*p.claimer.claims.EnableSSHCA = false
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk,
withX5CHdr(certs))
assert.FatalError(t, err)
return test{
p: p,
ctx: NewContextWithMethod(context.Background(), SignSSHMethod),
token: tok,
err: errors.Errorf("ssh ca is disabled for provisioner x5c/%s", p.GetName()),
}
},
"fail/ssh/invalid-token": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
assert.FatalError(t, err)
p, err := generateX5C(nil)
assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk,
withX5CHdr(certs))
assert.FatalError(t, err)
return test{
p: p,
ctx: NewContextWithMethod(context.Background(), SignSSHMethod),
token: tok,
err: errors.New("authorization token must be an SSH provisioning token"),
code: http.StatusUnauthorized,
err: errors.New("x5c.AuthorizeSign: x5c.authorizeToken; error parsing x5c token"),
}
},
"ok/empty-sans": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
assert.FatalError(t, err)
p, err := generateX5C(nil)
assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
@ -476,7 +435,6 @@ func TestX5C_AuthorizeSign(t *testing.T) {
assert.FatalError(t, err)
return test{
p: p,
ctx: NewContextWithMethod(context.Background(), SignMethod),
token: tok,
dns: []string{"foo"},
emails: []string{},
@ -484,11 +442,6 @@ func TestX5C_AuthorizeSign(t *testing.T) {
}
},
"ok/multi-sans": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
assert.FatalError(t, err)
p, err := generateX5C(nil)
assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
@ -497,7 +450,6 @@ func TestX5C_AuthorizeSign(t *testing.T) {
assert.FatalError(t, err)
return test{
p: p,
ctx: NewContextWithMethod(context.Background(), SignMethod),
token: tok,
dns: []string{"foo"},
emails: []string{"max@smallstep.com"},
@ -508,8 +460,11 @@ func TestX5C_AuthorizeSign(t *testing.T) {
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSign(tc.ctx, tc.token); err != nil {
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
@ -554,126 +509,11 @@ func TestX5C_AuthorizeSign(t *testing.T) {
}
}
func TestX5C_AuthorizeSSHSign(t *testing.T) {
_, fn := mockNow()
defer fn()
type test struct {
p *X5C
token string
claims *x5cPayload
err error
}
tests := map[string]func(*testing.T) test{
"fail/no-Step-claim": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
return test{
p: p,
claims: new(x5cPayload),
err: errors.New("authorization token must be an SSH provisioning token"),
}
},
"fail/no-SSH-subattribute-in-claims": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
return test{
p: p,
claims: &x5cPayload{Step: new(stepPayload)},
err: errors.New("authorization token must be an SSH provisioning token"),
}
},
"ok/with-claims": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
return test{
p: p,
claims: &x5cPayload{
Step: &stepPayload{SSH: &SSHOptions{
CertType: SSHHostCert,
Principals: []string{"max", "mariano", "alan"},
ValidAfter: TimeDuration{d: 5 * time.Minute},
ValidBefore: TimeDuration{d: 10 * time.Minute},
}},
Claims: jose.Claims{Subject: "foo"},
chains: [][]*x509.Certificate{certs},
},
}
},
"ok/without-claims": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
return test{
p: p,
claims: &x5cPayload{
Step: &stepPayload{SSH: &SSHOptions{}},
Claims: jose.Claims{Subject: "foo"},
chains: [][]*x509.Certificate{certs},
},
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSSHSign(context.TODO(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) {
tot := 0
nw := now()
for _, o := range opts {
switch v := o.(type) {
case sshCertificateOptionsValidator:
tc.claims.Step.SSH.ValidAfter.t = time.Time{}
tc.claims.Step.SSH.ValidBefore.t = time.Time{}
assert.Equals(t, SSHOptions(v), *tc.claims.Step.SSH)
case sshCertificateKeyIDModifier:
assert.Equals(t, string(v), "foo")
case sshCertTypeModifier:
assert.Equals(t, string(v), tc.claims.Step.SSH.CertType)
case sshCertPrincipalsModifier:
assert.Equals(t, []string(v), tc.claims.Step.SSH.Principals)
case sshCertificateValidAfterModifier:
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix())
case sshCertificateValidBeforeModifier:
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix())
case sshCertificateDefaultsModifier:
assert.Equals(t, SSHOptions(v), SSHOptions{CertType: SSHUserCert})
case *sshLimitDuration:
assert.Equals(t, v.Claimer, tc.p.claimer)
assert.Equals(t, v.NotAfter, tc.claims.chains[0][0].NotAfter)
case *sshCertificateValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer)
case *sshDefaultExtensionModifier, *sshDefaultPublicKeyValidator,
*sshCertificateDefaultValidator:
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
}
tot++
}
if len(tc.claims.Step.SSH.CertType) > 0 {
assert.Equals(t, tot, 12)
} else {
assert.Equals(t, tot, 8)
}
}
}
}
})
}
}
func TestX5C_AuthorizeRevoke(t *testing.T) {
type test struct {
p *X5C
token string
code int
err error
}
tests := map[string]func(*testing.T) test{
@ -683,13 +523,14 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
return test{
p: p,
token: "foo",
err: errors.New("error parsing token"),
code: http.StatusUnauthorized,
err: errors.New("x5c.AuthorizeRevoke: x5c.authorizeToken; error parsing x5c token"),
}
},
"ok": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
jwk, err := jose.ParseKey("./testdata/secrets/x5c-leaf.key")
assert.FatalError(t, err)
p, err := generateX5C(nil)
@ -707,8 +548,11 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeRevoke(context.TODO(), tc.token); err != nil {
if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
@ -719,33 +563,248 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
}
func TestX5C_AuthorizeRenew(t *testing.T) {
p1, err := generateX5C(nil)
assert.FatalError(t, err)
p2, err := generateX5C(nil)
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
cert *x509.Certificate
type test struct {
p *X5C
code int
err error
}
tests := []struct {
name string
prov *X5C
args args
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
tests := map[string]func(*testing.T) test{
"fail/renew-disabled": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
// disable renewal
disable := true
p.Claims = &Claims{DisableRenewal: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
return test{
p: p,
code: http.StatusUnauthorized,
err: errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID()),
}
},
"ok": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
return test{
p: p,
}
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("X5C.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), nil); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestX5C_AuthorizeSSHSign(t *testing.T) {
x5cCerts, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
assert.FatalError(t, err)
x5cJWK, err := jose.ParseKey("./testdata/secrets/x5c-leaf.key")
assert.FatalError(t, err)
_, fn := mockNow()
defer fn()
type test struct {
p *X5C
token string
claims *x5cPayload
code int
err error
}
tests := map[string]func(*testing.T) test{
"fail/sshCA-disabled": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
// disable sshCA
enable := false
p.Claims = &Claims{EnableSSHCA: &enable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID()),
}
},
"fail/invalid-token": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("x5c.AuthorizeSSHSign: x5c.authorizeToken; error parsing x5c token"),
}
},
"fail/no-Step-claim": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHSign[0], "",
[]string{"test.smallstep.com"}, time.Now(), x5cJWK,
withX5CHdr(x5cCerts))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token"),
}
},
"fail/no-SSH-subattribute-in-claims": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
id, err := randutil.ASCII(64)
assert.FatalError(t, err)
now := time.Now()
claims := &x5cPayload{
Claims: jose.Claims{
ID: id,
Subject: "foo",
Issuer: p.GetName(),
IssuedAt: jose.NewNumericDate(now),
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
Audience: []string{testAudiences.SSHSign[0]},
},
Step: &stepPayload{},
}
tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token"),
}
},
"ok/with-claims": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
id, err := randutil.ASCII(64)
assert.FatalError(t, err)
now := time.Now()
claims := &x5cPayload{
Claims: jose.Claims{
ID: id,
Subject: "foo",
Issuer: p.GetName(),
IssuedAt: jose.NewNumericDate(now),
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
Audience: []string{testAudiences.SSHSign[0]},
},
Step: &stepPayload{SSH: &SSHOptions{
CertType: SSHHostCert,
Principals: []string{"max", "mariano", "alan"},
ValidAfter: TimeDuration{d: 5 * time.Minute},
ValidBefore: TimeDuration{d: 10 * time.Minute},
}},
}
tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts))
assert.FatalError(t, err)
return test{
p: p,
claims: claims,
token: tok,
}
},
"ok/without-claims": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
id, err := randutil.ASCII(64)
assert.FatalError(t, err)
now := time.Now()
claims := &x5cPayload{
Claims: jose.Claims{
ID: id,
Subject: "foo",
Issuer: p.GetName(),
IssuedAt: jose.NewNumericDate(now),
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
Audience: []string{testAudiences.SSHSign[0]},
},
Step: &stepPayload{SSH: &SSHOptions{}},
}
tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts))
assert.FatalError(t, err)
return test{
p: p,
claims: claims,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) {
tot := 0
nw := now()
for _, o := range opts {
switch v := o.(type) {
case sshCertOptionsValidator:
tc.claims.Step.SSH.ValidAfter.t = time.Time{}
tc.claims.Step.SSH.ValidBefore.t = time.Time{}
assert.Equals(t, SSHOptions(v), *tc.claims.Step.SSH)
case sshCertKeyIDModifier:
assert.Equals(t, string(v), "foo")
case sshCertTypeModifier:
assert.Equals(t, string(v), tc.claims.Step.SSH.CertType)
case sshCertPrincipalsModifier:
assert.Equals(t, []string(v), tc.claims.Step.SSH.Principals)
case sshCertValidAfterModifier:
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix())
case sshCertValidBeforeModifier:
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix())
case sshCertDefaultsModifier:
assert.Equals(t, SSHOptions(v), SSHOptions{CertType: SSHUserCert})
case *sshLimitDuration:
assert.Equals(t, v.Claimer, tc.p.claimer)
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer)
case *sshDefaultExtensionModifier, *sshDefaultPublicKeyValidator,
*sshCertDefaultValidator:
case sshCertKeyIDValidator:
assert.Equals(t, string(v), "foo")
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
}
tot++
}
if len(tc.claims.Step.SSH.CertType) > 0 {
assert.Equals(t, tot, 13)
} else {
assert.Equals(t, tot, 9)
}
}
}
}
})
}

View file

@ -2,18 +2,16 @@ package authority
import (
"crypto/x509"
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
)
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
key, ok := a.provisioners.LoadEncryptedKey(kid)
if !ok {
return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid),
http.StatusNotFound, apiCtx{}}
return "", errs.NotFound("encrypted key with kid %s was not found", kid)
}
return key, nil
}
@ -30,8 +28,7 @@ func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List,
func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) {
p, ok := a.provisioners.LoadByCertificate(crt)
if !ok {
return nil, &apiError{errors.Errorf("provisioner not found"),
http.StatusNotFound, apiCtx{}}
return nil, errs.NotFound("provisioner not found")
}
return p, nil
}
@ -40,8 +37,7 @@ func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisi
func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
p, ok := a.provisioners.Load(id)
if !ok {
return nil, &apiError{errors.Errorf("provisioner not found"),
http.StatusNotFound, apiCtx{}}
return nil, errs.NotFound("provisioner not found")
}
return p, nil
}

View file

@ -7,13 +7,15 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
)
func TestGetEncryptedKey(t *testing.T) {
type ek struct {
a *Authority
kid string
err *apiError
a *Authority
kid string
err error
code int
}
tests := map[string]func(t *testing.T) *ek{
"ok": func(t *testing.T) *ek {
@ -32,10 +34,10 @@ func TestGetEncryptedKey(t *testing.T) {
a, err := New(c)
assert.FatalError(t, err)
return &ek{
a: a,
kid: "foo",
err: &apiError{errors.Errorf("encrypted key with kid foo was not found"),
http.StatusNotFound, apiCtx{}},
a: a,
kid: "foo",
err: errors.New("encrypted key with kid foo was not found"),
code: http.StatusNotFound,
}
},
}
@ -47,14 +49,10 @@ func TestGetEncryptedKey(t *testing.T) {
ek, err := tc.a.GetEncryptedKey(tc.kid)
if err != nil {
if assert.NotNil(t, tc.err) {
switch v := err.(type) {
case *apiError:
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
assert.Equals(t, v.code, tc.err.code)
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
@ -71,8 +69,9 @@ func TestGetEncryptedKey(t *testing.T) {
func TestGetProvisioners(t *testing.T) {
type gp struct {
a *Authority
err *apiError
a *Authority
err error
code int
}
tests := map[string]func(t *testing.T) *gp{
"ok": func(t *testing.T) *gp {
@ -91,14 +90,10 @@ func TestGetProvisioners(t *testing.T) {
ps, next, err := tc.a.GetProvisioners("", 0)
if err != nil {
if assert.NotNil(t, tc.err) {
switch v := err.(type) {
case *apiError:
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
assert.Equals(t, v.code, tc.err.code)
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {

View file

@ -2,23 +2,20 @@ package authority
import (
"crypto/x509"
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
)
// Root returns the certificate corresponding to the given SHA sum argument.
func (a *Authority) Root(sum string) (*x509.Certificate, error) {
val, ok := a.certificates.Load(sum)
if !ok {
return nil, &apiError{errors.Errorf("certificate with fingerprint %s was not found", sum),
http.StatusNotFound, apiCtx{}}
return nil, errs.NotFound("certificate with fingerprint %s was not found", sum)
}
crt, ok := val.(*x509.Certificate)
if !ok {
return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
http.StatusInternalServerError, apiCtx{}}
return nil, errs.InternalServer("stored value is not a *x509.Certificate")
}
return crt, nil
}
@ -52,8 +49,7 @@ func (a *Authority) GetFederation() (federation []*x509.Certificate, err error)
crt, ok := v.(*x509.Certificate)
if !ok {
federation = nil
err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
http.StatusInternalServerError, apiCtx{}}
err = errs.InternalServer("stored value is not a *x509.Certificate")
return false
}
federation = append(federation, crt)

View file

@ -8,6 +8,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/pemutil"
)
@ -16,12 +17,13 @@ func TestRoot(t *testing.T) {
a.certificates.Store("invaliddata", "a string") // invalid cert for testing
tests := map[string]struct {
sum string
err *apiError
sum string
err error
code int
}{
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, apiCtx{}}},
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, apiCtx{}}},
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil},
"not-found": {"foo", errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound},
"invalid-stored-certificate": {"invaliddata", errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError},
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil, http.StatusOK},
}
for name, tc := range tests {
@ -29,14 +31,10 @@ func TestRoot(t *testing.T) {
crt, err := a.Root(tc.sum)
if err != nil {
if assert.NotNil(t, tc.err) {
switch v := err.(type) {
case *apiError:
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
assert.Equals(t, v.code, tc.err.code)
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {

View file

@ -122,10 +122,7 @@ func (a *Authority) GetSSHFederation() (*SSHKeys, error) {
// GetSSHConfig returns rendered templates for clients (user) or servers (host).
func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) {
if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil {
return nil, &apiError{
err: errors.New("getSSHConfig: ssh is not configured"),
code: http.StatusNotFound,
}
return nil, errs.NotFound("getSSHConfig: ssh is not configured")
}
var ts []templates.Template
@ -139,10 +136,7 @@ func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]template
ts = a.config.Templates.SSH.Host
}
default:
return nil, &apiError{
err: errors.Errorf("getSSHConfig: type %s is not valid", typ),
code: http.StatusBadRequest,
}
return nil, errs.BadRequest("getSSHConfig: type %s is not valid", typ)
}
// Merge user and default data
@ -174,7 +168,8 @@ func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]template
// hostname.
func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error) {
if a.sshBastionFunc != nil {
return a.sshBastionFunc(user, hostname)
bs, err := a.sshBastionFunc(user, hostname)
return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion")
}
if a.config.SSH != nil {
if a.config.SSH.Bastion != nil && a.config.SSH.Bastion.Hostname != "" {
@ -182,32 +177,13 @@ func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error
}
return nil, nil
}
return nil, &apiError{
err: errors.New("getSSHBastion: ssh is not configured"),
code: http.StatusNotFound,
}
}
// authorizeSSHSign loads the provisioner from the token, checks that it has not
// been used again and calls the provisioner AuthorizeSSHSign method. Returns a
// list of methods to apply to the signing flow.
func (a *Authority) authorizeSSHSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
var errContext = apiCtx{"ott": ott}
p, err := a.authorizeToken(ctx, ott)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeSSHSign"), http.StatusUnauthorized, errContext}
}
opts, err := p.AuthorizeSSHSign(ctx, ott)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeSSHSign"), http.StatusUnauthorized, errContext}
}
return opts, nil
return nil, errs.NotFound("authority.GetSSHBastion; ssh is not configured")
}
// SignSSH creates a signed SSH certificate with the given public key and options.
func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var mods []provisioner.SSHCertificateModifier
var validators []provisioner.SSHCertificateValidator
var mods []provisioner.SSHCertModifier
var validators []provisioner.SSHCertValidator
// Set backdate with the configured value
opts.Backdate = a.config.AuthorityConfig.Backdate.Duration
@ -215,38 +191,32 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
for _, op := range signOpts {
switch o := op.(type) {
// modify the ssh.Certificate
case provisioner.SSHCertificateModifier:
case provisioner.SSHCertModifier:
mods = append(mods, o)
// modify the ssh.Certificate given the SSHOptions
case provisioner.SSHCertificateOptionModifier:
case provisioner.SSHCertOptionModifier:
mods = append(mods, o.Option(opts))
// validate the ssh.Certificate
case provisioner.SSHCertificateValidator:
case provisioner.SSHCertValidator:
validators = append(validators, o)
// validate the given SSHOptions
case provisioner.SSHCertificateOptionsValidator:
case provisioner.SSHCertOptionsValidator:
if err := o.Valid(opts); err != nil {
return nil, &apiError{err: err, code: http.StatusForbidden}
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
}
default:
return nil, &apiError{
err: errors.Errorf("signSSH: invalid extra option type %T", o),
code: http.StatusInternalServerError,
}
return nil, errs.InternalServer("signSSH: invalid extra option type %T", o)
}
}
nonce, err := randutil.ASCII(32)
if err != nil {
return nil, &apiError{err: err, code: http.StatusInternalServerError}
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH")
}
var serial uint64
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
return nil, &apiError{
err: errors.Wrap(err, "signSSH: error reading random number"),
code: http.StatusInternalServerError,
}
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error reading random number")
}
// Build base certificate with the key and some random values
@ -258,13 +228,13 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
// Use opts to modify the certificate
if err := opts.Modify(cert); err != nil {
return nil, &apiError{err: err, code: http.StatusForbidden}
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
}
// Use provisioner modifiers
for _, m := range mods {
if err := m.Modify(cert); err != nil {
return nil, &apiError{err: err, code: http.StatusForbidden}
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
}
}
@ -273,25 +243,16 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
switch cert.CertType {
case ssh.UserCert:
if a.sshCAUserCertSignKey == nil {
return nil, &apiError{
err: errors.New("signSSH: user certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
return nil, errs.NotImplemented("signSSH: user certificate signing is not enabled")
}
signer = a.sshCAUserCertSignKey
case ssh.HostCert:
if a.sshCAHostCertSignKey == nil {
return nil, &apiError{
err: errors.New("signSSH: host certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
return nil, errs.NotImplemented("signSSH: host certificate signing is not enabled")
}
signer = a.sshCAHostCertSignKey
default:
return nil, &apiError{
err: errors.Errorf("signSSH: unexpected ssh certificate type: %d", cert.CertType),
code: http.StatusInternalServerError,
}
return nil, errs.InternalServer("signSSH: unexpected ssh certificate type: %d", cert.CertType)
}
cert.SignatureKey = signer.PublicKey()
@ -302,71 +263,38 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
// Sign the certificate
sig, err := signer.Sign(rand.Reader, data)
if err != nil {
return nil, &apiError{
err: errors.Wrap(err, "signSSH: error signing certificate"),
code: http.StatusInternalServerError,
}
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate")
}
cert.Signature = sig
// User provisioners validators
for _, v := range validators {
if err := v.Valid(cert); err != nil {
return nil, &apiError{err: err, code: http.StatusForbidden}
if err := v.Valid(cert, opts); err != nil {
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
}
}
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
return nil, &apiError{
err: errors.Wrap(err, "signSSH: error storing certificate in db"),
code: http.StatusInternalServerError,
}
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error storing certificate in db")
}
return cert, nil
}
// authorizeSSHRenew authorizes an SSH certificate renewal request, by
// validating the contents of an SSHPOP token.
func (a *Authority) authorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
errContext := map[string]interface{}{"ott": token}
p, err := a.authorizeToken(ctx, token)
if err != nil {
return nil, &apiError{
err: errors.Wrap(err, "authorizeSSHRenew"),
code: http.StatusUnauthorized,
context: errContext,
}
}
cert, err := p.AuthorizeSSHRenew(ctx, token)
if err != nil {
return nil, &apiError{
err: errors.Wrap(err, "authorizeSSHRenew"),
code: http.StatusUnauthorized,
context: errContext,
}
}
return cert, nil
}
// RenewSSH creates a signed SSH certificate using the old SSH certificate as a template.
func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) {
nonce, err := randutil.ASCII(32)
if err != nil {
return nil, &apiError{err: err, code: http.StatusInternalServerError}
return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH")
}
var serial uint64
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
return nil, &apiError{
err: errors.Wrap(err, "renewSSH: error reading random number"),
code: http.StatusInternalServerError,
}
return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error reading random number")
}
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
return nil, errors.New("rewnewSSH: cannot renew certificate without validity period")
return nil, errs.BadRequest("rewnewSSH: cannot renew certificate without validity period")
}
backdate := a.config.AuthorityConfig.Backdate.Duration
@ -393,25 +321,16 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
switch cert.CertType {
case ssh.UserCert:
if a.sshCAUserCertSignKey == nil {
return nil, &apiError{
err: errors.New("renewSSH: user certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
return nil, errs.NotImplemented("renewSSH: user certificate signing is not enabled")
}
signer = a.sshCAUserCertSignKey
case ssh.HostCert:
if a.sshCAHostCertSignKey == nil {
return nil, &apiError{
err: errors.New("renewSSH: host certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
return nil, errs.NotImplemented("renewSSH: host certificate signing is not enabled")
}
signer = a.sshCAHostCertSignKey
default:
return nil, &apiError{
err: errors.Errorf("renewSSH: unexpected ssh certificate type: %d", cert.CertType),
code: http.StatusInternalServerError,
}
return nil, errs.InternalServer("renewSSH: unexpected ssh certificate type: %d", cert.CertType)
}
cert.SignatureKey = signer.PublicKey()
@ -422,79 +341,43 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
// Sign the certificate
sig, err := signer.Sign(rand.Reader, data)
if err != nil {
return nil, &apiError{
err: errors.Wrap(err, "renewSSH: error signing certificate"),
code: http.StatusInternalServerError,
}
return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error signing certificate")
}
cert.Signature = sig
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
return nil, &apiError{
err: errors.Wrap(err, "renewSSH: error storing certificate in db"),
code: http.StatusInternalServerError,
}
return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error storing certificate in db")
}
return cert, nil
}
// authorizeSSHRekey authorizes an SSH certificate rekey request, by
// validating the contents of an SSHPOP token.
func (a *Authority) authorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []provisioner.SignOption, error) {
errContext := map[string]interface{}{"ott": token}
p, err := a.authorizeToken(ctx, token)
if err != nil {
return nil, nil, &apiError{
err: errors.Wrap(err, "authorizeSSHRenew"),
code: http.StatusUnauthorized,
context: errContext,
}
}
cert, opts, err := p.AuthorizeSSHRekey(ctx, token)
if err != nil {
return nil, nil, &apiError{
err: errors.Wrap(err, "authorizeSSHRekey"),
code: http.StatusUnauthorized,
context: errContext,
}
}
return cert, opts, nil
}
// RekeySSH creates a signed SSH certificate using the old SSH certificate as a template.
func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var validators []provisioner.SSHCertificateValidator
var validators []provisioner.SSHCertValidator
for _, op := range signOpts {
switch o := op.(type) {
// validate the ssh.Certificate
case provisioner.SSHCertificateValidator:
case provisioner.SSHCertValidator:
validators = append(validators, o)
default:
return nil, &apiError{
err: errors.Errorf("rekeySSH: invalid extra option type %T", o),
code: http.StatusInternalServerError,
}
return nil, errs.InternalServer("rekeySSH; invalid extra option type %T", o)
}
}
nonce, err := randutil.ASCII(32)
if err != nil {
return nil, &apiError{err: err, code: http.StatusInternalServerError}
return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH")
}
var serial uint64
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
return nil, &apiError{
err: errors.Wrap(err, "rekeySSH: error reading random number"),
code: http.StatusInternalServerError,
}
return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error reading random number")
}
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
return nil, errors.New("rekeySSH: cannot rekey certificate without validity period")
return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period")
}
backdate := a.config.AuthorityConfig.Backdate.Duration
@ -521,25 +404,16 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
switch cert.CertType {
case ssh.UserCert:
if a.sshCAUserCertSignKey == nil {
return nil, &apiError{
err: errors.New("rekeySSH: user certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
return nil, errs.NotImplemented("rekeySSH; user certificate signing is not enabled")
}
signer = a.sshCAUserCertSignKey
case ssh.HostCert:
if a.sshCAHostCertSignKey == nil {
return nil, &apiError{
err: errors.New("rekeySSH: host certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
return nil, errs.NotImplemented("rekeySSH; host certificate signing is not enabled")
}
signer = a.sshCAHostCertSignKey
default:
return nil, &apiError{
err: errors.Errorf("rekeySSH: unexpected ssh certificate type: %d", cert.CertType),
code: http.StatusInternalServerError,
}
return nil, errs.BadRequest("rekeySSH; unexpected ssh certificate type: %d", cert.CertType)
}
cert.SignatureKey = signer.PublicKey()
@ -547,80 +421,47 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
data := cert.Marshal()
data = data[:len(data)-4]
// Sign the certificate
// Sign the certificate.
sig, err := signer.Sign(rand.Reader, data)
if err != nil {
return nil, &apiError{
err: errors.Wrap(err, "rekeySSH: error signing certificate"),
code: http.StatusInternalServerError,
}
return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error signing certificate")
}
cert.Signature = sig
// User provisioners validators
// Apply validators from provisioner.
for _, v := range validators {
if err := v.Valid(cert); err != nil {
return nil, &apiError{err: err, code: http.StatusForbidden}
if err := v.Valid(cert, provisioner.SSHOptions{Backdate: backdate}); err != nil {
return nil, errs.Wrap(http.StatusForbidden, err, "rekeySSH")
}
}
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
return nil, &apiError{
err: errors.Wrap(err, "rekeySSH: error storing certificate in db"),
code: http.StatusInternalServerError,
}
return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error storing certificate in db")
}
return cert, nil
}
// authorizeSSHRevoke authorizes an SSH certificate revoke request, by
// validating the contents of an SSHPOP token.
func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error {
errContext := map[string]interface{}{"ott": token}
p, err := a.authorizeToken(ctx, token)
if err != nil {
return &apiError{errors.Wrap(err, "authorizeSSHRevoke"), http.StatusUnauthorized, errContext}
}
if err = p.AuthorizeSSHRevoke(ctx, token); err != nil {
return &apiError{errors.Wrap(err, "authorizeSSHRevoke"), http.StatusUnauthorized, errContext}
}
return nil
}
// SignSSHAddUser signs a certificate that provisions a new user in a server.
func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) {
if a.sshCAUserCertSignKey == nil {
return nil, &apiError{
err: errors.New("signSSHAddUser: user certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled")
}
if subject.CertType != ssh.UserCert {
return nil, &apiError{
err: errors.New("signSSHAddUser: certificate is not a user certificate"),
code: http.StatusForbidden,
}
return nil, errs.Forbidden("signSSHAddUser: certificate is not a user certificate")
}
if len(subject.ValidPrincipals) != 1 {
return nil, &apiError{
err: errors.New("signSSHAddUser: certificate does not have only one principal"),
code: http.StatusForbidden,
}
return nil, errs.Forbidden("signSSHAddUser: certificate does not have only one principal")
}
nonce, err := randutil.ASCII(32)
if err != nil {
return nil, &apiError{err: err, code: http.StatusInternalServerError}
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser")
}
var serial uint64
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
return nil, &apiError{
err: errors.Wrap(err, "signSSHAddUser: error reading random number"),
code: http.StatusInternalServerError,
}
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error reading random number")
}
signer := a.sshCAUserCertSignKey
@ -656,10 +497,7 @@ func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate)
cert.Signature = sig
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
return nil, &apiError{
err: errors.Wrap(err, "signSSHAddUser: error storing certificate in db"),
code: http.StatusInternalServerError,
}
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error storing certificate in db")
}
return cert, nil
@ -691,14 +529,12 @@ func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token st
// GetSSHHosts returns a list of valid host principals.
func (a *Authority) GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) {
if a.sshGetHostsFunc != nil {
return a.sshGetHostsFunc(cert)
hosts, err := a.sshGetHostsFunc(cert)
return hosts, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts")
}
hostnames, err := a.db.GetSSHHostPrincipals()
if err != nil {
return nil, &apiError{
err: errors.Wrap(err, "getSSHHosts"),
code: http.StatusInternalServerError,
}
return nil, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts")
}
hosts := make([]sshutil.Host, len(hostnames))

View file

@ -5,8 +5,10 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"encoding/base64"
"fmt"
"net/http"
"reflect"
"testing"
"time"
@ -15,6 +17,8 @@ import (
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/sshutil"
"github.com/smallstep/certificates/templates"
"github.com/smallstep/cli/jose"
"golang.org/x/crypto/ssh"
@ -58,7 +62,7 @@ func (m sshTestCertModifier) Modify(cert *ssh.Certificate) error {
type sshTestCertValidator string
func (v sshTestCertValidator) Valid(crt *ssh.Certificate) error {
func (v sshTestCertValidator) Valid(crt *ssh.Certificate, opts provisioner.SSHOptions) error {
if v == "" {
return nil
}
@ -76,7 +80,7 @@ func (v sshTestOptionsValidator) Valid(opts provisioner.SSHOptions) error {
type sshTestOptionsModifier string
func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertificateModifier {
func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertModifier {
return sshTestCertModifier(string(m))
}
@ -488,18 +492,18 @@ func TestAuthority_CheckSSHHost(t *testing.T) {
want bool
wantErr bool
}{
{"true", fields{true, nil}, args{context.TODO(), "foo.internal.com", ""}, true, false},
{"false", fields{false, nil}, args{context.TODO(), "foo.internal.com", ""}, false, false},
{"notImplemented", fields{false, db.ErrNotImplemented}, args{context.TODO(), "foo.internal.com", ""}, false, true},
{"notImplemented", fields{true, db.ErrNotImplemented}, args{context.TODO(), "foo.internal.com", ""}, false, true},
{"internal", fields{false, fmt.Errorf("an error")}, args{context.TODO(), "foo.internal.com", ""}, false, true},
{"internal", fields{true, fmt.Errorf("an error")}, args{context.TODO(), "foo.internal.com", ""}, false, true},
{"true", fields{true, nil}, args{context.Background(), "foo.internal.com", ""}, true, false},
{"false", fields{false, nil}, args{context.Background(), "foo.internal.com", ""}, false, false},
{"notImplemented", fields{false, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true},
{"notImplemented", fields{true, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true},
{"internal", fields{false, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true},
{"internal", fields{true, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := testAuthority(t)
a.db = &MockAuthDB{
isSSHHost: func(_ string) (bool, error) {
a.db = &db.MockAuthDB{
MIsSSHHost: func(_ string) (bool, error) {
return tt.fields.exists, tt.fields.err
},
}
@ -640,6 +644,9 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
if (err != nil) != tt.wantErr {
t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr)
return
} else if err != nil {
_, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want)
@ -647,3 +654,266 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
})
}
}
func TestAuthority_GetSSHHosts(t *testing.T) {
a := testAuthority(t)
type test struct {
getHostsFunc func(*x509.Certificate) ([]sshutil.Host, error)
auth *Authority
cert *x509.Certificate
cmp func(got []sshutil.Host)
err error
code int
}
tests := map[string]func(t *testing.T) *test{
"fail/getHostsFunc-fail": func(t *testing.T) *test {
return &test{
getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) {
return nil, errors.New("force")
},
cert: &x509.Certificate{},
err: errors.New("getSSHHosts: force"),
code: http.StatusInternalServerError,
}
},
"ok/getHostsFunc-defined": func(t *testing.T) *test {
hosts := []sshutil.Host{
{HostID: "1", Hostname: "foo"},
{HostID: "2", Hostname: "bar"},
}
return &test{
getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) {
return hosts, nil
},
cert: &x509.Certificate{},
cmp: func(got []sshutil.Host) {
assert.Equals(t, got, hosts)
},
}
},
"fail/db-get-fail": func(t *testing.T) *test {
return &test{
auth: testAuthority(t, WithDatabase(&db.MockAuthDB{
MGetSSHHostPrincipals: func() ([]string, error) {
return nil, errors.New("force")
},
})),
cert: &x509.Certificate{},
err: errors.New("getSSHHosts: force"),
code: http.StatusInternalServerError,
}
},
"ok": func(t *testing.T) *test {
return &test{
auth: testAuthority(t, WithDatabase(&db.MockAuthDB{
MGetSSHHostPrincipals: func() ([]string, error) {
return []string{"foo", "bar"}, nil
},
})),
cert: &x509.Certificate{},
cmp: func(got []sshutil.Host) {
assert.Equals(t, got, []sshutil.Host{
{Hostname: "foo"},
{Hostname: "bar"},
})
},
}
},
}
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
auth := tc.auth
if auth == nil {
auth = a
}
auth.sshGetHostsFunc = tc.getHostsFunc
hosts, err := auth.GetSSHHosts(tc.cert)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
tc.cmp(hosts)
}
}
})
}
}
func TestAuthority_RekeySSH(t *testing.T) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err)
pub, err := ssh.NewPublicKey(key.Public())
assert.FatalError(t, err)
signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err)
signer, err := ssh.NewSignerFromKey(signKey)
assert.FatalError(t, err)
userOptions := sshTestModifier{
CertType: ssh.UserCert,
}
now := time.Now().UTC()
a := testAuthority(t)
type test struct {
auth *Authority
userSigner ssh.Signer
hostSigner ssh.Signer
cert *ssh.Certificate
key ssh.PublicKey
signOpts []provisioner.SignOption
cmpResult func(old, n *ssh.Certificate)
err error
code int
}
tests := map[string]func(t *testing.T) *test{
"fail/opts-type": func(t *testing.T) *test {
return &test{
userSigner: signer,
hostSigner: signer,
key: pub,
signOpts: []provisioner.SignOption{userOptions},
err: errors.New("rekeySSH; invalid extra option type"),
code: http.StatusInternalServerError,
}
},
"fail/old-cert-validAfter": func(t *testing.T) *test {
return &test{
userSigner: signer,
hostSigner: signer,
cert: &ssh.Certificate{},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; cannot rekey certificate without validity period"),
code: http.StatusBadRequest,
}
},
"fail/old-cert-validBefore": func(t *testing.T) *test {
return &test{
userSigner: signer,
hostSigner: signer,
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; cannot rekey certificate without validity period"),
code: http.StatusBadRequest,
}
},
"fail/old-cert-no-user-key": func(t *testing.T) *test {
return &test{
userSigner: nil,
hostSigner: signer,
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.UserCert},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; user certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
},
"fail/old-cert-no-host-key": func(t *testing.T) *test {
return &test{
userSigner: signer,
hostSigner: nil,
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.HostCert},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; host certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
},
"fail/unexpected-old-cert-type": func(t *testing.T) *test {
return &test{
userSigner: signer,
hostSigner: signer,
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; unexpected ssh certificate type: 0"),
code: http.StatusBadRequest,
}
},
"fail/db-store": func(t *testing.T) *test {
return &test{
auth: testAuthority(t, WithDatabase(&db.MockAuthDB{
MStoreSSHCertificate: func(cert *ssh.Certificate) error {
return errors.New("force")
},
})),
userSigner: signer,
hostSigner: nil,
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.UserCert},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; error storing certificate in db: force"),
code: http.StatusInternalServerError,
}
},
"ok": func(t *testing.T) *test {
va1 := now.Add(-24 * time.Hour)
vb1 := now.Add(-23 * time.Hour)
return &test{
userSigner: signer,
hostSigner: nil,
cert: &ssh.Certificate{
ValidAfter: uint64(va1.Unix()),
ValidBefore: uint64(vb1.Unix()),
CertType: ssh.UserCert,
ValidPrincipals: []string{"foo", "bar"},
KeyId: "foo",
},
key: pub,
signOpts: []provisioner.SignOption{},
cmpResult: func(old, n *ssh.Certificate) {
assert.Equals(t, n.CertType, old.CertType)
assert.Equals(t, n.ValidPrincipals, old.ValidPrincipals)
assert.Equals(t, n.KeyId, old.KeyId)
assert.True(t, n.ValidAfter > uint64(now.Add(-5*time.Minute).Unix()))
assert.True(t, n.ValidAfter < uint64(now.Add(5*time.Minute).Unix()))
l8r := now.Add(1 * time.Hour)
assert.True(t, n.ValidBefore > uint64(l8r.Add(-5*time.Minute).Unix()))
assert.True(t, n.ValidBefore < uint64(l8r.Add(5*time.Minute).Unix()))
},
}
},
}
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
auth := tc.auth
if auth == nil {
auth = a
}
a.sshCAUserCertSignKey = tc.userSigner
a.sshCAHostCertSignKey = tc.hostSigner
cert, err := auth.RekeySSH(tc.cert, tc.key, tc.signOpts...)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
tc.cmpResult(tc.cert, cert)
}
}
})
}
}

View file

@ -0,0 +1 @@
ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJj80EJXJR9vxefhdqOLSdzRzBw24t9YKPxb+eCYLf7BU50pJQnB/jK2ZM3qLFbieLaYjngZ86T4DzHxlPAnlAY=

View file

@ -0,0 +1 @@
ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJ8einS88ZaWpcTZG27D5N9JDKfGv0rzjDByLGsZzMsLYl3XcsN9IWKXB6b+5GJ3UaoZf/pFxzRzIdDIh7Ypw3Y=

View file

@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIKZCgb5pTSSCbr/xcHCOkl9O6tQtZmNahr3Ap3/c2nBLoAoGCCqGSM49
AwEHoUQDQgAEmPzQQlclH2/F5+F2o4tJ3NHMHDbi31go/Fv54Jgt/sFTnSklCcH+
MrZkzeosVuJ4tpiOeBnzpPgPMfGU8CeUBg==
-----END EC PRIVATE KEY-----

View file

@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIDuzykyPM6rLnSoyF4jnOpPAlyKZERqtaB8PTh179DMgoAoGCCqGSM49
AwEHoUQDQgAEnx6KdLzxlpalxNkbbsPk30kMp8a/SvOMMHIsaxnMywtiXddyw30h
YpcHpv7kYndRqhl/+kXHNHMh0MiHtinDdg==
-----END EC PRIVATE KEY-----

View file

@ -14,6 +14,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/crypto/x509util"
@ -60,7 +61,7 @@ func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
// Sign creates a signed certificate from a certificate signing request.
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
var (
errContext = apiCtx{"csr": csr, "signOptions": signOpts}
opts = []interface{}{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)}
mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)}
certValidators = []provisioner.CertificateValidator{}
issIdentity = a.intermediateIdentity
@ -75,54 +76,52 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti
certValidators = append(certValidators, k)
case provisioner.CertificateRequestValidator:
if err := k.Valid(csr); err != nil {
return nil, &apiError{errors.Wrap(err, "sign"), http.StatusUnauthorized, errContext}
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...)
}
case provisioner.ProfileModifier:
mods = append(mods, k.Option(signOpts))
default:
return nil, &apiError{errors.Errorf("sign: invalid extra option type %T", k),
http.StatusInternalServerError, errContext}
return nil, errs.InternalServer("authority.Sign; invalid extra option type %T", append([]interface{}{k}, opts...)...)
}
}
if err := csr.CheckSignature(); err != nil {
return nil, &apiError{errors.Wrap(err, "sign: invalid certificate request"),
http.StatusBadRequest, errContext}
return nil, errs.Wrap(http.StatusBadRequest, err, "authority.Sign; invalid certificate request", opts...)
}
leaf, err := x509util.NewLeafProfileWithCSR(csr, issIdentity.Crt, issIdentity.Key, mods...)
if err != nil {
return nil, &apiError{errors.Wrapf(err, "sign"), http.StatusInternalServerError, errContext}
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign", opts...)
}
for _, v := range certValidators {
if err := v.Valid(leaf.Subject()); err != nil {
return nil, &apiError{errors.Wrap(err, "sign"), http.StatusUnauthorized, errContext}
if err := v.Valid(leaf.Subject(), signOpts); err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...)
}
}
crtBytes, err := leaf.CreateCertificate()
if err != nil {
return nil, &apiError{errors.Wrap(err, "sign: error creating new leaf certificate"),
http.StatusInternalServerError, errContext}
return nil, errs.Wrap(http.StatusInternalServerError, err,
"authority.Sign; error creating new leaf certificate", opts...)
}
serverCert, err := x509.ParseCertificate(crtBytes)
if err != nil {
return nil, &apiError{errors.Wrap(err, "sign: error parsing new leaf certificate"),
http.StatusInternalServerError, errContext}
return nil, errs.Wrap(http.StatusInternalServerError, err,
"authority.Sign; error parsing new leaf certificate", opts...)
}
caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw)
if err != nil {
return nil, &apiError{errors.Wrap(err, "sign: error parsing intermediate certificate"),
http.StatusInternalServerError, errContext}
return nil, errs.Wrap(http.StatusInternalServerError, err,
"authority.Sign; error parsing intermediate certificate", opts...)
}
if err = a.db.StoreCertificate(serverCert); err != nil {
if err != db.ErrNotImplemented {
return nil, &apiError{errors.Wrap(err, "sign: error storing certificate in db"),
http.StatusInternalServerError, errContext}
return nil, errs.Wrap(http.StatusInternalServerError, err,
"authority.Sign; error storing certificate in db", opts...)
}
}
@ -132,9 +131,11 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti
// Renew creates a new Certificate identical to the old certificate, except
// with a validity window that begins 'now'.
func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) {
opts := []interface{}{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())}
// Check step provisioner extensions
if err := a.authorizeRenew(oldCert); err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Renew", opts...)
}
// Issuer
@ -161,16 +162,16 @@ func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error
MaxPathLenZero: oldCert.MaxPathLenZero,
OCSPServer: oldCert.OCSPServer,
IssuingCertificateURL: oldCert.IssuingCertificateURL,
PermittedDNSDomainsCritical: oldCert.PermittedDNSDomainsCritical,
PermittedEmailAddresses: oldCert.PermittedEmailAddresses,
DNSNames: oldCert.DNSNames,
EmailAddresses: oldCert.EmailAddresses,
IPAddresses: oldCert.IPAddresses,
URIs: oldCert.URIs,
PermittedDNSDomainsCritical: oldCert.PermittedDNSDomainsCritical,
PermittedDNSDomains: oldCert.PermittedDNSDomains,
ExcludedDNSDomains: oldCert.ExcludedDNSDomains,
PermittedIPRanges: oldCert.PermittedIPRanges,
ExcludedIPRanges: oldCert.ExcludedIPRanges,
PermittedEmailAddresses: oldCert.PermittedEmailAddresses,
ExcludedEmailAddresses: oldCert.ExcludedEmailAddresses,
PermittedURIDomains: oldCert.PermittedURIDomains,
ExcludedURIDomains: oldCert.ExcludedURIDomains,
@ -190,29 +191,28 @@ func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error
leaf, err := x509util.NewLeafProfileWithTemplate(newCert,
issIdentity.Crt, issIdentity.Key)
if err != nil {
return nil, &apiError{err, http.StatusInternalServerError, apiCtx{}}
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Renew", opts...)
}
crtBytes, err := leaf.CreateCertificate()
if err != nil {
return nil, &apiError{errors.Wrap(err, "error renewing certificate from existing server certificate"),
http.StatusInternalServerError, apiCtx{}}
return nil, errs.Wrap(http.StatusInternalServerError, err,
"authority.Renew; error renewing certificate from existing server certificate", opts...)
}
serverCert, err := x509.ParseCertificate(crtBytes)
if err != nil {
return nil, &apiError{errors.Wrap(err, "error parsing new server certificate"),
http.StatusInternalServerError, apiCtx{}}
return nil, errs.Wrap(http.StatusInternalServerError, err,
"authority.Renew; error parsing new server certificate", opts...)
}
caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw)
if err != nil {
return nil, &apiError{errors.Wrap(err, "error parsing intermediate certificate"),
http.StatusInternalServerError, apiCtx{}}
return nil, errs.Wrap(http.StatusInternalServerError, err,
"authority.Renew; error parsing intermediate certificate", opts...)
}
if err = a.db.StoreCertificate(serverCert); err != nil {
if err != db.ErrNotImplemented {
return nil, &apiError{errors.Wrap(err, "error storing certificate in db"),
http.StatusInternalServerError, apiCtx{}}
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Renew; error storing certificate in db", opts...)
}
}
@ -236,26 +236,26 @@ type RevokeOptions struct {
// being renewed.
//
// TODO: Add OCSP and CRL support.
func (a *Authority) Revoke(ctx context.Context, opts *RevokeOptions) error {
errContext := apiCtx{
"serialNumber": opts.Serial,
"reasonCode": opts.ReasonCode,
"reason": opts.Reason,
"passiveOnly": opts.PassiveOnly,
"mTLS": opts.MTLS,
"context": string(provisioner.MethodFromContext(ctx)),
func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error {
opts := []interface{}{
errs.WithKeyVal("serialNumber", revokeOpts.Serial),
errs.WithKeyVal("reasonCode", revokeOpts.ReasonCode),
errs.WithKeyVal("reason", revokeOpts.Reason),
errs.WithKeyVal("passiveOnly", revokeOpts.PassiveOnly),
errs.WithKeyVal("MTLS", revokeOpts.MTLS),
errs.WithKeyVal("context", string(provisioner.MethodFromContext(ctx))),
}
if opts.MTLS {
errContext["certificate"] = base64.StdEncoding.EncodeToString(opts.Crt.Raw)
if revokeOpts.MTLS {
opts = append(opts, errs.WithKeyVal("certificate", base64.StdEncoding.EncodeToString(revokeOpts.Crt.Raw)))
} else {
errContext["ott"] = opts.OTT
opts = append(opts, errs.WithKeyVal("token", revokeOpts.OTT))
}
rci := &db.RevokedCertificateInfo{
Serial: opts.Serial,
ReasonCode: opts.ReasonCode,
Reason: opts.Reason,
MTLS: opts.MTLS,
Serial: revokeOpts.Serial,
ReasonCode: revokeOpts.ReasonCode,
Reason: revokeOpts.Reason,
MTLS: revokeOpts.MTLS,
RevokedAt: time.Now().UTC(),
}
@ -264,48 +264,43 @@ func (a *Authority) Revoke(ctx context.Context, opts *RevokeOptions) error {
err error
)
// If not mTLS then get the TokenID of the token.
if !opts.MTLS {
// Validate payload
token, err := jose.ParseSigned(opts.OTT)
if !revokeOpts.MTLS {
token, err := jose.ParseSigned(revokeOpts.OTT)
if err != nil {
return &apiError{errors.Wrapf(err, "revoke: error parsing token"),
http.StatusUnauthorized, errContext}
return errs.Wrap(http.StatusUnauthorized, err,
"authority.Revoke; error parsing token", opts...)
}
// Get claims w/out verification. We should have already verified this token
// earlier with a call to authorizeSSHRevoke.
// Get claims w/out verification.
var claims Claims
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil {
return &apiError{errors.Wrap(err, "revoke"), http.StatusUnauthorized, errContext}
return errs.Wrap(http.StatusUnauthorized, err, "authority.Revoke", opts...)
}
// This method will also validate the audiences for JWK provisioners.
var ok bool
p, ok = a.provisioners.LoadByToken(token, &claims.Claims)
if !ok {
return &apiError{
errors.Errorf("revoke: provisioner not found"),
http.StatusInternalServerError, errContext}
return errs.InternalServer("authority.Revoke; provisioner not found", opts...)
}
rci.TokenID, err = p.GetTokenID(opts.OTT)
rci.TokenID, err = p.GetTokenID(revokeOpts.OTT)
if err != nil {
return &apiError{errors.Wrap(err, "revoke: could not get ID for token"),
http.StatusInternalServerError, errContext}
return errs.Wrap(http.StatusInternalServerError, err,
"authority.Revoke; could not get ID for token")
}
errContext["tokenID"] = rci.TokenID
opts = append(opts, errs.WithKeyVal("tokenID", rci.TokenID))
} else {
// Load the Certificate provisioner if one exists.
p, err = a.LoadProvisionerByCertificate(opts.Crt)
p, err = a.LoadProvisionerByCertificate(revokeOpts.Crt)
if err != nil {
return &apiError{
errors.Wrap(err, "revoke: unable to load certificate provisioner"),
http.StatusUnauthorized, errContext}
return errs.Wrap(http.StatusUnauthorized, err,
"authority.Revoke: unable to load certificate provisioner", opts...)
}
}
rci.ProvisionerID = p.GetID()
errContext["provisionerID"] = rci.ProvisionerID
opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID))
if provisioner.MethodFromContext(ctx) == provisioner.RevokeSSHMethod {
if provisioner.MethodFromContext(ctx) == provisioner.SSHRevokeMethod {
err = a.db.RevokeSSH(rci)
} else { // default to revoke x509
err = a.db.Revoke(rci)
@ -314,13 +309,12 @@ func (a *Authority) Revoke(ctx context.Context, opts *RevokeOptions) error {
case nil:
return nil
case db.ErrNotImplemented:
return &apiError{errors.New("revoke: no persistence layer configured"),
http.StatusNotImplemented, errContext}
return errs.NotImplemented("authority.Revoke; no persistence layer configured", opts...)
case db.ErrAlreadyExists:
return &apiError{errors.Errorf("revoke: certificate with serial number %s has already been revoked", rci.Serial),
http.StatusBadRequest, errContext}
return errs.BadRequest("authority.Revoke; certificate with serial "+
"number %s has already been revoked", append([]interface{}{rci.Serial}, opts...)...)
default:
return &apiError{err, http.StatusInternalServerError, errContext}
return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...)
}
}
@ -330,17 +324,17 @@ func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) {
a.intermediateIdentity.Crt, a.intermediateIdentity.Key,
x509util.WithHosts(strings.Join(a.config.DNSNames, ",")))
if err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetTLSCertificate")
}
crtBytes, err := profile.CreateCertificate()
if err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetTLSCertificate")
}
keyPEM, err := pemutil.Serialize(profile.SubjectPrivateKey())
if err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetTLSCertificate")
}
crtPEM := pem.EncodeToMemory(&pem.Block{
@ -352,19 +346,21 @@ func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) {
// to a tls.Certificate.
intermediatePEM, err := pemutil.Serialize(a.intermediateIdentity.Crt)
if err != nil {
return nil, err
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetTLSCertificate")
}
tlsCrt, err := tls.X509KeyPair(append(crtPEM,
pem.EncodeToMemory(intermediatePEM)...),
pem.EncodeToMemory(keyPEM))
if err != nil {
return nil, errors.Wrap(err, "error creating tls certificate")
return nil, errs.Wrap(http.StatusInternalServerError, err,
"authority.GetTLSCertificate; error creating tls certificate")
}
// Get the 'leaf' certificate and set the attribute accordingly.
leaf, err := x509.ParseCertificate(tlsCrt.Certificate[0])
if err != nil {
return nil, errors.Wrap(err, "error parsing tls certificate")
return nil, errs.Wrap(http.StatusInternalServerError, err,
"authority.GetTLSCertificate; error parsing tls certificate")
}
tlsCrt.Leaf = leaf

View file

@ -7,7 +7,6 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64"
"encoding/pem"
"fmt"
"net/http"
@ -19,6 +18,7 @@ import (
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/keys"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/tlsutil"
@ -77,7 +77,7 @@ func getCSR(t *testing.T, priv interface{}, opts ...func(*x509.CertificateReques
return csr
}
func TestSign(t *testing.T) {
func TestAuthority_Sign(t *testing.T) {
pub, priv, err := keys.GenerateDefaultKeyPair()
assert.FatalError(t, err)
@ -102,7 +102,7 @@ func TestSign(t *testing.T) {
p := a.config.AuthorityConfig.Provisioners[1].(*provisioner.JWK)
key, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
assert.FatalError(t, err)
token, err := generateToken("smallstep test", "step-cli", "https://test.ca.smallstep.com/sign", []string{"test.smallstep.com"}, time.Now(), key)
token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key)
assert.FatalError(t, err)
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
extraOpts, err := a.Authorize(ctx, token)
@ -113,7 +113,8 @@ func TestSign(t *testing.T) {
csr *x509.CertificateRequest
signOpts provisioner.Options
extraOpts []provisioner.SignOption
err *apiError
err error
code int
}
tests := map[string]func(*testing.T) *signTest{
"fail invalid signature": func(t *testing.T) *signTest {
@ -124,10 +125,8 @@ func TestSign(t *testing.T) {
csr: csr,
extraOpts: extraOpts,
signOpts: signOpts,
err: &apiError{errors.New("sign: invalid certificate request"),
http.StatusBadRequest,
apiCtx{"csr": csr, "signOptions": signOpts},
},
err: errors.New("authority.Sign; invalid certificate request"),
code: http.StatusBadRequest,
}
},
"fail invalid extra option": func(t *testing.T) *signTest {
@ -138,10 +137,8 @@ func TestSign(t *testing.T) {
csr: csr,
extraOpts: append(extraOpts, "42"),
signOpts: signOpts,
err: &apiError{errors.New("sign: invalid extra option type string"),
http.StatusInternalServerError,
apiCtx{"csr": csr, "signOptions": signOpts},
},
err: errors.New("authority.Sign; invalid extra option type string"),
code: http.StatusInternalServerError,
}
},
"fail merge default ASN1DN": func(t *testing.T) *signTest {
@ -153,10 +150,8 @@ func TestSign(t *testing.T) {
csr: csr,
extraOpts: extraOpts,
signOpts: signOpts,
err: &apiError{errors.New("sign: default ASN1DN template cannot be nil"),
http.StatusInternalServerError,
apiCtx{"csr": csr, "signOptions": signOpts},
},
err: errors.New("authority.Sign: default ASN1DN template cannot be nil"),
code: http.StatusInternalServerError,
}
},
"fail create cert": func(t *testing.T) *signTest {
@ -168,10 +163,8 @@ func TestSign(t *testing.T) {
csr: csr,
extraOpts: extraOpts,
signOpts: signOpts,
err: &apiError{errors.New("sign: error creating new leaf certificate"),
http.StatusInternalServerError,
apiCtx{"csr": csr, "signOptions": signOpts},
},
err: errors.New("authority.Sign; error creating new leaf certificate"),
code: http.StatusInternalServerError,
}
},
"fail provisioner duration claim": func(t *testing.T) *signTest {
@ -185,10 +178,8 @@ func TestSign(t *testing.T) {
csr: csr,
extraOpts: extraOpts,
signOpts: _signOpts,
err: &apiError{errors.New("sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h0m0s"),
http.StatusUnauthorized,
apiCtx{"csr": csr, "signOptions": _signOpts},
},
err: errors.New("authority.Sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h1m0s"),
code: http.StatusUnauthorized,
}
},
"fail validate sans when adding common name not in claims": func(t *testing.T) *signTest {
@ -200,10 +191,8 @@ func TestSign(t *testing.T) {
csr: csr,
extraOpts: extraOpts,
signOpts: signOpts,
err: &apiError{errors.New("sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
http.StatusUnauthorized,
apiCtx{"csr": csr, "signOptions": signOpts},
},
err: errors.New("authority.Sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
code: http.StatusUnauthorized,
}
},
"fail rsa key too short": func(t *testing.T) *signTest {
@ -228,20 +217,16 @@ ZYtQ9Ot36qc=
csr: csr,
extraOpts: extraOpts,
signOpts: signOpts,
err: &apiError{errors.New("sign: rsa key in CSR must be at least 2048 bits (256 bytes)"),
http.StatusUnauthorized,
apiCtx{"csr": csr, "signOptions": signOpts},
},
err: errors.New("authority.Sign: rsa key in CSR must be at least 2048 bits (256 bytes)"),
code: http.StatusUnauthorized,
}
},
"fail store cert in db": func(t *testing.T) *signTest {
csr := getCSR(t, priv)
_a := testAuthority(t)
_a.db = &MockAuthDB{
storeCertificate: func(crt *x509.Certificate) error {
return &apiError{errors.New("force"),
http.StatusInternalServerError,
apiCtx{"csr": csr, "signOptions": signOpts}}
_a.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error {
return errors.New("force")
},
}
return &signTest{
@ -249,17 +234,15 @@ ZYtQ9Ot36qc=
csr: csr,
extraOpts: extraOpts,
signOpts: signOpts,
err: &apiError{errors.New("sign: error storing certificate in db: force"),
http.StatusInternalServerError,
apiCtx{"csr": csr, "signOptions": signOpts},
},
err: errors.New("authority.Sign; error storing certificate in db: force"),
code: http.StatusInternalServerError,
}
},
"ok": func(t *testing.T) *signTest {
csr := getCSR(t, priv)
_a := testAuthority(t)
_a.db = &MockAuthDB{
storeCertificate: func(crt *x509.Certificate) error {
_a.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil
},
@ -279,15 +262,17 @@ ZYtQ9Ot36qc=
certChain, err := tc.auth.Sign(tc.csr, tc.signOpts, tc.extraOpts...)
if err != nil {
if assert.NotNil(t, tc.err) {
switch v := err.(type) {
case *apiError:
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
assert.Equals(t, v.code, tc.err.code)
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, certChain)
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
ctxErr, ok := err.(*errs.Error)
assert.Fatal(t, ok, "error is not of type *errs.Error")
assert.Equals(t, ctxErr.Details["csr"], tc.csr)
assert.Equals(t, ctxErr.Details["signOptions"], tc.signOpts)
}
} else {
leaf := certChain[0]
@ -346,7 +331,7 @@ ZYtQ9Ot36qc=
}
}
func TestRenew(t *testing.T) {
func TestAuthority_Renew(t *testing.T) {
pub, _, err := keys.GenerateDefaultKeyPair()
assert.FatalError(t, err)
@ -375,9 +360,9 @@ func TestRenew(t *testing.T) {
x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"),
withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].(*provisioner.JWK).Key.KeyID))
assert.FatalError(t, err)
crtBytes, err := leaf.CreateCertificate()
certBytes, err := leaf.CreateCertificate()
assert.FatalError(t, err)
crt, err := x509.ParseCertificate(crtBytes)
cert, err := x509.ParseCertificate(certBytes)
assert.FatalError(t, err)
leafNoRenew, err := x509util.NewLeafProfile("norenew", a.intermediateIdentity.Crt,
@ -388,15 +373,16 @@ func TestRenew(t *testing.T) {
withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].(*provisioner.JWK).Key.KeyID),
)
assert.FatalError(t, err)
crtBytesNoRenew, err := leafNoRenew.CreateCertificate()
certBytesNoRenew, err := leafNoRenew.CreateCertificate()
assert.FatalError(t, err)
crtNoRenew, err := x509.ParseCertificate(crtBytesNoRenew)
certNoRenew, err := x509.ParseCertificate(certBytesNoRenew)
assert.FatalError(t, err)
type renewTest struct {
auth *Authority
crt *x509.Certificate
err *apiError
cert *x509.Certificate
err error
code int
}
tests := map[string]func() (*renewTest, error){
"fail-create-cert": func() (*renewTest, error) {
@ -404,25 +390,22 @@ func TestRenew(t *testing.T) {
_a.intermediateIdentity.Key = nil
return &renewTest{
auth: _a,
crt: crt,
err: &apiError{errors.New("error renewing certificate from existing server certificate"),
http.StatusInternalServerError, apiCtx{}},
cert: cert,
err: errors.New("authority.Renew; error renewing certificate from existing server certificate"),
code: http.StatusInternalServerError,
}, nil
},
"fail-unauthorized": func() (*renewTest, error) {
ctx := map[string]interface{}{
"serialNumber": crtNoRenew.SerialNumber.String(),
}
return &renewTest{
crt: crtNoRenew,
err: &apiError{errors.New("renew: renew is disabled for provisioner dev:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"),
http.StatusUnauthorized, ctx},
cert: certNoRenew,
err: errors.New("authority.Renew: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner dev:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"),
code: http.StatusUnauthorized,
}, nil
},
"success": func() (*renewTest, error) {
return &renewTest{
auth: a,
crt: crt,
cert: cert,
}, nil
},
"success-new-intermediate": func() (*renewTest, error) {
@ -430,23 +413,23 @@ func TestRenew(t *testing.T) {
assert.FatalError(t, err)
newRootBytes, err := newRootProfile.CreateCertificate()
assert.FatalError(t, err)
newRootCrt, err := x509.ParseCertificate(newRootBytes)
newRootCert, err := x509.ParseCertificate(newRootBytes)
assert.FatalError(t, err)
newIntermediateProfile, err := x509util.NewIntermediateProfile("new-intermediate",
newRootCrt, newRootProfile.SubjectPrivateKey())
newRootCert, newRootProfile.SubjectPrivateKey())
assert.FatalError(t, err)
newIntermediateBytes, err := newIntermediateProfile.CreateCertificate()
assert.FatalError(t, err)
newIntermediateCrt, err := x509.ParseCertificate(newIntermediateBytes)
newIntermediateCert, err := x509.ParseCertificate(newIntermediateBytes)
assert.FatalError(t, err)
_a := testAuthority(t)
_a.intermediateIdentity.Key = newIntermediateProfile.SubjectPrivateKey()
_a.intermediateIdentity.Crt = newIntermediateCrt
_a.intermediateIdentity.Crt = newIntermediateCert
return &renewTest{
auth: _a,
crt: crt,
cert: cert,
}, nil
},
}
@ -458,32 +441,33 @@ func TestRenew(t *testing.T) {
var certChain []*x509.Certificate
if tc.auth != nil {
certChain, err = tc.auth.Renew(tc.crt)
certChain, err = tc.auth.Renew(tc.cert)
} else {
certChain, err = a.Renew(tc.crt)
certChain, err = a.Renew(tc.cert)
}
if err != nil {
if assert.NotNil(t, tc.err) {
switch v := err.(type) {
case *apiError:
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
assert.Equals(t, v.code, tc.err.code)
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, certChain)
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
ctxErr, ok := err.(*errs.Error)
assert.Fatal(t, ok, "error is not of type *errs.Error")
assert.Equals(t, ctxErr.Details["serialNumber"], tc.cert.SerialNumber.String())
}
} else {
leaf := certChain[0]
intermediate := certChain[1]
if assert.Nil(t, tc.err) {
assert.Equals(t, leaf.NotAfter.Sub(leaf.NotBefore), tc.crt.NotAfter.Sub(crt.NotBefore))
assert.Equals(t, leaf.NotAfter.Sub(leaf.NotBefore), tc.cert.NotAfter.Sub(cert.NotBefore))
assert.True(t, leaf.NotBefore.After(now.Add(-time.Minute)))
assert.True(t, leaf.NotBefore.After(now.Add(-2*time.Minute)))
assert.True(t, leaf.NotBefore.Before(now.Add(time.Minute)))
expiry := now.Add(time.Minute * 7)
assert.True(t, leaf.NotAfter.After(expiry.Add(-time.Minute)))
assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute)))
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute)))
tmplt := a.config.AuthorityConfig.Template
@ -513,7 +497,7 @@ func TestRenew(t *testing.T) {
if a.intermediateIdentity.Crt.SerialNumber == tc.auth.intermediateIdentity.Crt.SerialNumber {
assert.Equals(t, leaf.AuthorityKeyId, a.intermediateIdentity.Crt.SubjectKeyId)
// Compare extensions: they can be in a different order
for _, ext1 := range tc.crt.Extensions {
for _, ext1 := range tc.cert.Extensions {
found := false
for _, ext2 := range leaf.Extensions {
if reflect.DeepEqual(ext1, ext2) {
@ -529,7 +513,7 @@ func TestRenew(t *testing.T) {
// We did change the intermediate before renewing.
assert.Equals(t, leaf.AuthorityKeyId, tc.auth.intermediateIdentity.Crt.SubjectKeyId)
// Compare extensions: they can be in a different order
for _, ext1 := range tc.crt.Extensions {
for _, ext1 := range tc.cert.Extensions {
// The authority key id extension should be different b/c the intermediates are different.
if ext1.Id.Equal(oidAuthorityKeyIdentifier) {
for _, ext2 := range leaf.Extensions {
@ -560,7 +544,7 @@ func TestRenew(t *testing.T) {
}
}
func TestGetTLSOptions(t *testing.T) {
func TestAuthority_GetTLSOptions(t *testing.T) {
type renewTest struct {
auth *Authority
opts *tlsutil.TLSOptions
@ -596,21 +580,12 @@ func TestGetTLSOptions(t *testing.T) {
}
}
func TestRevoke(t *testing.T) {
func TestAuthority_Revoke(t *testing.T) {
reasonCode := 2
reason := "bob was let go"
validIssuer := "step-cli"
validAudience := []string{"https://test.ca.smallstep.com/revoke"}
validAudience := testAudiences.Revoke
now := time.Now().UTC()
getCtx := func() map[string]interface{} {
return apiCtx{
"serialNumber": "sn",
"reasonCode": reasonCode,
"reason": reason,
"mTLS": false,
"passiveOnly": false,
}
}
jwk, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
assert.FatalError(t, err)
@ -619,30 +594,30 @@ func TestRevoke(t *testing.T) {
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
assert.FatalError(t, err)
a := testAuthority(t)
type test struct {
a *Authority
opts *RevokeOptions
err *apiError
auth *Authority
opts *RevokeOptions
err error
code int
checkErrDetails func(err *errs.Error)
}
tests := map[string]func() test{
"error/token/authorizeRevoke error": func() test {
a := testAuthority(t)
ctx := getCtx()
ctx["ott"] = "foo"
"fail/token/authorizeRevoke error": func() test {
return test{
a: a,
auth: a,
opts: &RevokeOptions{
OTT: "foo",
Serial: "sn",
ReasonCode: reasonCode,
Reason: reason,
},
err: &apiError{errors.New("revoke: authorizeRevoke: authorizeToken: error parsing token"),
http.StatusUnauthorized, ctx},
err: errors.New("authority.Revoke; error parsing token"),
code: http.StatusUnauthorized,
}
},
"error/nil-db": func() test {
a := testAuthority(t)
"fail/nil-db": func() test {
cl := jwt.Claims{
Subject: "sn",
Issuer: validIssuer,
@ -654,30 +629,30 @@ func TestRevoke(t *testing.T) {
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
ctx := getCtx()
ctx["ott"] = raw
ctx["tokenID"] = "44"
ctx["provisionerID"] = "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc"
return test{
a: a,
auth: a,
opts: &RevokeOptions{
Serial: "sn",
ReasonCode: reasonCode,
Reason: reason,
OTT: raw,
},
err: &apiError{errors.New("revoke: no persistence layer configured"),
http.StatusNotImplemented, ctx},
err: errors.New("authority.Revoke; no persistence layer configured"),
code: http.StatusNotImplemented,
checkErrDetails: func(err *errs.Error) {
assert.Equals(t, err.Details["token"], raw)
assert.Equals(t, err.Details["tokenID"], "44")
assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc")
},
}
},
"error/db-revoke": func() test {
a := testAuthority(t)
a.db = &MockAuthDB{
useToken: func(id, tok string) (bool, error) {
"fail/db-revoke": func() test {
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{
MUseToken: func(id, tok string) (bool, error) {
return true, nil
},
err: errors.New("force"),
}
Err: errors.New("force"),
}))
cl := jwt.Claims{
Subject: "sn",
@ -690,30 +665,30 @@ func TestRevoke(t *testing.T) {
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
ctx := getCtx()
ctx["ott"] = raw
ctx["tokenID"] = "44"
ctx["provisionerID"] = "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc"
return test{
a: a,
auth: _a,
opts: &RevokeOptions{
Serial: "sn",
ReasonCode: reasonCode,
Reason: reason,
OTT: raw,
},
err: &apiError{errors.New("force"),
http.StatusInternalServerError, ctx},
err: errors.New("authority.Revoke: force"),
code: http.StatusInternalServerError,
checkErrDetails: func(err *errs.Error) {
assert.Equals(t, err.Details["token"], raw)
assert.Equals(t, err.Details["tokenID"], "44")
assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc")
},
}
},
"error/already-revoked": func() test {
a := testAuthority(t)
a.db = &MockAuthDB{
useToken: func(id, tok string) (bool, error) {
"fail/already-revoked": func() test {
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{
MUseToken: func(id, tok string) (bool, error) {
return true, nil
},
err: db.ErrAlreadyExists,
}
Err: db.ErrAlreadyExists,
}))
cl := jwt.Claims{
Subject: "sn",
@ -726,29 +701,29 @@ func TestRevoke(t *testing.T) {
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
ctx := getCtx()
ctx["ott"] = raw
ctx["tokenID"] = "44"
ctx["provisionerID"] = "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc"
return test{
a: a,
auth: _a,
opts: &RevokeOptions{
Serial: "sn",
ReasonCode: reasonCode,
Reason: reason,
OTT: raw,
},
err: &apiError{errors.New("revoke: certificate with serial number sn has already been revoked"),
http.StatusBadRequest, ctx},
err: errors.New("authority.Revoke; certificate with serial number sn has already been revoked"),
code: http.StatusBadRequest,
checkErrDetails: func(err *errs.Error) {
assert.Equals(t, err.Details["token"], raw)
assert.Equals(t, err.Details["tokenID"], "44")
assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc")
},
}
},
"ok/token": func() test {
a := testAuthority(t)
a.db = &MockAuthDB{
useToken: func(id, tok string) (bool, error) {
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{
MUseToken: func(id, tok string) (bool, error) {
return true, nil
},
}
}))
cl := jwt.Claims{
Subject: "sn",
@ -761,7 +736,7 @@ func TestRevoke(t *testing.T) {
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
return test{
a: a,
auth: _a,
opts: &RevokeOptions{
Serial: "sn",
ReasonCode: reasonCode,
@ -770,39 +745,14 @@ func TestRevoke(t *testing.T) {
},
}
},
"error/mTLS/authorizeRevoke": func() test {
a := testAuthority(t)
a.db = &MockAuthDB{}
crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt")
assert.FatalError(t, err)
ctx := getCtx()
ctx["certificate"] = base64.StdEncoding.EncodeToString(crt.Raw)
ctx["mTLS"] = true
return test{
a: a,
opts: &RevokeOptions{
Crt: crt,
Serial: "sn",
ReasonCode: reasonCode,
Reason: reason,
MTLS: true,
},
err: &apiError{errors.New("revoke: authorizeRevoke: serial number in certificate different than body"),
http.StatusUnauthorized, ctx},
}
},
"ok/mTLS": func() test {
a := testAuthority(t)
a.db = &MockAuthDB{}
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{}))
crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt")
assert.FatalError(t, err)
return test{
a: a,
auth: _a,
opts: &RevokeOptions{
Crt: crt,
Serial: "102012593071130646873265215610956555026",
@ -816,15 +766,24 @@ func TestRevoke(t *testing.T) {
for name, f := range tests {
tc := f()
t.Run(name, func(t *testing.T) {
if err := tc.a.Revoke(context.TODO(), tc.opts); err != nil {
if assert.NotNil(t, tc.err) {
switch v := err.(type) {
case *apiError:
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
assert.Equals(t, v.code, tc.err.code)
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod)
if err := tc.auth.Revoke(ctx, tc.opts); err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
ctxErr, ok := err.(*errs.Error)
assert.Fatal(t, ok, "error is not of type *errs.Error")
assert.Equals(t, ctxErr.Details["serialNumber"], tc.opts.Serial)
assert.Equals(t, ctxErr.Details["reasonCode"], tc.opts.ReasonCode)
assert.Equals(t, ctxErr.Details["reason"], tc.opts.Reason)
assert.Equals(t, ctxErr.Details["MTLS"], tc.opts.MTLS)
assert.Equals(t, ctxErr.Details["context"], string(provisioner.RevokeMethod))
if tc.checkErrDetails != nil {
tc.checkErrDetails(ctxErr)
}
}
} else {

View file

@ -22,6 +22,7 @@ import (
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/keys"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/randutil"
@ -102,7 +103,7 @@ func TestCASign(t *testing.T) {
ca: ca,
body: "invalid json",
status: http.StatusBadRequest,
errMsg: "Bad Request",
errMsg: errs.BadRequestDefaultMsg,
}
},
"fail invalid-csr-sig": func(t *testing.T) *signTest {
@ -140,7 +141,7 @@ ZEp7knvU2psWRw==
ca: ca,
body: string(body),
status: http.StatusBadRequest,
errMsg: "Bad Request",
errMsg: errs.BadRequestDefaultMsg,
}
},
"fail unauthorized-ott": func(t *testing.T) *signTest {
@ -155,7 +156,7 @@ ZEp7knvU2psWRw==
ca: ca,
body: string(body),
status: http.StatusUnauthorized,
errMsg: "Unauthorized",
errMsg: errs.UnauthorizedDefaultMsg,
}
},
"fail commonname-claim": func(t *testing.T) *signTest {
@ -188,7 +189,7 @@ ZEp7knvU2psWRw==
ca: ca,
body: string(body),
status: http.StatusUnauthorized,
errMsg: "Unauthorized",
errMsg: errs.UnauthorizedDefaultMsg,
}
},
"ok": func(t *testing.T) *signTest {
@ -392,7 +393,7 @@ func TestCAProvisionerEncryptedKey(t *testing.T) {
ca: ca,
kid: "foo",
status: http.StatusNotFound,
errMsg: "Not Found",
errMsg: errs.NotFoundDefaultMsg,
}
},
"ok": func(t *testing.T) *ekt {
@ -455,7 +456,7 @@ func TestCARoot(t *testing.T) {
ca: ca,
sha: "foo",
status: http.StatusNotFound,
errMsg: "Not Found",
errMsg: errs.NotFoundDefaultMsg,
}
},
"success": func(t *testing.T) *rootTest {
@ -575,7 +576,7 @@ func TestCARenew(t *testing.T) {
ca: ca,
tlsConnState: nil,
status: http.StatusBadRequest,
errMsg: "Bad Request",
errMsg: errs.BadRequestDefaultMsg,
}
},
"request-missing-peer-certificate": func(t *testing.T) *renewTest {
@ -583,7 +584,7 @@ func TestCARenew(t *testing.T) {
ca: ca,
tlsConnState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{}},
status: http.StatusBadRequest,
errMsg: "Bad Request",
errMsg: errs.BadRequestDefaultMsg,
}
},
"success": func(t *testing.T) *renewTest {

View file

@ -486,7 +486,7 @@ func (c *Client) Version() (*api.VersionResponse, error) {
retry:
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Version; client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
@ -497,7 +497,7 @@ retry:
}
var version api.VersionResponse
if err := readJSON(resp.Body, &version); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Version; error reading %s", u)
}
return &version, nil
}
@ -510,7 +510,7 @@ func (c *Client) Health() (*api.HealthResponse, error) {
retry:
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Health; client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
@ -521,7 +521,7 @@ retry:
}
var health api.HealthResponse
if err := readJSON(resp.Body, &health); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Health; error reading %s", u)
}
return &health, nil
}
@ -537,7 +537,7 @@ func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
retry:
resp, err := newInsecureClient().Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Root; client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
@ -548,12 +548,12 @@ retry:
}
var root api.RootResponse
if err := readJSON(resp.Body, &root); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Root; error reading %s", u)
}
// verify the sha256
sum := sha256.Sum256(root.RootPEM.Raw)
if sha256Sum != strings.ToLower(hex.EncodeToString(sum[:])) {
return nil, errors.New("root certificate SHA256 fingerprint do not match")
return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match")
}
return &root, nil
}
@ -564,13 +564,13 @@ func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
return nil, errors.Wrap(err, "error marshaling request")
return nil, errs.Wrap(http.StatusInternalServerError, err, "client.Sign; error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/sign"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u)
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Sign; client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
@ -581,7 +581,7 @@ retry:
}
var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Sign; error reading %s", u)
}
// Add tls.ConnectionState:
// We'll extract the root certificate from the verified chains
@ -598,7 +598,7 @@ func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
retry:
resp, err := client.Post(u.String(), "application/json", http.NoBody)
if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u)
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Renew; client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
@ -609,7 +609,7 @@ retry:
}
var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Renew; error reading %s", u)
}
return &sign, nil
}
@ -961,8 +961,8 @@ func (c *Client) SSHCheckHost(principal string, token string) (*api.SSHCheckPrin
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client POST %s failed", u,
errs.WithMessage("Failed to perform POST request to %s", u))
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client POST %s failed",
[]interface{}{u, errs.WithMessage("Failed to perform POST request to %s", u)}...)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
@ -974,8 +974,8 @@ retry:
}
var check api.SSHCheckPrincipalResponse
if err := readJSON(resp.Body, &check); err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response", u,
errs.WithMessage("Failed to parse response from /ssh/check-host endpoint"))
return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response",
[]interface{}{u, errs.WithMessage("Failed to parse response from /ssh/check-host endpoint")})
}
return &check, nil
}
@ -1008,13 +1008,13 @@ func (c *Client) SSHBastion(req *api.SSHBastionRequest) (*api.SSHBastionResponse
var retried bool
body, err := json.Marshal(req)
if err != nil {
return nil, errors.Wrap(err, "error marshaling request")
return nil, errors.Wrap(err, "client.SSHBastion; error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/bastion"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u)
return nil, errors.Wrapf(err, "client.SSHBastion; client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
@ -1025,7 +1025,7 @@ retry:
}
var bastion api.SSHBastionResponse
if err := readJSON(resp.Body, &bastion); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
return nil, errors.Wrapf(err, "client.SSHBastion; error reading %s", u)
}
return &bastion, nil
}

View file

@ -16,12 +16,12 @@ import (
"testing"
"time"
"github.com/smallstep/certificates/errs"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/x509util"
"golang.org/x/crypto/ssh"
)
@ -154,18 +154,17 @@ func equalJSON(t *testing.T, a interface{}, b interface{}) bool {
func TestClient_Version(t *testing.T) {
ok := &api.VersionResponse{Version: "test"}
internal := errs.InternalServerError(fmt.Errorf("Internal Server Error"))
notFound := errs.NotFound(fmt.Errorf("Not Found"))
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
expectedErr error
}{
{"ok", ok, 200, false},
{"500", internal, 500, true},
{"404", notFound, 404, true},
{"ok", ok, 200, false, nil},
{"500", errs.InternalServer("force"), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
{"404", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -185,7 +184,6 @@ func TestClient_Version(t *testing.T) {
got, err := c.Version()
if (err != nil) != tt.wantErr {
fmt.Printf("%+v", err)
t.Errorf("Client.Version() error = %v, wantErr %v", err, tt.wantErr)
return
}
@ -195,9 +193,7 @@ func TestClient_Version(t *testing.T) {
if got != nil {
t.Errorf("Client.Version() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Version() error = %v, want %v", err, tt.response)
}
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Version() = %v, want %v", got, tt.response)
@ -209,16 +205,16 @@ func TestClient_Version(t *testing.T) {
func TestClient_Health(t *testing.T) {
ok := &api.HealthResponse{Status: "ok"}
nok := errs.InternalServerError(fmt.Errorf("Internal Server Error"))
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
expectedErr error
}{
{"ok", ok, 200, false},
{"not ok", nok, 500, true},
{"ok", ok, 200, false, nil},
{"not ok", errs.InternalServer("force"), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -248,9 +244,7 @@ func TestClient_Health(t *testing.T) {
if got != nil {
t.Errorf("Client.Health() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Health() error = %v, want %v", err, tt.response)
}
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Health() = %v, want %v", got, tt.response)
@ -264,7 +258,6 @@ func TestClient_Root(t *testing.T) {
ok := &api.RootResponse{
RootPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
}
notFound := errs.NotFound(fmt.Errorf("Not Found"))
tests := []struct {
name string
@ -272,9 +265,10 @@ func TestClient_Root(t *testing.T) {
response interface{}
responseCode int
wantErr bool
expectedErr error
}{
{"ok", "a047a37fa2d2e118a4f5095fe074d6cfe0e352425a7632bf8659c03919a6c81d", ok, 200, false},
{"not found", "invalid", notFound, 404, true},
{"ok", "a047a37fa2d2e118a4f5095fe074d6cfe0e352425a7632bf8659c03919a6c81d", ok, 200, false, nil},
{"not found", "invalid", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -307,9 +301,7 @@ func TestClient_Root(t *testing.T) {
if got != nil {
t.Errorf("Client.Root() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Root() error = %v, want %v", err, tt.response)
}
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Root() = %v, want %v", got, tt.response)
@ -334,8 +326,6 @@ func TestClient_Sign(t *testing.T) {
NotBefore: api.NewTimeDuration(time.Now()),
NotAfter: api.NewTimeDuration(time.Now().AddDate(0, 1, 0)),
}
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
@ -343,11 +333,12 @@ func TestClient_Sign(t *testing.T) {
response interface{}
responseCode int
wantErr bool
expectedErr error
}{
{"ok", request, ok, 200, false},
{"unauthorized", request, unauthorized, 401, true},
{"empty request", &api.SignRequest{}, badRequest, 403, true},
{"nil request", nil, badRequest, 403, true},
{"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -364,7 +355,9 @@ func TestClient_Sign(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.SignRequest)
if err := api.ReadJSON(req.Body, body); err != nil {
api.WriteError(w, badRequest)
e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e)
return
} else if !equalJSON(t, body, tt.request) {
if tt.request == nil {
@ -390,9 +383,7 @@ func TestClient_Sign(t *testing.T) {
if got != nil {
t.Errorf("Client.Sign() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Sign() error = %v, want %v", err, tt.response)
}
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Sign() = %v, want %v", got, tt.response)
@ -409,19 +400,17 @@ func TestClient_Revoke(t *testing.T) {
OTT: "the-ott",
ReasonCode: 4,
}
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
request *api.RevokeRequest
response interface{}
responseCode int
wantErr bool
expectedErr error
}{
{"ok", request, ok, 200, false},
{"unauthorized", request, unauthorized, 401, true},
{"nil request", nil, badRequest, 403, true},
{"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -438,7 +427,9 @@ func TestClient_Revoke(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.RevokeRequest)
if err := api.ReadJSON(req.Body, body); err != nil {
api.WriteError(w, badRequest)
e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e)
return
} else if !equalJSON(t, body, tt.request) {
if tt.request == nil {
@ -464,9 +455,7 @@ func TestClient_Revoke(t *testing.T) {
if got != nil {
t.Errorf("Client.Revoke() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Revoke() error = %v, want %v", err, tt.response)
}
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Revoke() = %v, want %v", got, tt.response)
@ -485,19 +474,18 @@ func TestClient_Renew(t *testing.T) {
{Certificate: parseCertificate(rootPEM)},
},
}
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
err error
}{
{"ok", ok, 200, false},
{"unauthorized", unauthorized, 401, true},
{"empty request", badRequest, 403, true},
{"nil request", badRequest, 403, true},
{"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -527,9 +515,11 @@ func TestClient_Renew(t *testing.T) {
if got != nil {
t.Errorf("Client.Renew() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Renew() error = %v, want %v", err, tt.response)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
@ -543,7 +533,7 @@ func TestClient_Provisioners(t *testing.T) {
ok := &api.ProvisionersResponse{
Provisioners: provisioner.List{},
}
internalServerError := errs.InternalServerError(fmt.Errorf("Internal Server Error"))
internalServerError := errs.InternalServer("Internal Server Error")
tests := []struct {
name string
@ -589,9 +579,7 @@ func TestClient_Provisioners(t *testing.T) {
if got != nil {
t.Errorf("Client.Provisioners() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Provisioners() error = %v, want %v", err, tt.response)
}
assert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response)
@ -605,7 +593,6 @@ func TestClient_ProvisionerKey(t *testing.T) {
ok := &api.ProvisionerKeyResponse{
Key: "an encrypted key",
}
notFound := errs.NotFound(fmt.Errorf("Not Found"))
tests := []struct {
name string
@ -613,9 +600,10 @@ func TestClient_ProvisionerKey(t *testing.T) {
response interface{}
responseCode int
wantErr bool
err error
}{
{"ok", "kid", ok, 200, false},
{"fail", "invalid", notFound, 500, true},
{"ok", "kid", ok, 200, false, nil},
{"fail", "invalid", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -648,9 +636,11 @@ func TestClient_ProvisionerKey(t *testing.T) {
if got != nil {
t.Errorf("Client.ProvisionerKey() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.ProvisionerKey() error = %v, want %v", err, tt.response)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response)
@ -666,19 +656,17 @@ func TestClient_Roots(t *testing.T) {
{Certificate: parseCertificate(rootPEM)},
},
}
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
err error
}{
{"ok", ok, 200, false},
{"unauthorized", unauthorized, 401, true},
{"empty request", badRequest, 403, true},
{"nil request", badRequest, 403, true},
{"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -708,9 +696,10 @@ func TestClient_Roots(t *testing.T) {
if got != nil {
t.Errorf("Client.Roots() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Roots() error = %v, want %v", err, tt.response)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Roots() = %v, want %v", got, tt.response)
@ -726,19 +715,16 @@ func TestClient_Federation(t *testing.T) {
{Certificate: parseCertificate(rootPEM)},
},
}
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
err error
}{
{"ok", ok, 200, false},
{"unauthorized", unauthorized, 401, true},
{"empty request", badRequest, 403, true},
{"nil request", badRequest, 403, true},
{"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -768,9 +754,10 @@ func TestClient_Federation(t *testing.T) {
if got != nil {
t.Errorf("Client.Federation() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Federation() error = %v, want %v", err, tt.response)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Federation() = %v, want %v", got, tt.response)
@ -790,16 +777,16 @@ func TestClient_SSHRoots(t *testing.T) {
HostKeys: []api.SSHPublicKey{{PublicKey: key}},
UserKeys: []api.SSHPublicKey{{PublicKey: key}},
}
notFound := errs.NotFound(fmt.Errorf("Not Found"))
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
err error
}{
{"ok", ok, 200, false},
{"not found", notFound, 404, true},
{"ok", ok, 200, false, nil},
{"not found", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -829,9 +816,10 @@ func TestClient_SSHRoots(t *testing.T) {
if got != nil {
t.Errorf("Client.SSHKeys() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.SSHKeys() error = %v, want %v", err, tt.response)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response)
@ -881,7 +869,7 @@ func Test_parseEndpoint(t *testing.T) {
func TestClient_RootFingerprint(t *testing.T) {
ok := &api.HealthResponse{Status: "ok"}
nok := errs.InternalServerError(fmt.Errorf("Internal Server Error"))
nok := errs.InternalServer("Internal Server Error")
httpsServer := httptest.NewTLSServer(nil)
defer httpsServer.Close()
@ -948,7 +936,6 @@ func TestClient_SSHBastion(t *testing.T) {
Hostname: "bastion.local",
},
}
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
@ -956,11 +943,11 @@ func TestClient_SSHBastion(t *testing.T) {
response interface{}
responseCode int
wantErr bool
err error
}{
{"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false},
{"bad response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true},
{"empty request", &api.SSHBastionRequest{}, badRequest, 403, true},
{"nil request", nil, badRequest, 403, true},
{"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil},
{"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil},
{"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -990,8 +977,11 @@ func TestClient_SSHBastion(t *testing.T) {
if got != nil {
t.Errorf("Client.SSHBastion() = %v, want nil", got)
}
if tt.responseCode != 200 && !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.SSHBastion() error = %v, want %v", err, tt.response)
if tt.responseCode != 200 {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
}
default:
if !reflect.DeepEqual(got, tt.response) {

View file

@ -276,6 +276,7 @@ func TestIdentity_Renew(t *testing.T) {
}
oldIdentityDir := identityDir
identityDir = "testdata/identity"
defer func() {
identityDir = oldIdentityDir
os.RemoveAll(tmpDir)

View file

@ -40,7 +40,7 @@ func (c *mutableTLSConfig) Init(base *tls.Config) {
// tls.Config GetConfigForClient.
func (c *mutableTLSConfig) TLSConfig() (config *tls.Config) {
c.RLock()
config = c.config
config = c.config.Clone()
c.RUnlock()
return
}

View file

@ -80,7 +80,9 @@ func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOption
func (r *TLSRenewer) Run() {
cert := r.getCertificate()
next := r.nextRenewDuration(cert.Leaf.NotAfter)
r.Lock()
r.timer = time.AfterFunc(next, r.renewCertificate)
r.Unlock()
}
// RunContext starts the certificate renewer for the given certificate.

View file

@ -270,6 +270,105 @@ func (db *DB) Shutdown() error {
return nil
}
// MockAuthDB mocks the AuthDB interface. //
type MockAuthDB struct {
Err error
Ret1 interface{}
MIsRevoked func(string) (bool, error)
MIsSSHRevoked func(string) (bool, error)
MRevoke func(rci *RevokedCertificateInfo) error
MRevokeSSH func(rci *RevokedCertificateInfo) error
MStoreCertificate func(crt *x509.Certificate) error
MUseToken func(id, tok string) (bool, error)
MIsSSHHost func(principal string) (bool, error)
MStoreSSHCertificate func(crt *ssh.Certificate) error
MGetSSHHostPrincipals func() ([]string, error)
MShutdown func() error
}
// IsRevoked mock.
func (m *MockAuthDB) IsRevoked(sn string) (bool, error) {
if m.MIsRevoked != nil {
return m.MIsRevoked(sn)
}
return m.Ret1.(bool), m.Err
}
// IsSSHRevoked mock.
func (m *MockAuthDB) IsSSHRevoked(sn string) (bool, error) {
if m.MIsSSHRevoked != nil {
return m.MIsSSHRevoked(sn)
}
return m.Ret1.(bool), m.Err
}
// UseToken mock.
func (m *MockAuthDB) UseToken(id, tok string) (bool, error) {
if m.MUseToken != nil {
return m.MUseToken(id, tok)
}
if m.Ret1 == nil {
return false, m.Err
}
return m.Ret1.(bool), m.Err
}
// Revoke mock.
func (m *MockAuthDB) Revoke(rci *RevokedCertificateInfo) error {
if m.MRevoke != nil {
return m.MRevoke(rci)
}
return m.Err
}
// RevokeSSH mock.
func (m *MockAuthDB) RevokeSSH(rci *RevokedCertificateInfo) error {
if m.MRevokeSSH != nil {
return m.MRevokeSSH(rci)
}
return m.Err
}
// StoreCertificate mock.
func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error {
if m.MStoreCertificate != nil {
return m.MStoreCertificate(crt)
}
return m.Err
}
// IsSSHHost mock.
func (m *MockAuthDB) IsSSHHost(principal string) (bool, error) {
if m.MIsSSHHost != nil {
return m.MIsSSHHost(principal)
}
return m.Ret1.(bool), m.Err
}
// StoreSSHCertificate mock.
func (m *MockAuthDB) StoreSSHCertificate(crt *ssh.Certificate) error {
if m.MStoreSSHCertificate != nil {
return m.MStoreSSHCertificate(crt)
}
return m.Err
}
// GetSSHHostPrincipals mock.
func (m *MockAuthDB) GetSSHHostPrincipals() ([]string, error) {
if m.MGetSSHHostPrincipals != nil {
return m.MGetSSHHostPrincipals()
}
return m.Ret1.([]string), m.Err
}
// Shutdown mock.
func (m *MockAuthDB) Shutdown() error {
if m.MShutdown != nil {
return m.MShutdown()
}
return m.Err
}
// MockNoSQLDB //
type MockNoSQLDB struct {
Err error

View file

@ -21,9 +21,9 @@ type StackTracer interface {
// Option modifies the Error type.
type Option func(e *Error) error
// WithMessage returns an Option that modifies the error by overwriting the
// withDefaultMessage returns an Option that modifies the error by overwriting the
// message only if it is empty.
func WithMessage(format string, args ...interface{}) Option {
func withDefaultMessage(format string, args ...interface{}) Option {
return func(e *Error) error {
if len(e.Msg) > 0 {
return e
@ -33,31 +33,33 @@ func WithMessage(format string, args ...interface{}) Option {
}
}
// Error represents the CA API errors.
type Error struct {
Status int
Err error
Msg string
// WithMessage returns an Option that modifies the error by overwriting the
// message only if it is empty.
func WithMessage(format string, args ...interface{}) Option {
return func(e *Error) error {
e.Msg = fmt.Sprintf(format, args...)
return e
}
}
// New returns a new Error. If the given error implements the StatusCoder
// interface we will ignore the given status.
func New(status int, err error, opts ...Option) error {
var e *Error
if sc, ok := err.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
cause := errors.Cause(err)
if sc, ok := cause.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
e = &Error{Status: status, Err: err}
// WithKeyVal returns an Option that adds the given key-value pair to the
// Error details. This is helpful for debugging errors.
func WithKeyVal(key string, val interface{}) Option {
return func(e *Error) error {
if e.Details == nil {
e.Details = make(map[string]interface{})
}
e.Details[key] = val
return e
}
for _, o := range opts {
o(e)
}
return e
}
// Error represents the CA API errors.
type Error struct {
Status int
Err error
Msg string
Details map[string]interface{}
}
// ErrorResponse represents an error in JSON format.
@ -92,10 +94,11 @@ func (e *Error) Message() string {
// Wrap returns an error annotating err with a stack trace at the point Wrap is
// called, and the supplied message. If err is nil, Wrap returns nil.
func Wrap(status int, e error, m string, opts ...Option) error {
func Wrap(status int, e error, m string, args ...interface{}) error {
if e == nil {
return nil
}
_, opts := splitOptionArgs(args)
if err, ok := e.(*Error); ok {
err.Err = errors.Wrap(err.Err, m)
e = err
@ -111,25 +114,12 @@ func Wrapf(status int, e error, format string, args ...interface{}) error {
if e == nil {
return nil
}
var opts []Option
for i, arg := range args {
// Once we find the first Option, assume that all further arguments are Options.
if _, ok := arg.(Option); ok {
for _, a := range args[i:] {
// Ignore any arguments after the first Option that are not Options.
if opt, ok := a.(Option); ok {
opts = append(opts, opt)
}
}
args = args[:i]
break
}
}
as, opts := splitOptionArgs(args)
if err, ok := e.(*Error); ok {
err.Err = errors.Wrapf(err.Err, format, args...)
e = err
} else {
e = errors.Wrapf(e, format, args...)
e = errors.Wrapf(e, format, as...)
}
return StatusCodeError(status, e, opts...)
}
@ -174,77 +164,172 @@ type Messenger interface {
func StatusCodeError(code int, e error, opts ...Option) error {
switch code {
case http.StatusBadRequest:
return BadRequest(e, opts...)
return BadRequestErr(e, opts...)
case http.StatusUnauthorized:
return Unauthorized(e, opts...)
return UnauthorizedErr(e, opts...)
case http.StatusForbidden:
return Forbidden(e, opts...)
return ForbiddenErr(e, opts...)
case http.StatusInternalServerError:
return InternalServerError(e, opts...)
return InternalServerErr(e, opts...)
case http.StatusNotImplemented:
return NotImplemented(e, opts...)
return NotImplementedErr(e, opts...)
default:
return UnexpectedError(code, e, opts...)
return UnexpectedErr(code, e, opts...)
}
}
var seeLogs = "Please see the certificate authority logs for more info."
var (
seeLogs = "Please see the certificate authority logs for more info."
// BadRequestDefaultMsg 400 default msg
BadRequestDefaultMsg = "The request could not be completed; malformed or missing data" + seeLogs
// UnauthorizedDefaultMsg 401 default msg
UnauthorizedDefaultMsg = "The request lacked necessary authorization to be completed. " + seeLogs
// ForbiddenDefaultMsg 403 default msg
ForbiddenDefaultMsg = "The request was forbidden by the certificate authority. " + seeLogs
// NotFoundDefaultMsg 404 default msg
NotFoundDefaultMsg = "The requested resource could not be found. " + seeLogs
// InternalServerErrorDefaultMsg 500 default msg
InternalServerErrorDefaultMsg = "The certificate authority encountered an Internal Server Error. " + seeLogs
// NotImplementedDefaultMsg 501 default msg
NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs
)
// InternalServerError returns a 500 error with the given error.
func InternalServerError(err error, opts ...Option) error {
if len(opts) == 0 {
opts = append(opts, WithMessage("The certificate authority encountered an Internal Server Error. "+seeLogs))
// splitOptionArgs splits the variadic length args into string formatting args
// and Option(s) to apply to an Error.
func splitOptionArgs(args []interface{}) ([]interface{}, []Option) {
indexOptionStart := -1
for i, a := range args {
if _, ok := a.(Option); ok {
indexOptionStart = i
break
}
}
return New(http.StatusInternalServerError, err, opts...)
if indexOptionStart < 0 {
return args, []Option{}
}
opts := []Option{}
// Ignore any non-Option args that come after the first Option.
for _, o := range args[indexOptionStart:] {
if opt, ok := o.(Option); ok {
opts = append(opts, opt)
}
}
return args[:indexOptionStart], opts
}
// NotImplemented returns a 501 error with the given error.
func NotImplemented(err error, opts ...Option) error {
if len(opts) == 0 {
opts = append(opts, WithMessage("The requested method is not implemented by the certificate authority. "+seeLogs))
// NewErr returns a new Error. If the given error implements the StatusCoder
// interface we will ignore the given status.
func NewErr(status int, err error, opts ...Option) error {
var (
e *Error
ok bool
)
if e, ok = err.(*Error); !ok {
if sc, ok := err.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
cause := errors.Cause(err)
if sc, ok := cause.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
e = &Error{Status: status, Err: err}
}
}
}
return New(http.StatusNotImplemented, err, opts...)
for _, o := range opts {
o(e)
}
return e
}
// BadRequest returns an 400 error with the given error.
func BadRequest(err error, opts ...Option) error {
if len(opts) == 0 {
opts = append(opts, WithMessage("The request could not be completed due to being poorly formatted or "+
"missing critical data. "+seeLogs))
// Errorf creates a new error using the given format and status code.
func Errorf(code int, format string, args ...interface{}) error {
as, opts := splitOptionArgs(args)
opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg))
e := &Error{Status: code, Err: fmt.Errorf(format, as...)}
for _, o := range opts {
o(e)
}
return New(http.StatusBadRequest, err, opts...)
return e
}
// Unauthorized returns an 401 error with the given error.
func Unauthorized(err error, opts ...Option) error {
if len(opts) == 0 {
opts = append(opts, WithMessage("The request lacked necessary authorization to be completed. "+seeLogs))
}
return New(http.StatusUnauthorized, err, opts...)
// InternalServer creates a 500 error with the given format and arguments.
func InternalServer(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(InternalServerErrorDefaultMsg))
return Errorf(http.StatusInternalServerError, format, args...)
}
// Forbidden returns an 403 error with the given error.
func Forbidden(err error, opts ...Option) error {
if len(opts) == 0 {
opts = append(opts, WithMessage("The request was Forbidden by the certificate authority. "+seeLogs))
}
return New(http.StatusForbidden, err, opts...)
// InternalServerErr returns a 500 error with the given error.
func InternalServerErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(InternalServerErrorDefaultMsg))
return NewErr(http.StatusInternalServerError, err, opts...)
}
// NotFound returns an 404 error with the given error.
func NotFound(err error, opts ...Option) error {
if len(opts) == 0 {
opts = append(opts, WithMessage("The requested resource could not be found. "+seeLogs))
}
return New(http.StatusNotFound, err, opts...)
// NotImplemented creates a 501 error with the given format and arguments.
func NotImplemented(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(NotImplementedDefaultMsg))
return Errorf(http.StatusNotImplemented, format, args...)
}
// UnexpectedError will be used when the certificate authority makes an outgoing
// NotImplementedErr returns a 501 error with the given error.
func NotImplementedErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg))
return NewErr(http.StatusNotImplemented, err, opts...)
}
// BadRequest creates a 400 error with the given format and arguments.
func BadRequest(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(BadRequestDefaultMsg))
return Errorf(http.StatusBadRequest, format, args...)
}
// BadRequestErr returns an 400 error with the given error.
func BadRequestErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(BadRequestDefaultMsg))
return NewErr(http.StatusBadRequest, err, opts...)
}
// Unauthorized creates a 401 error with the given format and arguments.
func Unauthorized(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(UnauthorizedDefaultMsg))
return Errorf(http.StatusUnauthorized, format, args...)
}
// UnauthorizedErr returns an 401 error with the given error.
func UnauthorizedErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(UnauthorizedDefaultMsg))
return NewErr(http.StatusUnauthorized, err, opts...)
}
// Forbidden creates a 403 error with the given format and arguments.
func Forbidden(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(ForbiddenDefaultMsg))
return Errorf(http.StatusForbidden, format, args...)
}
// ForbiddenErr returns an 403 error with the given error.
func ForbiddenErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(ForbiddenDefaultMsg))
return NewErr(http.StatusForbidden, err, opts...)
}
// NotFound creates a 404 error with the given format and arguments.
func NotFound(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(NotFoundDefaultMsg))
return Errorf(http.StatusNotFound, format, args...)
}
// NotFoundErr returns an 404 error with the given error.
func NotFoundErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(NotFoundDefaultMsg))
return NewErr(http.StatusNotFound, err, opts...)
}
// UnexpectedErr will be used when the certificate authority makes an outgoing
// request and receives an unhandled status code.
func UnexpectedError(code int, err error, opts ...Option) error {
if len(opts) == 0 {
opts = append(opts, WithMessage("The certificate authority received an "+
"unexpected HTTP status code - '%d'. "+seeLogs, code))
}
return New(code, err, opts...)
func UnexpectedErr(code int, err error, opts ...Option) error {
opts = append(opts, withDefaultMessage("The certificate authority received an "+
"unexpected HTTP status code - '%d'. "+seeLogs, code))
return NewErr(code, err, opts...)
}

View file

@ -1,11 +1,9 @@
package api
package errs
import (
"fmt"
"reflect"
"testing"
"github.com/smallstep/certificates/errs"
)
func TestError_MarshalJSON(t *testing.T) {
@ -24,7 +22,7 @@ func TestError_MarshalJSON(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &errs.Error{
e := &Error{
Status: tt.fields.Status,
Err: tt.fields.Err,
}
@ -47,15 +45,15 @@ func TestError_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
args args
expected *errs.Error
expected *Error
wantErr bool
}{
{"ok", args{[]byte(`{"status":400,"message":"bad request"}`)}, &errs.Error{Status: 400, Err: fmt.Errorf("bad request")}, false},
{"fail", args{[]byte(`{"status":"400","message":"bad request"}`)}, &errs.Error{}, true},
{"ok", args{[]byte(`{"status":400,"message":"bad request"}`)}, &Error{Status: 400, Err: fmt.Errorf("bad request")}, false},
{"fail", args{[]byte(`{"status":"400","message":"bad request"}`)}, &Error{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := new(errs.Error)
e := new(Error)
if err := e.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
t.Errorf("Error.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
}