Merge pull request #161 from smallstep/unittests

Introduce generalized statusCoder errors and loads of ssh unit tests.
This commit is contained in:
Max 2020-01-24 16:16:00 -08:00 committed by GitHub
commit f3f8ee4207
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
88 changed files with 5620 additions and 2544 deletions

View file

@ -63,6 +63,7 @@ issues:
- declaration of "err" shadows declaration at line - declaration of "err" shadows declaration at line
- should have a package comment, unless it's in another file for this package - should have a package comment, unless it's in another file for this package
- error strings should not be capitalized or end with punctuation or a newline - error strings should not be capitalized or end with punctuation or a newline
- Wrapf call needs 1 arg but has 2 args
# golangci.com configuration # golangci.com configuration
# https://github.com/golangci/golangci/wiki/Configuration # https://github.com/golangci/golangci/wiki/Configuration
service: service:

View file

@ -5,7 +5,6 @@ import (
"crypto/dsa" "crypto/dsa"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rsa" "crypto/rsa"
"crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/asn1" "encoding/asn1"
"encoding/base64" "encoding/base64"
@ -209,14 +208,6 @@ type RootResponse struct {
RootPEM Certificate `json:"ca"` RootPEM Certificate `json:"ca"`
} }
// SignRequest is the request body for a certificate signature request.
type SignRequest struct {
CsrPEM CertificateRequest `json:"csr"`
OTT string `json:"ott"`
NotAfter TimeDuration `json:"notAfter"`
NotBefore TimeDuration `json:"notBefore"`
}
// ProvisionersResponse is the response object that returns the list of // ProvisionersResponse is the response object that returns the list of
// provisioners. // provisioners.
type ProvisionersResponse struct { type ProvisionersResponse struct {
@ -230,31 +221,6 @@ type ProvisionerKeyResponse struct {
Key string `json:"key"` Key string `json:"key"`
} }
// Validate checks the fields of the SignRequest and returns nil if they are ok
// or an error if something is wrong.
func (s *SignRequest) Validate() error {
if s.CsrPEM.CertificateRequest == nil {
return errs.BadRequest(errors.New("missing csr"))
}
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
return errs.BadRequest(errors.Wrap(err, "invalid csr"))
}
if s.OTT == "" {
return errs.BadRequest(errors.New("missing ott"))
}
return nil
}
// SignResponse is the response object of the certificate signature request.
type SignResponse struct {
ServerPEM Certificate `json:"crt"`
CaPEM Certificate `json:"ca"`
CertChainPEM []Certificate `json:"certChain"`
TLSOptions *tlsutil.TLSOptions `json:"tlsOptions,omitempty"`
TLS *tls.ConnectionState `json:"-"`
}
// RootsResponse is the response object of the roots request. // RootsResponse is the response object of the roots request.
type RootsResponse struct { type RootsResponse struct {
Certificates []Certificate `json:"crts"` Certificates []Certificate `json:"crts"`
@ -329,7 +295,7 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
// Load root certificate with the // Load root certificate with the
cert, err := h.Authority.Root(sum) cert, err := h.Authority.Root(sum)
if err != nil { if err != nil {
WriteError(w, errs.NotFound(errors.Wrapf(err, "%s was not found", r.RequestURI))) WriteError(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
return return
} }
@ -344,91 +310,17 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
return certChainPEM return certChainPEM
} }
// Sign is an HTTP handler that reads a certificate request and an
// one-time-token (ott) from the body and creates a new certificate with the
// information in the certificate request.
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, err)
return
}
opts := provisioner.Options{
NotBefore: body.NotBefore,
NotAfter: body.NotAfter,
}
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
if err != nil {
WriteError(w, errs.Unauthorized(err))
return
}
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil {
WriteError(w, errs.Forbidden(err))
return
}
certChainPEM := certChainToPEM(certChain)
var caPEM Certificate
if len(certChainPEM) > 0 {
caPEM = certChainPEM[1]
}
logCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0],
CaPEM: caPEM,
CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(),
}, http.StatusCreated)
}
// Renew uses the information of certificate in the TLS connection to create a
// new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest(errors.New("missing peer certificate")))
return
}
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0])
if err != nil {
WriteError(w, errs.Forbidden(err))
return
}
certChainPEM := certChainToPEM(certChain)
var caPEM Certificate
if len(certChainPEM) > 0 {
caPEM = certChainPEM[1]
}
logCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0],
CaPEM: caPEM,
CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(),
}, http.StatusCreated)
}
// Provisioners returns the list of provisioners configured in the authority. // Provisioners returns the list of provisioners configured in the authority.
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := parseCursor(r) cursor, limit, err := parseCursor(r)
if err != nil { if err != nil {
WriteError(w, errs.BadRequest(err)) WriteError(w, errs.BadRequestErr(err))
return return
} }
p, next, err := h.Authority.GetProvisioners(cursor, limit) p, next, err := h.Authority.GetProvisioners(cursor, limit)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
return return
} }
JSON(w, &ProvisionersResponse{ JSON(w, &ProvisionersResponse{
@ -442,7 +334,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
kid := chi.URLParam(r, "kid") kid := chi.URLParam(r, "kid")
key, err := h.Authority.GetEncryptedKey(kid) key, err := h.Authority.GetEncryptedKey(kid)
if err != nil { if err != nil {
WriteError(w, errs.NotFound(err)) WriteError(w, errs.NotFoundErr(err))
return return
} }
JSON(w, &ProvisionerKeyResponse{key}) JSON(w, &ProvisionerKeyResponse{key})
@ -452,7 +344,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots() roots, err := h.Authority.GetRoots()
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }
@ -470,7 +362,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
federated, err := h.Authority.GetFederation() federated, err := h.Authority.GetFederation()
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }

View file

@ -28,6 +28,7 @@ import (
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/sshutil" "github.com/smallstep/certificates/sshutil"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
@ -914,7 +915,7 @@ func Test_caHandler_Renew(t *testing.T) {
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"no tls", nil, nil, nil, nil, http.StatusBadRequest}, {"no tls", nil, nil, nil, nil, http.StatusBadRequest},
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
{"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, {"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
} }
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
@ -934,13 +935,13 @@ func Test_caHandler_Renew(t *testing.T) {
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := ioutil.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Root unexpected error = %v", err) t.Errorf("caHandler.Renew unexpected error = %v", err)
} }
if tt.statusCode < http.StatusBadRequest { if tt.statusCode < http.StatusBadRequest {
if !bytes.Equal(bytes.TrimSpace(body), expected) { if !bytes.Equal(bytes.TrimSpace(body), expected) {
@ -1009,8 +1010,12 @@ func Test_caHandler_Provisioners(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
expectedError400 := []byte(`{"status":400,"message":"Bad Request"}`) expectedError400 := errs.BadRequest("force")
expectedError500 := []byte(`{"status":500,"message":"Internal Server Error"}`) expectedError400Bytes, err := json.Marshal(expectedError400)
assert.FatalError(t, err)
expectedError500 := errs.InternalServer("force")
expectedError500Bytes, err := json.Marshal(expectedError500)
assert.FatalError(t, err)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := &caHandler{ h := &caHandler{
@ -1035,12 +1040,12 @@ func Test_caHandler_Provisioners(t *testing.T) {
} else { } else {
switch tt.statusCode { switch tt.statusCode {
case 400: case 400:
if !bytes.Equal(bytes.TrimSpace(body), expectedError400) { if !bytes.Equal(bytes.TrimSpace(body), expectedError400Bytes) {
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError400) t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError400Bytes)
} }
case 500: case 500:
if !bytes.Equal(bytes.TrimSpace(body), expectedError500) { if !bytes.Equal(bytes.TrimSpace(body), expectedError500Bytes) {
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError500) t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError500Bytes)
} }
default: default:
t.Errorf("caHandler.Provisioner unexpected status code = %d", tt.statusCode) t.Errorf("caHandler.Provisioner unexpected status code = %d", tt.statusCode)
@ -1077,7 +1082,9 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
} }
expected := []byte(`{"key":"` + privKey + `"}`) expected := []byte(`{"key":"` + privKey + `"}`)
expectedError := []byte(`{"status":404,"message":"Not Found"}`) expectedError404 := errs.NotFound("force")
expectedError404Bytes, err := json.Marshal(expectedError404)
assert.FatalError(t, err)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -1101,8 +1108,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expected) t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expected)
} }
} else { } else {
if !bytes.Equal(bytes.TrimSpace(body), expectedError) { if !bytes.Equal(bytes.TrimSpace(body), expectedError404Bytes) {
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError) t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError404Bytes)
} }
} }
}) })

35
api/renew.go Normal file
View file

@ -0,0 +1,35 @@
package api
import (
"net/http"
"github.com/smallstep/certificates/errs"
)
// Renew uses the information of certificate in the TLS connection to create a
// new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing peer certificate"))
return
}
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0])
if err != nil {
WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
return
}
certChainPEM := certChainToPEM(certChain)
var caPEM Certificate
if len(certChainPEM) > 1 {
caPEM = certChainPEM[1]
}
logCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0],
CaPEM: caPEM,
CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(),
}, http.StatusCreated)
}

View file

@ -4,7 +4,6 @@ import (
"context" "context"
"net/http" "net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
@ -30,13 +29,13 @@ type RevokeRequest struct {
// or an error if something is wrong. // or an error if something is wrong.
func (r *RevokeRequest) Validate() (err error) { func (r *RevokeRequest) Validate() (err error) {
if r.Serial == "" { if r.Serial == "" {
return errs.BadRequest(errors.New("missing serial")) return errs.BadRequest("missing serial")
} }
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise { if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
return errs.BadRequest(errors.New("reasonCode out of bounds")) return errs.BadRequest("reasonCode out of bounds")
} }
if !r.Passive { if !r.Passive {
return errs.NotImplemented(errors.New("non-passive revocation not implemented")) return errs.NotImplemented("non-passive revocation not implemented")
} }
return return
@ -50,7 +49,7 @@ func (r *RevokeRequest) Validate() (err error) {
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest var body RevokeRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return return
} }
@ -72,7 +71,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
if len(body.OTT) > 0 { if len(body.OTT) > 0 {
logOtt(w, body.OTT) logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, errs.Unauthorized(err)) WriteError(w, errs.UnauthorizedErr(err))
return return
} }
opts.OTT = body.OTT opts.OTT = body.OTT
@ -81,12 +80,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
// the client certificate Serial Number must match the serial number // the client certificate Serial Number must match the serial number
// being revoked. // being revoked.
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest(errors.New("missing ott or peer certificate"))) WriteError(w, errs.BadRequest("missing ott or peer certificate"))
return return
} }
opts.Crt = r.TLS.PeerCertificates[0] opts.Crt = r.TLS.PeerCertificates[0]
if opts.Crt.SerialNumber.String() != opts.Serial { if opts.Crt.SerialNumber.String() != opts.Serial {
WriteError(w, errs.BadRequest(errors.New("revoke: serial number in mtls certificate different than body"))) WriteError(w, errs.BadRequest("revoke: serial number in mtls certificate different than body"))
return return
} }
// TODO: should probably be checking if the certificate was revoked here. // TODO: should probably be checking if the certificate was revoked here.
@ -97,7 +96,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
} }
if err := h.Authority.Revoke(ctx, opts); err != nil { if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }

View file

@ -190,7 +190,7 @@ func Test_caHandler_Revoke(t *testing.T) {
return nil, nil return nil, nil
}, },
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
return errs.InternalServerError(errors.New("force")) return errs.InternalServer("force")
}, },
}, },
} }

89
api/sign.go Normal file
View file

@ -0,0 +1,89 @@
package api
import (
"crypto/tls"
"net/http"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/tlsutil"
)
// SignRequest is the request body for a certificate signature request.
type SignRequest struct {
CsrPEM CertificateRequest `json:"csr"`
OTT string `json:"ott"`
NotAfter TimeDuration `json:"notAfter"`
NotBefore TimeDuration `json:"notBefore"`
}
// Validate checks the fields of the SignRequest and returns nil if they are ok
// or an error if something is wrong.
func (s *SignRequest) Validate() error {
if s.CsrPEM.CertificateRequest == nil {
return errs.BadRequest("missing csr")
}
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
return errs.Wrap(http.StatusBadRequest, err, "invalid csr")
}
if s.OTT == "" {
return errs.BadRequest("missing ott")
}
return nil
}
// SignResponse is the response object of the certificate signature request.
type SignResponse struct {
ServerPEM Certificate `json:"crt"`
CaPEM Certificate `json:"ca"`
CertChainPEM []Certificate `json:"certChain"`
TLSOptions *tlsutil.TLSOptions `json:"tlsOptions,omitempty"`
TLS *tls.ConnectionState `json:"-"`
}
// Sign is an HTTP handler that reads a certificate request and an
// one-time-token (ott) from the body and creates a new certificate with the
// information in the certificate request.
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, err)
return
}
opts := provisioner.Options{
NotBefore: body.NotBefore,
NotAfter: body.NotAfter,
}
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
if err != nil {
WriteError(w, errs.UnauthorizedErr(err))
return
}
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil {
WriteError(w, errs.ForbiddenErr(err))
return
}
certChainPEM := certChainToPEM(certChain)
var caPEM Certificate
if len(certChainPEM) > 1 {
caPEM = certChainPEM[1]
}
logCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0],
CaPEM: caPEM,
CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(),
}, http.StatusCreated)
}

View file

@ -249,19 +249,19 @@ type SSHBastionResponse struct {
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest var body SSHSignRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err)) WriteError(w, errs.BadRequestErr(err))
return return
} }
publicKey, err := ssh.ParsePublicKey(body.PublicKey) publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil { if err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey"))
return return
} }
@ -269,7 +269,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if body.AddUserPublicKey != nil { if body.AddUserPublicKey != nil {
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
if err != nil { if err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing addUserPublicKey"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing addUserPublicKey"))
return return
} }
} }
@ -282,16 +282,16 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ValidAfter: body.ValidAfter, ValidAfter: body.ValidAfter,
} }
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignSSHMethod) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.Unauthorized(err)) WriteError(w, errs.UnauthorizedErr(err))
return return
} }
cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...) cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }
@ -299,7 +299,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 { if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 {
addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert) addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert)
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }
addUserCertificate = &SSHCertificate{addUserCert} addUserCertificate = &SSHCertificate{addUserCert}
@ -320,12 +320,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.Unauthorized(err)) WriteError(w, errs.UnauthorizedErr(err))
return return
} }
certChain, err := h.Authority.Sign(cr, opts, signOpts...) certChain, err := h.Authority.Sign(cr, opts, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }
identityCertificate = certChainToPEM(certChain) identityCertificate = certChainToPEM(certChain)
@ -343,12 +343,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHRoots() keys, err := h.Authority.GetSSHRoots()
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
return return
} }
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, errs.NotFound(errors.New("no keys found"))) WriteError(w, errs.NotFound("no keys found"))
return return
} }
@ -368,12 +368,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHFederation() keys, err := h.Authority.GetSSHFederation()
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
return return
} }
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, errs.NotFound(errors.New("no keys found"))) WriteError(w, errs.NotFound("no keys found"))
return return
} }
@ -393,17 +393,17 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest var body SSHConfigRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err)) WriteError(w, errs.BadRequestErr(err))
return return
} }
ts, err := h.Authority.GetSSHConfig(body.Type, body.Data) ts, err := h.Authority.GetSSHConfig(body.Type, body.Data)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
return return
} }
@ -414,7 +414,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
case provisioner.SSHHostCert: case provisioner.SSHHostCert:
config.HostTemplates = ts config.HostTemplates = ts
default: default:
WriteError(w, errs.InternalServerError(errors.New("it should hot get here"))) WriteError(w, errs.InternalServer("it should hot get here"))
return return
} }
@ -429,13 +429,13 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err)) WriteError(w, errs.BadRequestErr(err))
return return
} }
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token) exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
return return
} }
JSON(w, &SSHCheckPrincipalResponse{ JSON(w, &SSHCheckPrincipalResponse{
@ -452,7 +452,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
hosts, err := h.Authority.GetSSHHosts(cert) hosts, err := h.Authority.GetSSHHosts(cert)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
return return
} }
JSON(w, &SSHGetHostsResponse{ JSON(w, &SSHGetHostsResponse{
@ -464,17 +464,17 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest var body SSHBastionRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err)) WriteError(w, errs.BadRequestErr(err))
return return
} }
bastion, err := h.Authority.GetSSHBastion(body.User, body.Hostname) bastion, err := h.Authority.GetSSHBastion(body.User, body.Hostname)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
return return
} }

View file

@ -40,42 +40,42 @@ type SSHRekeyResponse struct {
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest var body SSHRekeyRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err)) WriteError(w, errs.BadRequestErr(err))
return return
} }
publicKey, err := ssh.ParsePublicKey(body.PublicKey) publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil { if err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey"))
return return
} }
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RekeySSHMethod) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.Unauthorized(err)) WriteError(w, errs.UnauthorizedErr(err))
return return
} }
oldCert, err := provisioner.ExtractSSHPOPCert(body.OTT) oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
} }
newCert, err := h.Authority.RekeySSH(oldCert, publicKey, signOpts...) newCert, err := h.Authority.RekeySSH(oldCert, publicKey, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }
identity, err := h.renewIdentityCertificate(r) identity, err := h.renewIdentityCertificate(r)
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }

View file

@ -36,36 +36,36 @@ type SSHRenewResponse struct {
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest var body SSHRenewRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err)) WriteError(w, errs.BadRequestErr(err))
return return
} }
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RenewSSHMethod) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod)
_, err := h.Authority.Authorize(ctx, body.OTT) _, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.Unauthorized(err)) WriteError(w, errs.UnauthorizedErr(err))
return return
} }
oldCert, err := provisioner.ExtractSSHPOPCert(body.OTT) oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerError(err)) WriteError(w, errs.InternalServerErr(err))
} }
newCert, err := h.Authority.RenewSSH(oldCert) newCert, err := h.Authority.RenewSSH(oldCert)
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }
identity, err := h.renewIdentityCertificate(r) identity, err := h.renewIdentityCertificate(r)
if err != nil { if err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }

View file

@ -4,7 +4,6 @@ import (
"context" "context"
"net/http" "net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
@ -30,16 +29,16 @@ type SSHRevokeRequest struct {
// or an error if something is wrong. // or an error if something is wrong.
func (r *SSHRevokeRequest) Validate() (err error) { func (r *SSHRevokeRequest) Validate() (err error) {
if r.Serial == "" { if r.Serial == "" {
return errs.BadRequest(errors.New("missing serial")) return errs.BadRequest("missing serial")
} }
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise { if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
return errs.BadRequest(errors.New("reasonCode out of bounds")) return errs.BadRequest("reasonCode out of bounds")
} }
if !r.Passive { if !r.Passive {
return errs.NotImplemented(errors.New("non-passive revocation not implemented")) return errs.NotImplemented("non-passive revocation not implemented")
} }
if len(r.OTT) == 0 { if len(r.OTT) == 0 {
return errs.BadRequest(errors.New("missing ott")) return errs.BadRequest("missing ott")
} }
return return
} }
@ -50,7 +49,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest var body SSHRevokeRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return return
} }
@ -66,18 +65,18 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
PassiveOnly: body.Passive, PassiveOnly: body.Passive,
} }
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeSSHMethod) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod)
// A token indicates that we are using the api via a provisioner token, // A token indicates that we are using the api via a provisioner token,
// otherwise it is assumed that the certificate is revoking itself over mTLS. // otherwise it is assumed that the certificate is revoking itself over mTLS.
logOtt(w, body.OTT) logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, errs.Unauthorized(err)) WriteError(w, errs.UnauthorizedErr(err))
return return
} }
opts.OTT = body.OTT opts.OTT = body.OTT
if err := h.Authority.Revoke(ctx, opts); err != nil { if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.Forbidden(err)) WriteError(w, errs.ForbiddenErr(err))
return return
} }

View file

@ -6,7 +6,6 @@ import (
"log" "log"
"net/http" "net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
) )
@ -69,7 +68,7 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
// pointed by v. // pointed by v.
func ReadJSON(r io.Reader, v interface{}) error { func ReadJSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil { if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.BadRequest(errors.Wrap(err, "error decoding json")) return errs.Wrap(http.StatusBadRequest, err, "error decoding json")
} }
return nil return nil
} }

View file

@ -13,12 +13,13 @@ import (
stepJOSE "github.com/smallstep/cli/jose" stepJOSE "github.com/smallstep/cli/jose"
) )
func testAuthority(t *testing.T) *Authority { func testAuthority(t *testing.T, opts ...Option) *Authority {
maxjwk, err := stepJOSE.ParseKey("testdata/secrets/max_pub.jwk") maxjwk, err := stepJOSE.ParseKey("testdata/secrets/max_pub.jwk")
assert.FatalError(t, err) assert.FatalError(t, err)
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk") clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
assert.FatalError(t, err) assert.FatalError(t, err)
disableRenewal := true disableRenewal := true
enableSSHCA := true
p := provisioner.List{ p := provisioner.List{
&provisioner.JWK{ &provisioner.JWK{
Name: "Max", Name: "Max",
@ -29,6 +30,9 @@ func testAuthority(t *testing.T) *Authority {
Name: "step-cli", Name: "step-cli",
Type: "JWK", Type: "JWK",
Key: clijwk, Key: clijwk,
Claims: &provisioner.Claims{
EnableSSHCA: &enableSSHCA,
},
}, },
&provisioner.JWK{ &provisioner.JWK{
Name: "dev", Name: "dev",
@ -46,19 +50,30 @@ func testAuthority(t *testing.T) *Authority {
DisableRenewal: &disableRenewal, DisableRenewal: &disableRenewal,
}, },
}, },
&provisioner.SSHPOP{
Name: "sshpop",
Type: "SSHPOP",
Claims: &provisioner.Claims{
EnableSSHCA: &enableSSHCA,
},
},
} }
c := &Config{ c := &Config{
Address: "127.0.0.1:443", Address: "127.0.0.1:443",
Root: []string{"testdata/certs/root_ca.crt"}, Root: []string{"testdata/certs/root_ca.crt"},
IntermediateCert: "testdata/certs/intermediate_ca.crt", IntermediateCert: "testdata/certs/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key", IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.ca.smallstep.com"}, SSH: &SSHConfig{
HostKey: "testdata/secrets/ssh_host_ca_key",
UserKey: "testdata/secrets/ssh_user_ca_key",
},
DNSNames: []string{"example.com"},
Password: "pass", Password: "pass",
AuthorityConfig: &AuthConfig{ AuthorityConfig: &AuthConfig{
Provisioners: p, Provisioners: p,
}, },
} }
a, err := New(c) a, err := New(c, opts...)
assert.FatalError(t, err) assert.FatalError(t, err)
return a return a
} }

View file

@ -6,9 +6,10 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
"golang.org/x/crypto/ssh"
) )
// Claims extends jose.Claims with step attributes. // Claims extends jose.Claims with step attributes.
@ -36,22 +37,19 @@ func SkipTokenReuseFromContext(ctx context.Context) bool {
// authorizeToken parses the token and returns the provisioner used to generate // authorizeToken parses the token and returns the provisioner used to generate
// the token. This method enforces the One-Time use policy (tokens can only be // the token. This method enforces the One-Time use policy (tokens can only be
// used once). // used once).
func (a *Authority) authorizeToken(ctx context.Context, ott string) (provisioner.Interface, error) { func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) {
var errContext = map[string]interface{}{"ott": ott}
// Validate payload // Validate payload
token, err := jose.ParseSigned(ott) tok, err := jose.ParseSigned(token)
if err != nil { if err != nil {
return nil, &apiError{errors.Wrapf(err, "authorizeToken: error parsing token"), return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken: error parsing token")
http.StatusUnauthorized, errContext}
} }
// Get claims w/out verification. We need to look up the provisioner // Get claims w/out verification. We need to look up the provisioner
// key in order to verify the claims and we need the issuer from the claims // key in order to verify the claims and we need the issuer from the claims
// before we can look up the provisioner. // before we can look up the provisioner.
var claims Claims var claims Claims
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil { if err = tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeToken"), http.StatusUnauthorized, errContext} return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken")
} }
// TODO: use new persistence layer abstraction. // TODO: use new persistence layer abstraction.
@ -59,29 +57,27 @@ func (a *Authority) authorizeToken(ctx context.Context, ott string) (provisioner
// This check is meant as a stopgap solution to the current lack of a persistence layer. // This check is meant as a stopgap solution to the current lack of a persistence layer.
if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck { if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck {
if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) { if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) {
return nil, &apiError{errors.New("authorizeToken: token issued before the bootstrap of certificate authority"), return nil, errs.Unauthorized("authority.authorizeToken: token issued before the bootstrap of certificate authority")
http.StatusUnauthorized, errContext}
} }
} }
// This method will also validate the audiences for JWK provisioners. // This method will also validate the audiences for JWK provisioners.
p, ok := a.provisioners.LoadByToken(token, &claims.Claims) p, ok := a.provisioners.LoadByToken(tok, &claims.Claims)
if !ok { if !ok {
return nil, &apiError{ return nil, errs.Unauthorized("authority.authorizeToken: provisioner "+
errors.Errorf("authorizeToken: provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")), "not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))
http.StatusUnauthorized, errContext}
} }
// Store the token to protect against reuse unless it's skipped. // Store the token to protect against reuse unless it's skipped.
if !SkipTokenReuseFromContext(ctx) { if !SkipTokenReuseFromContext(ctx) {
if reuseKey, err := p.GetTokenID(ott); err == nil { if reuseKey, err := p.GetTokenID(token); err == nil {
ok, err := a.db.UseToken(reuseKey, ott) ok, err := a.db.UseToken(reuseKey, token)
if err != nil { if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeToken: failed when attempting to store token"), return nil, errs.Wrap(http.StatusInternalServerError, err,
http.StatusInternalServerError, errContext} "authority.authorizeToken: failed when attempting to store token")
} }
if !ok { if !ok {
return nil, &apiError{errors.Errorf("authorizeToken: token already used"), http.StatusUnauthorized, errContext} return nil, errs.Unauthorized("authority.authorizeToken: token already used")
} }
} }
} }
@ -89,125 +85,158 @@ func (a *Authority) authorizeToken(ctx context.Context, ott string) (provisioner
return p, nil return p, nil
} }
// Authorize grabs the method from the context and authorizes a signature // Authorize grabs the method from the context and authorizes the request by
// request by validating the one-time-token. // validating the one-time-token.
func (a *Authority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.SignOption, error) {
var errContext = apiCtx{"ott": ott} var opts = []interface{}{errs.WithKeyVal("token", token)}
switch m := provisioner.MethodFromContext(ctx); m { switch m := provisioner.MethodFromContext(ctx); m {
case provisioner.SignMethod: case provisioner.SignMethod:
return a.authorizeSign(ctx, ott) signOpts, err := a.authorizeSign(ctx, token)
return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
case provisioner.RevokeMethod: case provisioner.RevokeMethod:
return nil, a.authorizeRevoke(ctx, ott) return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeRevoke(ctx, token), "authority.Authorize", opts...)
case provisioner.SignSSHMethod: case provisioner.SSHSignMethod:
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil { if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
return nil, &apiError{errors.New("authorize: ssh signing is not enabled"), http.StatusNotImplemented, errContext} return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
} }
return a.authorizeSSHSign(ctx, ott) _, err := a.authorizeSSHSign(ctx, token)
case provisioner.RenewSSHMethod: return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
case provisioner.SSHRenewMethod:
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil { if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
return nil, &apiError{errors.New("authorize: ssh signing is not enabled"), http.StatusNotImplemented, errContext} return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
} }
if _, err := a.authorizeSSHRenew(ctx, ott); err != nil { _, err := a.authorizeSSHRenew(ctx, token)
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
} case provisioner.SSHRevokeMethod:
return nil, nil return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeSSHRevoke(ctx, token), "authority.Authorize", opts...)
case provisioner.RevokeSSHMethod: case provisioner.SSHRekeyMethod:
return nil, a.authorizeSSHRevoke(ctx, ott)
case provisioner.RekeySSHMethod:
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil { if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
return nil, &apiError{errors.New("authorize: ssh signing is not enabled"), http.StatusNotImplemented, errContext} return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
} }
_, opts, err := a.authorizeSSHRekey(ctx, ott) _, signOpts, err := a.authorizeSSHRekey(ctx, token)
if err != nil { return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
return nil, err
}
return opts, nil
default: default:
return nil, &apiError{errors.Errorf("authorize: method %d is not supported", m), http.StatusInternalServerError, errContext} return nil, errs.InternalServer("authority.Authorize; method %d is not supported", append([]interface{}{m}, opts...)...)
} }
} }
// authorizeSign loads the provisioner from the token, checks that it has not // authorizeSign loads the provisioner from the token and calls the provisioner
// been used again and calls the provisioner AuthorizeSign method. Returns a // AuthorizeSign method. Returns a list of methods to apply to the signing flow.
// list of methods to apply to the signing flow. func (a *Authority) authorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) {
func (a *Authority) authorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) { p, err := a.authorizeToken(ctx, token)
var errContext = apiCtx{"ott": ott}
p, err := a.authorizeToken(ctx, ott)
if err != nil { if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext} return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSign")
} }
opts, err := p.AuthorizeSign(ctx, ott) signOpts, err := p.AuthorizeSign(ctx, token)
if err != nil { if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext} return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSign")
} }
return opts, nil return signOpts, nil
} }
// AuthorizeSign authorizes a signature request by validating and authenticating // AuthorizeSign authorizes a signature request by validating and authenticating
// a OTT that must be sent w/ the request. // a token that must be sent w/ the request.
// //
// NOTE: This method is deprecated and should not be used. We make it available // NOTE: This method is deprecated and should not be used. We make it available
// in the short term os as not to break existing clients. // in the short term os as not to break existing clients.
func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { func (a *Authority) AuthorizeSign(token string) ([]provisioner.SignOption, error) {
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
return a.Authorize(ctx, ott) return a.Authorize(ctx, token)
} }
// authorizeRevoke authorizes a revocation request by validating and authenticating // authorizeRevoke locates the provisioner used to generate the authenticating
// the RevokeOptions POSTed with the request. // token and then performs the token validation flow.
// Returns a tuple of the provisioner ID and error, if one occurred.
func (a *Authority) authorizeRevoke(ctx context.Context, token string) error { func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
errContext := map[string]interface{}{"ott": token}
p, err := a.authorizeToken(ctx, token) p, err := a.authorizeToken(ctx, token)
if err != nil { if err != nil {
return &apiError{errors.Wrap(err, "authorizeRevoke"), http.StatusUnauthorized, errContext} return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
} }
if err = p.AuthorizeRevoke(ctx, token); err != nil { if err = p.AuthorizeRevoke(ctx, token); err != nil {
return &apiError{errors.Wrap(err, "authorizeRevoke"), http.StatusUnauthorized, errContext} return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
} }
return nil return nil
} }
// authorizeRenewl tries to locate the step provisioner extension, and checks // authorizeRenew locates the provisioner (using the provisioner extension in the cert), and checks
// if for the configured provisioner, the renewal is enabled or not. If the // if for the configured provisioner, the renewal is enabled or not. If the
// extra extension cannot be found, authorize the renewal by default. // extra extension cannot be found, authorize the renewal by default.
// //
// TODO(mariano): should we authorize by default? // TODO(mariano): should we authorize by default?
func (a *Authority) authorizeRenew(crt *x509.Certificate) error { func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
errContext := map[string]interface{}{"serialNumber": crt.SerialNumber.String()} var opts = []interface{}{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())}
// Check the passive revocation table. // Check the passive revocation table.
isRevoked, err := a.db.IsRevoked(crt.SerialNumber.String()) isRevoked, err := a.db.IsRevoked(cert.SerialNumber.String())
if err != nil { if err != nil {
return &apiError{ return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
err: errors.Wrap(err, "renew"),
code: http.StatusInternalServerError,
context: errContext,
}
} }
if isRevoked { if isRevoked {
return &apiError{ return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
err: errors.New("renew: certificate has been revoked"),
code: http.StatusUnauthorized,
context: errContext,
}
} }
p, ok := a.provisioners.LoadByCertificate(crt) p, ok := a.provisioners.LoadByCertificate(cert)
if !ok { if !ok {
return &apiError{ return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
err: errors.New("renew: provisioner not found"),
code: http.StatusUnauthorized,
context: errContext,
}
}
if err := p.AuthorizeRenew(context.Background(), crt); err != nil {
return &apiError{
err: errors.Wrap(err, "renew"),
code: http.StatusUnauthorized,
context: errContext,
} }
if err := p.AuthorizeRenew(context.Background(), cert); err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
}
return nil
}
// authorizeSSHSign loads the provisioner from the token, checks that it has not
// been used again and calls the provisioner AuthorizeSSHSign method. Returns a
// list of methods to apply to the signing flow.
func (a *Authority) authorizeSSHSign(ctx context.Context, token string) ([]provisioner.SignOption, error) {
p, err := a.authorizeToken(ctx, token)
if err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeSSHSign")
}
signOpts, err := p.AuthorizeSSHSign(ctx, token)
if err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeSSHSign")
}
return signOpts, nil
}
// authorizeSSHRenew authorizes an SSH certificate renewal request, by
// validating the contents of an SSHPOP token.
func (a *Authority) authorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
p, err := a.authorizeToken(ctx, token)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRenew")
}
cert, err := p.AuthorizeSSHRenew(ctx, token)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRenew")
}
return cert, nil
}
// authorizeSSHRekey authorizes an SSH certificate rekey request, by
// validating the contents of an SSHPOP token.
func (a *Authority) authorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []provisioner.SignOption, error) {
p, err := a.authorizeToken(ctx, token)
if err != nil {
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRekey")
}
cert, signOpts, err := p.AuthorizeSSHRekey(ctx, token)
if err != nil {
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRekey")
}
return cert, signOpts, nil
}
// authorizeSSHRevoke authorizes an SSH certificate revoke request, by
// validating the contents of an SSHPOP token.
func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error {
p, err := a.authorizeToken(ctx, token)
if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRevoke")
}
if err = p.AuthorizeSSHRevoke(ctx, token); err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRevoke")
} }
return nil return nil
} }

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,7 @@
package authority package authority
import ( import (
"fmt"
"testing" "testing"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -9,7 +10,6 @@ import (
"github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/crypto/x509util" "github.com/smallstep/cli/crypto/x509util"
stepJOSE "github.com/smallstep/cli/jose" stepJOSE "github.com/smallstep/cli/jose"
jose "gopkg.in/square/go-jose.v2"
) )
func TestConfigValidate(t *testing.T) { func TestConfigValidate(t *testing.T) {
@ -255,28 +255,6 @@ func TestAuthConfigValidate(t *testing.T) {
err: errors.New("authority cannot be undefined"), err: errors.New("authority cannot be undefined"),
} }
}, },
"fail-invalid-provisioners": func(t *testing.T) AuthConfigValidateTest {
return AuthConfigValidateTest{
ac: &AuthConfig{
Provisioners: provisioner.List{
&provisioner.JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
&provisioner.JWK{Name: "foo", Key: &jose.JSONWebKey{}},
},
},
err: errors.New("provisioner type cannot be empty"),
}
},
"fail-invalid-claims": func(t *testing.T) AuthConfigValidateTest {
return AuthConfigValidateTest{
ac: &AuthConfig{
Provisioners: p,
Claims: &provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: -1},
},
},
err: errors.New("claims: MinTLSCertDuration must be greater than 0"),
}
},
"ok-empty-provisioners": func(t *testing.T) AuthConfigValidateTest { "ok-empty-provisioners": func(t *testing.T) AuthConfigValidateTest {
return AuthConfigValidateTest{ return AuthConfigValidateTest{
ac: &AuthConfig{}, ac: &AuthConfig{},
@ -311,7 +289,7 @@ func TestAuthConfigValidate(t *testing.T) {
assert.Equals(t, tc.err.Error(), err.Error()) assert.Equals(t, tc.err.Error(), err.Error())
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err, fmt.Sprintf("expected error: %s, but got <nil>", tc.err)) {
assert.Equals(t, *tc.ac.Template, tc.asn1dn) assert.Equals(t, *tc.ac.Template, tc.asn1dn)
} }
} }

View file

@ -1,96 +0,0 @@
package authority
import (
"crypto/x509"
"github.com/smallstep/certificates/db"
"golang.org/x/crypto/ssh"
)
type MockAuthDB struct {
err error
ret1 interface{}
isRevoked func(string) (bool, error)
isSSHRevoked func(string) (bool, error)
revoke func(rci *db.RevokedCertificateInfo) error
revokeSSH func(rci *db.RevokedCertificateInfo) error
storeCertificate func(crt *x509.Certificate) error
useToken func(id, tok string) (bool, error)
isSSHHost func(principal string) (bool, error)
storeSSHCertificate func(crt *ssh.Certificate) error
getSSHHostPrincipals func() ([]string, error)
shutdown func() error
}
func (m *MockAuthDB) IsRevoked(sn string) (bool, error) {
if m.isRevoked != nil {
return m.isRevoked(sn)
}
return m.ret1.(bool), m.err
}
func (m *MockAuthDB) IsSSHRevoked(sn string) (bool, error) {
if m.isSSHRevoked != nil {
return m.isSSHRevoked(sn)
}
return m.ret1.(bool), m.err
}
func (m *MockAuthDB) UseToken(id, tok string) (bool, error) {
if m.useToken != nil {
return m.useToken(id, tok)
}
if m.ret1 == nil {
return false, m.err
}
return m.ret1.(bool), m.err
}
func (m *MockAuthDB) Revoke(rci *db.RevokedCertificateInfo) error {
if m.revoke != nil {
return m.revoke(rci)
}
return m.err
}
func (m *MockAuthDB) RevokeSSH(rci *db.RevokedCertificateInfo) error {
if m.revokeSSH != nil {
return m.revokeSSH(rci)
}
return m.err
}
func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error {
if m.storeCertificate != nil {
return m.storeCertificate(crt)
}
return m.err
}
func (m *MockAuthDB) IsSSHHost(principal string) (bool, error) {
if m.isSSHHost != nil {
return m.isSSHHost(principal)
}
return m.ret1.(bool), m.err
}
func (m *MockAuthDB) StoreSSHCertificate(crt *ssh.Certificate) error {
if m.storeSSHCertificate != nil {
return m.storeSSHCertificate(crt)
}
return m.err
}
func (m *MockAuthDB) GetSSHHostPrincipals() ([]string, error) {
if m.getSSHHostPrincipals != nil {
return m.getSSHHostPrincipals()
}
return m.ret1.([]string), m.err
}
func (m *MockAuthDB) Shutdown() error {
if m.shutdown != nil {
return m.shutdown()
}
return m.err
}

View file

@ -1,67 +0,0 @@
package authority
import (
"encoding/json"
"fmt"
"net/http"
)
type apiCtx map[string]interface{}
// Error implements the api.Error interface and adds context to error messages.
type apiError struct {
err error
code int
context apiCtx
}
// Cause implements the errors.Causer interface and returns the original error.
func (e *apiError) Cause() error {
return e.err
}
// Error returns an error message with additional context.
func (e *apiError) Error() string {
ret := e.err.Error()
/*
if len(e.context) > 0 {
ret += "\n\nContext:"
for k, v := range e.context {
ret += fmt.Sprintf("\n %s: %v", k, v)
}
}
*/
return ret
}
// ErrorResponse represents an error in JSON format.
type ErrorResponse struct {
Status int `json:"status"`
Message string `json:"message"`
}
// StatusCode returns an http status code indicating the type and severity of
// the error.
func (e *apiError) StatusCode() int {
if e.code == 0 {
return http.StatusInternalServerError
}
return e.code
}
// MarshalJSON implements json.Marshaller interface for the Error struct.
func (e *apiError) MarshalJSON() ([]byte, error) {
return json.Marshal(&ErrorResponse{Status: e.code, Message: http.StatusText(e.code)})
}
// UnmarshalJSON implements json.Unmarshaler interface for the Error struct.
func (e *apiError) UnmarshalJSON(data []byte) error {
var er ErrorResponse
if err := json.Unmarshal(data, &er); err != nil {
return err
}
e.code = er.Status
e.err = fmt.Errorf(er.Message)
return nil
}

View file

@ -5,6 +5,7 @@ import (
"crypto/x509" "crypto/x509"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
) )
// ACME is the acme provisioner type, an entity that can authorize the ACME // ACME is the acme provisioner type, an entity that can authorize the ACME
@ -79,7 +80,7 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID()) return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID())
} }
return nil return nil
} }

View file

@ -3,11 +3,13 @@ package provisioner
import ( import (
"context" "context"
"crypto/x509" "crypto/x509"
"net/http"
"testing" "testing"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
) )
func TestACME_Getters(t *testing.T) { func TestACME_Getters(t *testing.T) {
@ -88,86 +90,98 @@ func TestACME_Init(t *testing.T) {
} }
} }
func TestACME_AuthorizeRevoke(t *testing.T) { func TestACME_AuthorizeRenew(t *testing.T) {
type test struct {
p *ACME
cert *x509.Certificate
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/renew-disabled": func(t *testing.T) test {
p, err := generateACME() p, err := generateACME()
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Nil(t, p.AuthorizeRevoke(context.TODO(), ""))
}
func TestACME_AuthorizeRenew(t *testing.T) {
p1, err := generateACME()
assert.FatalError(t, err)
p2, err := generateACME()
assert.FatalError(t, err)
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{DisableRenewal: &disable} p.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{
type args struct { p: p,
cert *x509.Certificate cert: &x509.Certificate{},
code: http.StatusUnauthorized,
err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID()),
} }
tests := []struct { },
name string "ok": func(t *testing.T) test {
prov *ACME p, err := generateACME()
args args assert.FatalError(t, err)
err error return test{
}{ p: p,
{"ok", p1, args{nil}, nil}, cert: &x509.Certificate{},
{"fail", p2, args{nil}, errors.Errorf("renew is disabled for provisioner %s", p2.GetID())},
} }
for _, tt := range tests { },
t.Run(tt.name, func(t *testing.T) { }
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); err != nil { for name, tt := range tests {
if assert.NotNil(t, tt.err) { t.Run(name, func(t *testing.T) {
assert.HasPrefix(t, err.Error(), tt.err.Error()) tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else {
assert.Nil(t, tt.err) assert.Nil(t, tc.err)
} }
}) })
} }
} }
func TestACME_AuthorizeSign(t *testing.T) { func TestACME_AuthorizeSign(t *testing.T) {
p1, err := generateACME() type test struct {
assert.FatalError(t, err) p *ACME
token string
tests := []struct { code int
name string
prov *ACME
method Method
err error err error
}{
{"fail/method", p1, SignSSHMethod, errors.New("unexpected method type 1 in context")},
{"ok", p1, SignMethod, nil},
} }
for _, tt := range tests { tests := map[string]func(*testing.T) test{
t.Run(tt.name, func(t *testing.T) { "ok": func(t *testing.T) test {
ctx := NewContextWithMethod(context.Background(), tt.method) p, err := generateACME()
if got, err := tt.prov.AuthorizeSign(ctx, ""); err != nil { assert.FatalError(t, err)
if assert.NotNil(t, tt.err) { return test{
assert.HasPrefix(t, err.Error(), tt.err.Error()) p: p,
token: "foo",
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else {
if assert.NotNil(t, got) { if assert.Nil(t, tc.err) && assert.NotNil(t, opts) {
assert.Len(t, 4, got) assert.Len(t, 4, opts)
for _, o := range opts {
for _, o := range got {
switch v := o.(type) { switch v := o.(type) {
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeACME)) assert.Equals(t, v.Type, int(TypeACME))
assert.Equals(t, v.Name, tt.prov.GetName()) assert.Equals(t, v.Name, tc.p.GetName())
assert.Equals(t, v.CredentialID, "") assert.Equals(t, v.CredentialID, "")
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration())
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration())
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
} }

View file

@ -16,6 +16,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
) )
@ -271,7 +272,7 @@ func (p *AWS) Init(config Config) (err error) {
func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
payload, err := p.authorizeToken(token) payload, err := p.authorizeToken(token)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSign")
} }
doc := payload.document doc := payload.document
@ -305,7 +306,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID()) return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner %s", p.GetID())
} }
return nil return nil
} }
@ -349,41 +350,41 @@ func (p *AWS) readURL(url string) ([]byte, error) {
func (p *AWS) authorizeToken(token string) (*awsPayload, error) { func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
jwt, err := jose.ParseSigned(token) jwt, err := jose.ParseSigned(token)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error parsing token") return nil, errs.Wrapf(http.StatusUnauthorized, err, "aws.authorizeToken; error parsing aws token")
} }
if len(jwt.Headers) == 0 { if len(jwt.Headers) == 0 {
return nil, errors.New("error parsing token: header is missing") return nil, errs.InternalServer("aws.authorizeToken; error parsing token, header is missing")
} }
var unsafeClaims awsPayload var unsafeClaims awsPayload
if err := jwt.UnsafeClaimsWithoutVerification(&unsafeClaims); err != nil { if err := jwt.UnsafeClaimsWithoutVerification(&unsafeClaims); err != nil {
return nil, errors.Wrap(err, "error unmarshaling claims") return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; error unmarshaling claims")
} }
var payload awsPayload var payload awsPayload
if err := jwt.Claims(unsafeClaims.Amazon.Signature, &payload); err != nil { if err := jwt.Claims(unsafeClaims.Amazon.Signature, &payload); err != nil {
return nil, errors.Wrap(err, "error verifying claims") return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; error verifying claims")
} }
// Validate identity document signature // Validate identity document signature
if err := p.checkSignature(payload.Amazon.Document, payload.Amazon.Signature); err != nil { if err := p.checkSignature(payload.Amazon.Document, payload.Amazon.Signature); err != nil {
return nil, err return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; invalid aws token signature")
} }
var doc awsInstanceIdentityDocument var doc awsInstanceIdentityDocument
if err := json.Unmarshal(payload.Amazon.Document, &doc); err != nil { if err := json.Unmarshal(payload.Amazon.Document, &doc); err != nil {
return nil, errors.Wrap(err, "error unmarshaling identity document") return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; error unmarshaling aws identity document")
} }
switch { switch {
case doc.AccountID == "": case doc.AccountID == "":
return nil, errors.New("identity document accountId cannot be empty") return nil, errs.Unauthorized("aws.authorizeToken; aws identity document accountId cannot be empty")
case doc.InstanceID == "": case doc.InstanceID == "":
return nil, errors.New("identity document instanceId cannot be empty") return nil, errs.Unauthorized("aws.authorizeToken; aws identity document instanceId cannot be empty")
case doc.PrivateIP == "": case doc.PrivateIP == "":
return nil, errors.New("identity document privateIp cannot be empty") return nil, errs.Unauthorized("aws.authorizeToken; aws identity document privateIp cannot be empty")
case doc.Region == "": case doc.Region == "":
return nil, errors.New("identity document region cannot be empty") return nil, errs.Unauthorized("aws.authorizeToken; aws identity document region cannot be empty")
} }
// According to "rfc7519 JSON Web Token" acceptable skew should be no // According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -393,12 +394,12 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
Issuer: awsIssuer, Issuer: awsIssuer,
Time: now, Time: now,
}, time.Minute); err != nil { }, time.Minute); err != nil {
return nil, errors.Wrapf(err, "invalid token") return nil, errs.Wrapf(http.StatusUnauthorized, err, "aws.authorizeToken; invalid aws token")
} }
// validate audiences with the defaults // validate audiences with the defaults
if !matchesAudience(payload.Audience, p.audiences.Sign) { if !matchesAudience(payload.Audience, p.audiences.Sign) {
return nil, errors.New("invalid token: invalid audience claim (aud)") return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)")
} }
// Validate subject, it has to be known if disableCustomSANs is enabled // Validate subject, it has to be known if disableCustomSANs is enabled
@ -406,7 +407,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
if payload.Subject != doc.InstanceID && if payload.Subject != doc.InstanceID &&
payload.Subject != doc.PrivateIP && payload.Subject != doc.PrivateIP &&
payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) { payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) {
return nil, errors.New("invalid token: invalid subject claim (sub)") return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)")
} }
} }
@ -420,14 +421,14 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
} }
} }
if !found { if !found {
return nil, errors.New("invalid identity document: accountId is not valid") return nil, errs.Unauthorized("aws.authorizeToken; invalid aws identity document - accountId is not valid")
} }
} }
// validate instance age // validate instance age
if d := p.InstanceAge.Value(); d > 0 { if d := p.InstanceAge.Value(); d > 0 {
if now.Sub(doc.PendingTime) > d { if now.Sub(doc.PendingTime) > d {
return nil, errors.New("identity document pendingTime is too old") return nil, errs.Unauthorized("aws.authorizeToken; aws identity document pendingTime is too old")
} }
} }
@ -438,18 +439,18 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.claimer.IsSSHCAEnabled() {
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID()) return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner %s", p.GetID())
} }
claims, err := p.authorizeToken(token) claims, err := p.authorizeToken(token)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSSHSign")
} }
doc := claims.document doc := claims.document
signOptions := []SignOption{ signOptions := []SignOption{
// set the key id to the token subject // set the key id to the token subject
sshCertificateKeyIDModifier(claims.Subject), sshCertKeyIDModifier(claims.Subject),
} }
// Default to host + known IPs/hostnames // Default to host + known IPs/hostnames
@ -461,9 +462,9 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
}, },
} }
// Validate user options // Validate user options
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) signOptions = append(signOptions, sshCertOptionsValidator(defaults))
// Set defaults if not given as user options // Set defaults if not given as user options
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults)) signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
return append(signOptions, return append(signOptions,
// Set the default extensions. // Set the default extensions.
@ -473,8 +474,8 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertificateValidityValidator{p.claimer}, &sshCertValidityValidator{p.claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil
} }

View file

@ -10,12 +10,15 @@ import (
"encoding/hex" "encoding/hex"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"net/http"
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
) )
@ -229,6 +232,213 @@ func TestAWS_Init(t *testing.T) {
} }
} }
func TestAWS_authorizeToken(t *testing.T) {
block, _ := pem.Decode([]byte(awsTestKey))
if block == nil || block.Type != "RSA PRIVATE KEY" {
t.Fatal("error decoding AWS key")
}
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
assert.FatalError(t, err)
badKey, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err)
type test struct {
p *AWS
token string
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; error parsing aws token"),
}
},
"fail/cannot-validate-sig": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now(), badKey)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; invalid aws token signature"),
}
},
"fail/empty-account-id": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", awsIssuer, p.GetID(), "", "instance-id",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; aws identity document accountId cannot be empty"),
}
},
"fail/empty-instance-id": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; aws identity document instanceId cannot be empty"),
}
},
"fail/empty-private-ip": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
"", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; aws identity document privateIp cannot be empty"),
}
},
"fail/empty-region": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
"127.0.0.1", "", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; aws identity document region cannot be empty"),
}
},
"fail/invalid-token-issuer": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", "bad-issuer", p.GetID(), p.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; invalid aws token"),
}
},
"fail/invalid-audience": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", awsIssuer, "bad-audience", p.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; invalid token - invalid audience claim (aud)"),
}
},
"fail/invalid-subject-disabled-custom-SANs": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
p.DisableCustomSANs = true
tok, err := generateAWSToken(
"foo", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; invalid token - invalid subject claim (sub)"),
}
},
"fail/invalid-account-id": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", awsIssuer, p.GetID(), "foo", "instance-id",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; invalid aws identity document - accountId is not valid"),
}
},
"fail/instance-age": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
p.InstanceAge = Duration{1 * time.Minute}
tok, err := generateAWSToken(
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now().Add(-1*time.Minute), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("aws.authorizeToken; aws identity document pendingTime is too old"),
}
},
"ok": func(t *testing.T) test {
p, err := generateAWS()
assert.FatalError(t, err)
tok, err := generateAWSToken(
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
"127.0.0.1", "us-west-1", time.Now(), key)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) && assert.NotNil(t, claims) {
assert.Equals(t, claims.Subject, "instance-id")
assert.Equals(t, claims.Issuer, awsIssuer)
assert.NotNil(t, claims.Amazon)
aud, err := generateSignAudience("https://ca.smallstep.com", tc.p.GetID())
assert.FatalError(t, err)
assert.Equals(t, claims.Audience[0], aud)
}
}
})
}
}
func TestAWS_AuthorizeSign(t *testing.T) { func TestAWS_AuthorizeSign(t *testing.T) {
p1, srv, err := generateAWSWithServer() p1, srv, err := generateAWSWithServer()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -326,26 +536,27 @@ func TestAWS_AuthorizeSign(t *testing.T) {
aws *AWS aws *AWS
args args args args
wantLen int wantLen int
code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{t1}, 5, false}, {"ok", p1, args{t1}, 5, http.StatusOK, false},
{"ok", p2, args{t2}, 7, false}, {"ok", p2, args{t2}, 7, http.StatusOK, false},
{"ok", p2, args{t2Hostname}, 7, false}, {"ok", p2, args{t2Hostname}, 7, http.StatusOK, false},
{"ok", p2, args{t2PrivateIP}, 7, false}, {"ok", p2, args{t2PrivateIP}, 7, http.StatusOK, false},
{"ok", p1, args{t4}, 5, false}, {"ok", p1, args{t4}, 5, http.StatusOK, false},
{"fail account", p3, args{t3}, 0, true}, {"fail account", p3, args{t3}, 0, http.StatusUnauthorized, true},
{"fail token", p1, args{"token"}, 0, true}, {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
{"fail subject", p1, args{failSubject}, 0, true}, {"fail subject", p1, args{failSubject}, 0, http.StatusUnauthorized, true},
{"fail issuer", p1, args{failIssuer}, 0, true}, {"fail issuer", p1, args{failIssuer}, 0, http.StatusUnauthorized, true},
{"fail audience", p1, args{failAudience}, 0, true}, {"fail audience", p1, args{failAudience}, 0, http.StatusUnauthorized, true},
{"fail account", p1, args{failAccount}, 0, true}, {"fail account", p1, args{failAccount}, 0, http.StatusUnauthorized, true},
{"fail instanceID", p1, args{failInstanceID}, 0, true}, {"fail instanceID", p1, args{failInstanceID}, 0, http.StatusUnauthorized, true},
{"fail privateIP", p1, args{failPrivateIP}, 0, true}, {"fail privateIP", p1, args{failPrivateIP}, 0, http.StatusUnauthorized, true},
{"fail region", p1, args{failRegion}, 0, true}, {"fail region", p1, args{failRegion}, 0, http.StatusUnauthorized, true},
{"fail exp", p1, args{failExp}, 0, true}, {"fail exp", p1, args{failExp}, 0, http.StatusUnauthorized, true},
{"fail nbf", p1, args{failNbf}, 0, true}, {"fail nbf", p1, args{failNbf}, 0, http.StatusUnauthorized, true},
{"fail key", p1, args{failKey}, 0, true}, {"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
{"fail instance age", p2, args{failInstanceAge}, 0, true}, {"fail instance age", p2, args{failInstanceAge}, 0, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -354,8 +565,13 @@ func TestAWS_AuthorizeSign(t *testing.T) {
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} } else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
} else {
assert.Len(t, tt.wantLen, got) assert.Len(t, tt.wantLen, got)
}
}) })
} }
} }
@ -368,6 +584,14 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
defer srv.Close() defer srv.Close()
p2, err := generateAWS()
assert.FatalError(t, err)
// disable sshCA
disable := false
p2.Claims = &Claims{EnableSSHCA: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com") t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com")
assert.FatalError(t, err) assert.FatalError(t, err)
@ -407,30 +631,35 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
aws *AWS aws *AWS
args args args args
expected *SSHOptions expected *SSHOptions
code int
wantErr bool wantErr bool
wantSignErr bool wantSignErr bool
}{ }{
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false}, {"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false}, {"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, http.StatusOK, false, false},
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false}, {"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, false, false}, {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-principal-ip", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1"}}, pub}, expectedHostOptionsIP, false, false}, {"ok-principal-ip", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1"}}, pub}, expectedHostOptionsIP, http.StatusOK, false, false},
{"ok-principal-hostname", p1, args{t1, SSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptionsHostname, false, false}, {"ok-principal-hostname", p1, args{t1, SSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptionsHostname, http.StatusOK, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, false, false}, {"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true}, {"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, http.StatusOK, false, true},
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true}, {"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, http.StatusOK, false, true},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true}, {"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}, pub}, nil, false, true}, {"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
{"fail-sshCA-disabled", p2, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
{"fail-invalid-token", p1, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod) got, err := tt.aws.AuthorizeSSHSign(context.Background(), tt.args.token)
got, err := tt.aws.AuthorizeSSHSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("AWS.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
@ -447,6 +676,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
}) })
} }
} }
func TestAWS_AuthorizeRenew(t *testing.T) { func TestAWS_AuthorizeRenew(t *testing.T) {
p1, err := generateAWS() p1, err := generateAWS()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -466,44 +696,20 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
name string name string
aws *AWS aws *AWS
args args args args
code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, false}, {"ok", p1, args{nil}, http.StatusOK, false},
{"fail", p2, args{nil}, true}, {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.aws.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} } else if err != nil {
}) sc, ok := err.(errs.StatusCoder)
} assert.Fatal(t, ok, "error does not implement StatusCoder interface")
} assert.Equals(t, sc.StatusCode(), tt.code)
func TestAWS_AuthorizeRevoke(t *testing.T) {
p1, srv, err := generateAWSWithServer()
assert.FatalError(t, err)
defer srv.Close()
t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com")
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
aws *AWS
args args
wantErr bool
}{
{"ok", p1, args{t1}, true}, // revoke is disabled
{"fail", p1, args{"token"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.aws.AuthorizeRevoke(context.TODO(), tt.args.token); (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
} }
}) })
} }

View file

@ -13,6 +13,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
) )
@ -209,14 +210,14 @@ func (p *Azure) Init(config Config) (err error) {
return nil return nil
} }
// parseToken returns the claims, name, group, error. // authorizeToken returns the claims, name, group, error.
func (p *Azure) parseToken(token string) (*azurePayload, string, string, error) { func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, error) {
jwt, err := jose.ParseSigned(token) jwt, err := jose.ParseSigned(token)
if err != nil { if err != nil {
return nil, "", "", errors.Wrapf(err, "error parsing token") return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token")
} }
if len(jwt.Headers) == 0 { if len(jwt.Headers) == 0 {
return nil, "", "", errors.New("error parsing token: header is missing") return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token missing header")
} }
var found bool var found bool
@ -229,7 +230,7 @@ func (p *Azure) parseToken(token string) (*azurePayload, string, string, error)
} }
} }
if !found { if !found {
return nil, "", "", errors.New("cannot validate token") return nil, "", "", errs.Unauthorized("azure.authorizeToken; cannot validate azure token")
} }
if err := claims.ValidateWithLeeway(jose.Expected{ if err := claims.ValidateWithLeeway(jose.Expected{
@ -237,17 +238,17 @@ func (p *Azure) parseToken(token string) (*azurePayload, string, string, error)
Issuer: p.oidcConfig.Issuer, Issuer: p.oidcConfig.Issuer,
Time: time.Now(), Time: time.Now(),
}, 1*time.Minute); err != nil { }, 1*time.Minute); err != nil {
return nil, "", "", errors.Wrap(err, "failed to validate payload") return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; failed to validate azure token payload")
} }
// Validate TenantID // Validate TenantID
if claims.TenantID != p.TenantID { if claims.TenantID != p.TenantID {
return nil, "", "", errors.New("validation failed: invalid tenant id claim (tid)") return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)")
} }
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID) re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
if len(re) != 4 { if len(re) != 4 {
return nil, "", "", errors.Errorf("error parsing xms_mirid claim: %s", claims.XMSMirID) return nil, "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)
} }
group, name := re[2], re[3] group, name := re[2], re[3]
return &claims, name, group, nil return &claims, name, group, nil
@ -256,9 +257,9 @@ func (p *Azure) parseToken(token string) (*azurePayload, string, string, error)
// AuthorizeSign validates the given token and returns the sign options that // AuthorizeSign validates the given token and returns the sign options that
// will be used on certificate creation. // will be used on certificate creation.
func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
_, name, group, err := p.parseToken(token) _, name, group, err := p.authorizeToken(token)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSign")
} }
// Filter by resource group // Filter by resource group
@ -271,7 +272,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
} }
} }
if !found { if !found {
return nil, errors.New("validation failed: invalid resource group") return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid resource group")
} }
} }
@ -301,7 +302,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID()) return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner %s", p.GetID())
} }
return nil return nil
} }
@ -309,16 +310,16 @@ func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.claimer.IsSSHCAEnabled() {
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID()) return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner %s", p.GetID())
} }
_, name, _, err := p.parseToken(token) _, name, _, err := p.authorizeToken(token)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSSHSign")
} }
signOptions := []SignOption{ signOptions := []SignOption{
// set the key id to the token subject // set the key id to the token subject
sshCertificateKeyIDModifier(name), sshCertKeyIDModifier(name),
} }
// Default to host + known hostnames // Default to host + known hostnames
@ -327,9 +328,9 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
Principals: []string{name}, Principals: []string{name},
} }
// Validate user options // Validate user options
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) signOptions = append(signOptions, sshCertOptionsValidator(defaults))
// Set defaults if not given as user options // Set defaults if not given as user options
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults)) signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
return append(signOptions, return append(signOptions,
// Set the default extensions. // Set the default extensions.
@ -339,9 +340,9 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertificateValidityValidator{p.claimer}, &sshCertValidityValidator{p.claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil
} }

View file

@ -15,7 +15,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose"
) )
func TestAzure_Getters(t *testing.T) { func TestAzure_Getters(t *testing.T) {
@ -209,6 +212,148 @@ func TestAzure_Init(t *testing.T) {
} }
} }
func TestAzure_authorizeToken(t *testing.T) {
type test struct {
p *Azure
token string
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
p, err := generateAzure()
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("azure.authorizeToken; error parsing azure token"),
}
},
"fail/cannot-validate-sig": func(t *testing.T) test {
p, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
time.Now(), jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("azure.authorizeToken; cannot validate azure token"),
}
},
"fail/invalid-token-issuer": func(t *testing.T) test {
p, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("azure.authorizeToken; failed to validate azure token payload"),
}
},
"fail/invalid-tenant-id": func(t *testing.T) test {
p, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
"foo", "subscriptionID", "resourceGroup", "virtualMachine",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)"),
}
},
"fail/invalid-xms-mir-id": func(t *testing.T) test {
p, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
jwk := &p.keyStore.keySet.Keys[0]
sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
)
assert.FatalError(t, err)
now := time.Now()
claims := azurePayload{
Claims: jose.Claims{
Subject: "subject",
Issuer: p.oidcConfig.Issuer,
IssuedAt: jose.NewNumericDate(now),
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
Audience: []string{azureDefaultAudience},
ID: "the-jti",
},
AppID: "the-appid",
AppIDAcr: "the-appidacr",
IdentityProvider: "the-idp",
ObjectID: "the-oid",
TenantID: p.TenantID,
Version: "the-version",
XMSMirID: "foo",
}
tok, err := jose.Signed(sig).Claims(claims).CompactSerialize()
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("azure.authorizeToken; error parsing xms_mirid claim - foo"),
}
},
"ok": func(t *testing.T) test {
p, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if claims, name, group, err := tc.p.authorizeToken(tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, claims.Subject, "subject")
assert.Equals(t, claims.Issuer, tc.p.oidcConfig.Issuer)
assert.Equals(t, claims.Audience[0], azureDefaultAudience)
assert.Equals(t, name, "virtualMachine")
assert.Equals(t, group, "resourceGroup")
}
}
})
}
}
func TestAzure_AuthorizeSign(t *testing.T) { func TestAzure_AuthorizeSign(t *testing.T) {
p1, srv, err := generateAzureWithServer() p1, srv, err := generateAzureWithServer()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -283,19 +428,20 @@ func TestAzure_AuthorizeSign(t *testing.T) {
azure *Azure azure *Azure
args args args args
wantLen int wantLen int
code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{t1}, 4, false}, {"ok", p1, args{t1}, 4, http.StatusOK, false},
{"ok", p2, args{t2}, 6, false}, {"ok", p2, args{t2}, 6, http.StatusOK, false},
{"ok", p1, args{t11}, 4, false}, {"ok", p1, args{t11}, 4, http.StatusOK, false},
{"fail tenant", p3, args{t3}, 0, true}, {"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true},
{"fail resource group", p4, args{t4}, 0, true}, {"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true},
{"fail token", p1, args{"token"}, 0, true}, {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
{"fail issuer", p1, args{failIssuer}, 0, true}, {"fail issuer", p1, args{failIssuer}, 0, http.StatusUnauthorized, true},
{"fail audience", p1, args{failAudience}, 0, true}, {"fail audience", p1, args{failAudience}, 0, http.StatusUnauthorized, true},
{"fail exp", p1, args{failExp}, 0, true}, {"fail exp", p1, args{failExp}, 0, http.StatusUnauthorized, true},
{"fail nbf", p1, args{failNbf}, 0, true}, {"fail nbf", p1, args{failNbf}, 0, http.StatusUnauthorized, true},
{"fail key", p1, args{failKey}, 0, true}, {"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -304,8 +450,51 @@ func TestAzure_AuthorizeSign(t *testing.T) {
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} } else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
} else {
assert.Len(t, tt.wantLen, got) assert.Len(t, tt.wantLen, got)
}
})
}
}
func TestAzure_AuthorizeRenew(t *testing.T) {
p1, err := generateAzure()
assert.FatalError(t, err)
p2, err := generateAzure()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
azure *Azure
args args
code int
wantErr bool
}{
{"ok", p1, args{nil}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}
}) })
} }
} }
@ -318,6 +507,14 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
defer srv.Close() defer srv.Close()
p2, err := generateAzure()
assert.FatalError(t, err)
// disable sshCA
disable := false
p2.Claims = &Claims{EnableSSHCA: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
t1, err := p1.GetIdentityToken("subject", "caURL") t1, err := p1.GetIdentityToken("subject", "caURL")
assert.FatalError(t, err) assert.FatalError(t, err)
@ -349,28 +546,33 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
azure *Azure azure *Azure
args args args args
expected *SSHOptions expected *SSHOptions
code int
wantErr bool wantErr bool
wantSignErr bool wantSignErr bool
}{ }{
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false}, {"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false}, {"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, http.StatusOK, false, false},
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false}, {"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, false, false}, {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, false, false}, {"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true}, {"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, http.StatusOK, false, true},
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true}, {"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, http.StatusOK, false, true},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true}, {"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine", "smallstep.com"}}, pub}, nil, false, true}, {"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine", "smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
{"fail-sshCA-disabled", p2, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
{"fail-invalid-token", p1, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod) got, err := tt.azure.AuthorizeSSHSign(context.Background(), tt.args.token)
got, err := tt.azure.AuthorizeSSHSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Azure.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
@ -388,68 +590,6 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
} }
} }
func TestAzure_AuthorizeRenew(t *testing.T) {
p1, err := generateAzure()
assert.FatalError(t, err)
p2, err := generateAzure()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
azure *Azure
args args
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.azure.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestAzure_AuthorizeRevoke(t *testing.T) {
az, srv, err := generateAzureWithServer()
assert.FatalError(t, err)
defer srv.Close()
token, err := az.GetIdentityToken("subject", "caURL")
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
azure *Azure
args args
wantErr bool
}{
{"ok token", az, args{token}, true}, // revoke is disabled
{"bad token", az, args{"bad token"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.azure.AuthorizeRevoke(context.TODO(), tt.args.token); (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestAzure_assertConfig(t *testing.T) { func TestAzure_assertConfig(t *testing.T) {
p1, err := generateAzure() p1, err := generateAzure()
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -78,7 +78,7 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims)
// match with server audiences // match with server audiences
if matchesAudience(claims.Audience, audiences) { if matchesAudience(claims.Audience, audiences) {
// Use fragment to get provisioner name (GCP, AWS) // Use fragment to get provisioner name (GCP, AWS, SSHPOP)
if fragment != "" { if fragment != "" {
return c.Load(fragment) return c.Load(fragment)
} }

View file

@ -14,6 +14,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
) )
@ -210,7 +211,7 @@ func (p *GCP) Init(config Config) error {
func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token) claims, err := p.authorizeToken(token)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "gcp.AuthorizeSign")
} }
ce := claims.Google.ComputeEngine ce := claims.Google.ComputeEngine
@ -239,10 +240,10 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
), nil ), nil
} }
// AuthorizeRenewal returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *GCP) AuthorizeRenewal(ctx context.Context, cert *x509.Certificate) error { func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID()) return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner %s", p.GetID())
} }
return nil return nil
} }
@ -260,10 +261,10 @@ func (p *GCP) assertConfig() {
func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
jwt, err := jose.ParseSigned(token) jwt, err := jose.ParseSigned(token)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error parsing token") return nil, errs.Wrap(http.StatusUnauthorized, err, "gcp.authorizeToken; error parsing gcp token")
} }
if len(jwt.Headers) == 0 { if len(jwt.Headers) == 0 {
return nil, errors.New("error parsing token: header is missing") return nil, errs.Unauthorized("gcp.authorizeToken; error parsing gcp token - header is missing")
} }
var found bool var found bool
@ -277,7 +278,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
} }
} }
if !found { if !found {
return nil, errors.Errorf("failed to validate payload: cannot find key for kid %s", kid) return nil, errs.Unauthorized("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid %s", kid)
} }
// According to "rfc7519 JSON Web Token" acceptable skew should be no // According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -287,12 +288,12 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
Issuer: "https://accounts.google.com", Issuer: "https://accounts.google.com",
Time: now, Time: now,
}, time.Minute); err != nil { }, time.Minute); err != nil {
return nil, errors.Wrapf(err, "invalid token") return nil, errs.Wrap(http.StatusUnauthorized, err, "gcp.authorizeToken; invalid gcp token payload")
} }
// validate audiences with the defaults // validate audiences with the defaults
if !matchesAudience(claims.Audience, p.audiences.Sign) { if !matchesAudience(claims.Audience, p.audiences.Sign) {
return nil, errors.New("invalid token: invalid audience claim (aud)") return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)")
} }
// validate subject (service account) // validate subject (service account)
@ -305,7 +306,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
} }
} }
if !found { if !found {
return nil, errors.New("invalid token: invalid subject claim") return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid subject claim")
} }
} }
@ -319,26 +320,26 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
} }
} }
if !found { if !found {
return nil, errors.New("invalid token: invalid project id") return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid project id")
} }
} }
// validate instance age // validate instance age
if d := p.InstanceAge.Value(); d > 0 { if d := p.InstanceAge.Value(); d > 0 {
if now.Sub(claims.Google.ComputeEngine.InstanceCreationTimestamp.Time()) > d { if now.Sub(claims.Google.ComputeEngine.InstanceCreationTimestamp.Time()) > d {
return nil, errors.New("token google.compute_engine.instance_creation_timestamp is too old") return nil, errs.Unauthorized("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old")
} }
} }
switch { switch {
case claims.Google.ComputeEngine.InstanceID == "": case claims.Google.ComputeEngine.InstanceID == "":
return nil, errors.New("token google.compute_engine.instance_id cannot be empty") return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty")
case claims.Google.ComputeEngine.InstanceName == "": case claims.Google.ComputeEngine.InstanceName == "":
return nil, errors.New("token google.compute_engine.instance_name cannot be empty") return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty")
case claims.Google.ComputeEngine.ProjectID == "": case claims.Google.ComputeEngine.ProjectID == "":
return nil, errors.New("token google.compute_engine.project_id cannot be empty") return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty")
case claims.Google.ComputeEngine.Zone == "": case claims.Google.ComputeEngine.Zone == "":
return nil, errors.New("token google.compute_engine.zone cannot be empty") return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty")
} }
return &claims, nil return &claims, nil
@ -347,18 +348,18 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.claimer.IsSSHCAEnabled() {
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID()) return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner %s", p.GetID())
} }
claims, err := p.authorizeToken(token) claims, err := p.authorizeToken(token)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "gcp.AuthorizeSSHSign")
} }
ce := claims.Google.ComputeEngine ce := claims.Google.ComputeEngine
signOptions := []SignOption{ signOptions := []SignOption{
// set the key id to the token subject // set the key id to the token subject
sshCertificateKeyIDModifier(ce.InstanceName), sshCertKeyIDModifier(ce.InstanceName),
} }
// Default to host + known hostnames // Default to host + known hostnames
@ -370,9 +371,9 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
}, },
} }
// Validate user options // Validate user options
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) signOptions = append(signOptions, sshCertOptionsValidator(defaults))
// Set defaults if not given as user options // Set defaults if not given as user options
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults)) signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
return append(signOptions, return append(signOptions,
// Set the default extensions // Set the default extensions
@ -382,8 +383,8 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertificateValidityValidator{p.claimer}, &sshCertValidityValidator{p.claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil
} }

View file

@ -16,7 +16,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose"
) )
func TestGCP_Getters(t *testing.T) { func TestGCP_Getters(t *testing.T) {
@ -211,6 +214,202 @@ func TestGCP_Init(t *testing.T) {
} }
} }
func TestGCP_authorizeToken(t *testing.T) {
type test struct {
p *GCP
token string
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; error parsing gcp token"),
}
},
"fail/cannot-validate-sig": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://accounts.google.com", p.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid "),
}
},
"fail/invalid-issuer": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://foo.bar.zap", p.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; invalid gcp token payload"),
}
},
"fail/invalid-serviceAccount": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
tok, err := generateGCPToken("foo",
"https://accounts.google.com", p.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; invalid gcp token - invalid subject claim"),
}
},
"fail/invalid-projectID": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
p.ProjectIDs = []string{"foo", "bar"}
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://accounts.google.com", p.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; invalid gcp token - invalid project id"),
}
},
"fail/instance-age": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
p.InstanceAge = Duration{1 * time.Minute}
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://accounts.google.com", p.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now().Add(-1*time.Minute), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old"),
}
},
"fail/empty-instance-id": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://accounts.google.com", p.GetID(),
"", "instance-name", "project-id", "zone",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty"),
}
},
"fail/empty-instance-name": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://accounts.google.com", p.GetID(),
"instance-id", "", "project-id", "zone",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty"),
}
},
"fail/empty-project-id": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://accounts.google.com", p.GetID(),
"instance-id", "instance-name", "", "zone",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty"),
}
},
"fail/empty-zone": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://accounts.google.com", p.GetID(),
"instance-id", "instance-name", "project-id", "",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty"),
}
},
"ok": func(t *testing.T) test {
p, err := generateGCP()
assert.FatalError(t, err)
tok, err := generateGCPToken(p.ServiceAccounts[0],
"https://accounts.google.com", p.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
p: p,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) && assert.NotNil(t, claims) {
assert.Equals(t, claims.Subject, tc.p.ServiceAccounts[0])
assert.Equals(t, claims.Issuer, "https://accounts.google.com")
assert.NotNil(t, claims.Google)
aud, err := generateSignAudience("https://ca.smallstep.com", tc.p.GetID())
assert.FatalError(t, err)
assert.Equals(t, claims.Audience[0], aud)
}
}
})
}
}
func TestGCP_AuthorizeSign(t *testing.T) { func TestGCP_AuthorizeSign(t *testing.T) {
p1, err := generateGCP() p1, err := generateGCP()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -313,24 +512,25 @@ func TestGCP_AuthorizeSign(t *testing.T) {
gcp *GCP gcp *GCP
args args args args
wantLen int wantLen int
code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{t1}, 4, false}, {"ok", p1, args{t1}, 4, http.StatusOK, false},
{"ok", p2, args{t2}, 6, false}, {"ok", p2, args{t2}, 6, http.StatusOK, false},
{"ok", p3, args{t3}, 4, false}, {"ok", p3, args{t3}, 4, http.StatusOK, false},
{"fail token", p1, args{"token"}, 0, true}, {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
{"fail key", p1, args{failKey}, 0, true}, {"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
{"fail iss", p1, args{failIss}, 0, true}, {"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true},
{"fail aud", p1, args{failAud}, 0, true}, {"fail aud", p1, args{failAud}, 0, http.StatusUnauthorized, true},
{"fail exp", p1, args{failExp}, 0, true}, {"fail exp", p1, args{failExp}, 0, http.StatusUnauthorized, true},
{"fail nbf", p1, args{failNbf}, 0, true}, {"fail nbf", p1, args{failNbf}, 0, http.StatusUnauthorized, true},
{"fail service account", p1, args{failServiceAccount}, 0, true}, {"fail service account", p1, args{failServiceAccount}, 0, http.StatusUnauthorized, true},
{"fail invalid project id", p3, args{failInvalidProjectID}, 0, true}, {"fail invalid project id", p3, args{failInvalidProjectID}, 0, http.StatusUnauthorized, true},
{"fail invalid instance age", p3, args{failInvalidInstanceAge}, 0, true}, {"fail invalid instance age", p3, args{failInvalidInstanceAge}, 0, http.StatusUnauthorized, true},
{"fail instance id", p1, args{failInstanceID}, 0, true}, {"fail instance id", p1, args{failInstanceID}, 0, http.StatusUnauthorized, true},
{"fail instance name", p1, args{failInstanceName}, 0, true}, {"fail instance name", p1, args{failInstanceName}, 0, http.StatusUnauthorized, true},
{"fail project id", p1, args{failProjectID}, 0, true}, {"fail project id", p1, args{failProjectID}, 0, http.StatusUnauthorized, true},
{"fail zone", p1, args{failZone}, 0, true}, {"fail zone", p1, args{failZone}, 0, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -339,8 +539,13 @@ func TestGCP_AuthorizeSign(t *testing.T) {
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} } else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
} else {
assert.Len(t, tt.wantLen, got) assert.Len(t, tt.wantLen, got)
}
}) })
} }
} }
@ -352,6 +557,14 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
p1, err := generateGCP() p1, err := generateGCP()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateGCP()
assert.FatalError(t, err)
// disable sshCA
disable := false
p2.Claims = &Claims{EnableSSHCA: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
t1, err := generateGCPToken(p1.ServiceAccounts[0], t1, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", p1.GetID(), "https://accounts.google.com", p1.GetID(),
"instance-id", "instance-name", "project-id", "zone", "instance-id", "instance-name", "project-id", "zone",
@ -394,30 +607,35 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
gcp *GCP gcp *GCP
args args args args
expected *SSHOptions expected *SSHOptions
code int
wantErr bool wantErr bool
wantSignErr bool wantSignErr bool
}{ }{
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false}, {"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false}, {"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, http.StatusOK, false, false},
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false}, {"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, false, false}, {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"ok-principal1", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal1, false, false}, {"ok-principal1", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal1, http.StatusOK, false, false},
{"ok-principal2", p1, args{t1, SSHOptions{Principals: []string{"instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal2, false, false}, {"ok-principal2", p1, args{t1, SSHOptions{Principals: []string{"instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal2, http.StatusOK, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, false, false}, {"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true}, {"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, http.StatusOK, false, true},
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true}, {"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, http.StatusOK, false, true},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true}, {"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal", "smallstep.com"}}, pub}, nil, false, true}, {"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal", "smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
{"fail-sshCA-disabled", p2, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
{"fail-invalid-token", p1, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod) got, err := tt.gcp.AuthorizeSSHSign(context.Background(), tt.args.token)
got, err := tt.gcp.AuthorizeSSHSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("GCP.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
@ -435,7 +653,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
} }
} }
func TestGCP_AuthorizeRenewal(t *testing.T) { func TestGCP_AuthorizeRenew(t *testing.T) {
p1, err := generateGCP() p1, err := generateGCP()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateGCP() p2, err := generateGCP()
@ -454,46 +672,20 @@ func TestGCP_AuthorizeRenewal(t *testing.T) {
name string name string
prov *GCP prov *GCP
args args args args
code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, false}, {"ok", p1, args{nil}, http.StatusOK, false},
{"fail", p2, args{nil}, true}, {"fail/renewal-disabled", p2, args{nil}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenewal(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} } else if err != nil {
}) sc, ok := err.(errs.StatusCoder)
} assert.Fatal(t, ok, "error does not implement StatusCoder interface")
} assert.Equals(t, sc.StatusCode(), tt.code)
func TestGCP_AuthorizeRevoke(t *testing.T) {
p1, err := generateGCP()
assert.FatalError(t, err)
t1, err := generateGCPToken(p1.ServiceAccounts[0],
"https://accounts.google.com", p1.GetID(),
"instance-id", "instance-name", "project-id", "zone",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
gcp *GCP
args args
wantErr bool
}{
{"ok", p1, args{t1}, true}, // revoke is disabled
{"fail", p1, args{"token"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.gcp.AuthorizeRevoke(context.TODO(), tt.args.token); (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
} }
}) })
} }

View file

@ -3,9 +3,11 @@ package provisioner
import ( import (
"context" "context"
"crypto/x509" "crypto/x509"
"net/http"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/x509util" "github.com/smallstep/cli/crypto/x509util"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
) )
@ -99,12 +101,12 @@ func (p *JWK) Init(config Config) (err error) {
func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, error) { func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, error) {
jwt, err := jose.ParseSigned(token) jwt, err := jose.ParseSigned(token)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error parsing token") return nil, errs.Wrap(http.StatusUnauthorized, err, "jwk.authorizeToken; error parsing jwk token")
} }
var claims jwtPayload var claims jwtPayload
if err = jwt.Claims(p.Key, &claims); err != nil { if err = jwt.Claims(p.Key, &claims); err != nil {
return nil, errors.Wrap(err, "error parsing claims") return nil, errs.Wrap(http.StatusUnauthorized, err, "jwk.authorizeToken; error parsing jwk claims")
} }
// According to "rfc7519 JSON Web Token" acceptable skew should be no // According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -113,17 +115,17 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
Issuer: p.Name, Issuer: p.Name,
Time: time.Now().UTC(), Time: time.Now().UTC(),
}, time.Minute); err != nil { }, time.Minute); err != nil {
return nil, errors.Wrapf(err, "invalid token") return nil, errs.Wrapf(http.StatusUnauthorized, err, "jwk.authorizeToken; invalid jwk claims")
} }
// validate audiences with the defaults // validate audiences with the defaults
if !matchesAudience(claims.Audience, audiences) { if !matchesAudience(claims.Audience, audiences) {
return nil, errors.Errorf("invalid token: invalid audience claim (aud); want %s, but got %s", return nil, errs.Unauthorized("jwk.authorizeToken; invalid jwk token audience claim (aud); want %s, but got %s",
audiences, claims.Audience) audiences, claims.Audience)
} }
if claims.Subject == "" { if claims.Subject == "" {
return nil, errors.New("token subject cannot be empty") return nil, errs.Unauthorized("jwk.authorizeToken; jwk token subject cannot be empty")
} }
return &claims, nil return &claims, nil
@ -133,14 +135,14 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
// revoke the certificate with serial number in the `sub` property. // revoke the certificate with serial number in the `sub` property.
func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error { func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke) _, err := p.authorizeToken(token, p.audiences.Revoke)
return err return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke")
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign) claims, err := p.authorizeToken(token, p.audiences.Sign)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
} }
// NOTE: This is for backwards compatibility with older versions of cli // NOTE: This is for backwards compatibility with older versions of cli
@ -171,7 +173,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID()) return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner %s", p.GetID())
} }
return nil return nil
} }
@ -179,20 +181,20 @@ func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.claimer.IsSSHCAEnabled() {
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID()) return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner %s", p.GetID())
} }
claims, err := p.authorizeToken(token, p.audiences.SSHSign) claims, err := p.authorizeToken(token, p.audiences.SSHSign)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign")
} }
if claims.Step == nil || claims.Step.SSH == nil { if claims.Step == nil || claims.Step.SSH == nil {
return nil, errors.New("authorization token must be an SSH provisioning token") return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; jwk token must be an SSH provisioning token")
} }
opts := claims.Step.SSH opts := claims.Step.SSH
signOptions := []SignOption{ signOptions := []SignOption{
// validates user's SSHOptions with the ones in the token // validates user's SSHOptions with the ones in the token
sshCertificateOptionsValidator(*opts), sshCertOptionsValidator(*opts),
} }
t := now() t := now()
@ -205,19 +207,19 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
signOptions = append(signOptions, sshCertPrincipalsModifier(opts.Principals)) signOptions = append(signOptions, sshCertPrincipalsModifier(opts.Principals))
} }
if !opts.ValidAfter.IsZero() { if !opts.ValidAfter.IsZero() {
signOptions = append(signOptions, sshCertificateValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix())) signOptions = append(signOptions, sshCertValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))
} }
if !opts.ValidBefore.IsZero() { if !opts.ValidBefore.IsZero() {
signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix())) signOptions = append(signOptions, sshCertValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
} }
if opts.KeyID != "" { if opts.KeyID != "" {
signOptions = append(signOptions, sshCertificateKeyIDModifier(opts.KeyID)) signOptions = append(signOptions, sshCertKeyIDModifier(opts.KeyID))
} else { } else {
signOptions = append(signOptions, sshCertificateKeyIDModifier(claims.Subject)) signOptions = append(signOptions, sshCertKeyIDModifier(claims.Subject))
} }
// Default to a user certificate with no principals if not set // Default to a user certificate with no principals if not set
signOptions = append(signOptions, sshCertificateDefaultsModifier{CertType: SSHUserCert}) signOptions = append(signOptions, sshCertDefaultsModifier{CertType: SSHUserCert})
return append(signOptions, return append(signOptions,
// Set the default extensions. // Set the default extensions.
@ -229,14 +231,14 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertificateValidityValidator{p.claimer}, &sshCertValidityValidator{p.claimer},
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertificateDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil
} }
// AuthorizeSSHRevoke returns nil if the token is valid, false otherwise. // AuthorizeSSHRevoke returns nil if the token is valid, false otherwise.
func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error { func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.SSHRevoke) _, err := p.authorizeToken(token, p.audiences.SSHRevoke)
return err return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke")
} }

View file

@ -7,12 +7,14 @@ import (
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"net" "net"
"net/http"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
) )
@ -162,25 +164,29 @@ func TestJWK_authorizeToken(t *testing.T) {
name string name string
prov *JWK prov *JWK
args args args args
code int
err error err error
}{ }{
{"fail-token", p1, args{failTok}, errors.New("error parsing token")}, {"fail-token", p1, args{failTok}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk token")},
{"fail-key", p1, args{failKey}, errors.New("error parsing claims")}, {"fail-key", p1, args{failKey}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk claims")},
{"fail-claims", p1, args{failClaims}, errors.New("error parsing claims")}, {"fail-claims", p1, args{failClaims}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk claims")},
{"fail-signature", p1, args{failSig}, errors.New("error parsing claims: square/go-jose: error in cryptographic primitive")}, {"fail-signature", p1, args{failSig}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk claims: square/go-jose: error in cryptographic primitive")},
{"fail-issuer", p1, args{failIss}, errors.New("invalid token: square/go-jose/jwt: validation failed, invalid issuer claim (iss)")}, {"fail-issuer", p1, args{failIss}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk claims: square/go-jose/jwt: validation failed, invalid issuer claim (iss)")},
{"fail-expired", p1, args{failExp}, errors.New("invalid token: square/go-jose/jwt: validation failed, token is expired (exp)")}, {"fail-expired", p1, args{failExp}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk claims: square/go-jose/jwt: validation failed, token is expired (exp)")},
{"fail-not-before", p1, args{failNbf}, errors.New("invalid token: square/go-jose/jwt: validation failed, token not valid yet (nbf)")}, {"fail-not-before", p1, args{failNbf}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk claims: square/go-jose/jwt: validation failed, token not valid yet (nbf)")},
{"fail-audience", p1, args{failAud}, errors.New("invalid token: invalid audience claim (aud)")}, {"fail-audience", p1, args{failAud}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk token audience claim (aud)")},
{"fail-subject", p1, args{failSub}, errors.New("token subject cannot be empty")}, {"fail-subject", p1, args{failSub}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; jwk token subject cannot be empty")},
{"ok", p1, args{t1}, nil}, {"ok", p1, args{t1}, http.StatusOK, nil},
{"ok-no-encrypted-key", p2, args{t2}, nil}, {"ok-no-encrypted-key", p2, args{t2}, http.StatusOK, nil},
{"ok-no-sans", p1, args{t3}, nil}, {"ok-no-sans", p1, args{t3}, http.StatusOK, nil},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got, err := tt.prov.authorizeToken(tt.args.token, testAudiences.Sign); err != nil { if got, err := tt.prov.authorizeToken(tt.args.token, testAudiences.Sign); err != nil {
if assert.NotNil(t, tt.err) { if assert.NotNil(t, tt.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
} }
} else { } else {
@ -208,15 +214,19 @@ func TestJWK_AuthorizeRevoke(t *testing.T) {
name string name string
prov *JWK prov *JWK
args args args args
code int
err error err error
}{ }{
{"fail-signature", p1, args{failSig}, errors.New("error parsing claims: square/go-jose: error in cryptographic primitive")}, {"fail-signature", p1, args{failSig}, http.StatusUnauthorized, errors.New("jwk.AuthorizeRevoke: jwk.authorizeToken; error parsing jwk claims: square/go-jose: error in cryptographic primitive")},
{"ok", p1, args{t1}, nil}, {"ok", p1, args{t1}, http.StatusOK, nil},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRevoke(context.TODO(), tt.args.token); err != nil { if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil {
if assert.NotNil(t, tt.err) { if assert.NotNil(t, tt.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
} }
} }
@ -246,20 +256,24 @@ func TestJWK_AuthorizeSign(t *testing.T) {
name string name string
prov *JWK prov *JWK
args args args args
code int
err error err error
dns []string dns []string
emails []string emails []string
ips []net.IP ips []net.IP
}{ }{
{name: "fail-signature", prov: p1, args: args{failSig}, err: errors.New("error parsing claims: square/go-jose: error in cryptographic primitive")}, {name: "fail-signature", prov: p1, args: args{failSig}, code: http.StatusUnauthorized, err: errors.New("jwk.AuthorizeSign: jwk.authorizeToken; error parsing jwk claims: square/go-jose: error in cryptographic primitive")},
{"ok-sans", p1, args{t1}, nil, []string{"foo"}, []string{"max@smallstep.com"}, []net.IP{net.ParseIP("127.0.0.1")}}, {"ok-sans", p1, args{t1}, http.StatusOK, nil, []string{"foo"}, []string{"max@smallstep.com"}, []net.IP{net.ParseIP("127.0.0.1")}},
{"ok-no-sans", p1, args{t2}, nil, []string{"subject"}, []string{}, []net.IP{}}, {"ok-no-sans", p1, args{t2}, http.StatusOK, nil, []string{"subject"}, []string{}, []net.IP{}},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignMethod) ctx := NewContextWithMethod(context.Background(), SignMethod)
if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil { if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil {
if assert.NotNil(t, tt.err) { if assert.NotNil(t, tt.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
} }
} else { } else {
@ -315,15 +329,20 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
name string name string
prov *JWK prov *JWK
args args args args
code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, false}, {"ok", p1, args{nil}, http.StatusOK, false},
{"fail", p2, args{nil}, true}, {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
} }
}) })
} }
@ -335,6 +354,14 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
p1, err := generateJWK() p1, err := generateJWK()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateJWK()
assert.FatalError(t, err)
// disable sshCA
disable := false
p2.Claims = &Claims{EnableSSHCA: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
jwk, err := decryptJSONWebKey(p1.EncryptedKey) jwk, err := decryptJSONWebKey(p1.EncryptedKey)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -382,30 +409,34 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
prov *JWK prov *JWK
args args args args
expected *SSHOptions expected *SSHOptions
code int
wantErr bool wantErr bool
wantSignErr bool wantSignErr bool
}{ }{
{"user", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, false, false}, {"user", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, http.StatusOK, false, false},
{"user-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, false, false}, {"user-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, http.StatusOK, false, false},
{"user-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, false, false}, {"user-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, http.StatusOK, false, false},
{"user-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub}, expectedUserOptions, false, false}, {"user-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub}, expectedUserOptions, http.StatusOK, false, false},
{"user-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, false, false}, {"user-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, http.StatusOK, false, false},
{"host", p1, args{t2, SSHOptions{}, pub}, expectedHostOptions, false, false}, {"host", p1, args{t2, SSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"host-type", p1, args{t2, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false}, {"host-type", p1, args{t2, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"host-principals", p1, args{t2, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false}, {"host-principals", p1, args{t2, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"host-options", p1, args{t2, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false}, {"host-options", p1, args{t2, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"fail-signature", p1, args{failSig, SSHOptions{}, pub}, nil, true, false}, {"fail-sshCA-disabled", p2, args{"foo", SSHOptions{}, pub}, expectedUserOptions, http.StatusUnauthorized, true, false},
{"rail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, false, true}, {"fail-signature", p1, args{failSig, SSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false},
{"rail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod) got, err := tt.prov.AuthorizeSSHSign(context.Background(), tt.args.token)
got, err := tt.prov.AuthorizeSSHSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("JWK.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("JWK.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
@ -511,10 +542,9 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
token, err := generateSSHToken(tt.args.sub, tt.args.iss, tt.args.aud, tt.args.iat, tt.args.tokSSHOpts, tt.args.jwk) token, err := generateSSHToken(tt.args.sub, tt.args.iss, tt.args.aud, tt.args.iat, tt.args.tokSSHOpts, tt.args.jwk)
assert.FatalError(t, err) assert.FatalError(t, err)
if got, err := tt.prov.AuthorizeSSHSign(ctx, token); (err != nil) != tt.wantErr { if got, err := tt.prov.AuthorizeSSHSign(context.Background(), token); (err != nil) != tt.wantErr {
t.Errorf("JWK.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("JWK.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
} else if !tt.wantErr && assert.NotNil(t, got) { } else if !tt.wantErr && assert.NotNil(t, got) {
var opts SSHOptions var opts SSHOptions
@ -535,3 +565,52 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) {
}) })
} }
} }
func TestJWK_AuthorizeSSHRevoke(t *testing.T) {
type test struct {
p *JWK
token string
code int
err error
}
tests := map[string]func(*testing.T) test{
"fail/invalid-token": func(t *testing.T) test {
p, err := generateJWK()
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("jwk.AuthorizeSSHRevoke: jwk.authorizeToken; error parsing jwk token"),
}
},
"ok": func(t *testing.T) test {
p, err := generateJWK()
assert.FatalError(t, err)
jwk, err := decryptJSONWebKey(p.EncryptedKey)
assert.FatalError(t, err)
tok, err := generateToken("subject", p.Name, testAudiences.SSHRevoke[0], "name@smallstep.com", []string{"127.0.0.1", "max@smallstep.com", "foo"}, time.Now(), jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}

View file

@ -6,8 +6,10 @@ import (
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"net/http"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
@ -138,7 +140,8 @@ func (p *K8sSA) Init(config Config) (err error) {
func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, error) { func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, error) {
jwt, err := jose.ParseSigned(token) jwt, err := jose.ParseSigned(token)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error parsing token") return nil, errs.Wrap(http.StatusUnauthorized, err,
"k8ssa.authorizeToken; error parsing k8sSA token")
} }
var ( var (
@ -146,7 +149,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
claims k8sSAPayload claims k8sSAPayload
) )
if p.pubKeys == nil { if p.pubKeys == nil {
return nil, errors.New("TokenReview API integration not implemented") return nil, errs.Unauthorized("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented")
/* NOTE: We plan to support the TokenReview API in a future release. /* NOTE: We plan to support the TokenReview API in a future release.
Below is some code that should be useful when we prioritize Below is some code that should be useful when we prioritize
this integration. this integration.
@ -174,7 +177,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
} }
} }
if !valid { if !valid {
return nil, errors.New("error validating token and extracting claims") return nil, errs.Unauthorized("k8ssa.authorizeToken; error validating k8sSA token and extracting claims")
} }
// According to "rfc7519 JSON Web Token" acceptable skew should be no // According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -182,11 +185,11 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
if err = claims.Validate(jose.Expected{ if err = claims.Validate(jose.Expected{
Issuer: k8sSAIssuer, Issuer: k8sSAIssuer,
}); err != nil { }); err != nil {
return nil, errors.Wrapf(err, "invalid token claims") return nil, errs.Wrap(http.StatusUnauthorized, err, "k8ssa.authorizeToken; invalid k8sSA token claims")
} }
if claims.Subject == "" { if claims.Subject == "" {
return nil, errors.New("token subject cannot be empty") return nil, errs.Unauthorized("k8ssa.authorizeToken; k8sSA token subject cannot be empty")
} }
return &claims, nil return &claims, nil
@ -196,14 +199,13 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
// revoke the certificate with serial number in the `sub` property. // revoke the certificate with serial number in the `sub` property.
func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error { func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke) _, err := p.authorizeToken(token, p.audiences.Revoke)
return err return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke")
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
_, err := p.authorizeToken(token, p.audiences.Sign) if _, err := p.authorizeToken(token, p.audiences.Sign); err != nil {
if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign")
return nil, err
} }
return []SignOption{ return []SignOption{
@ -219,7 +221,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID()) return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID())
} }
return nil return nil
} }
@ -227,17 +229,14 @@ func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro
// AuthorizeSSHSign validates an request for an SSH certificate. // AuthorizeSSHSign validates an request for an SSH certificate.
func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.claimer.IsSSHCAEnabled() {
return nil, errors.Errorf("authorizeSSHSign: ssh ca is disabled for provisioner %s", p.GetID()) return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID())
} }
_, err := p.authorizeToken(token, p.audiences.SSHSign) if _, err := p.authorizeToken(token, p.audiences.SSHSign); err != nil {
if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign")
return nil, errors.Wrap(err, "authorizeSSHSign")
} }
// Default to a user certificate with no principals if not set // Default to a user certificate with no principals if not set
signOptions := []SignOption{ signOptions := []SignOption{sshCertDefaultsModifier{CertType: SSHUserCert}}
sshCertificateDefaultsModifier{CertType: SSHUserCert},
}
return append(signOptions, return append(signOptions,
// Set the default extensions. // Set the default extensions.
@ -247,9 +246,9 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertificateValidityValidator{p.claimer}, &sshCertValidityValidator{p.claimer},
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertificateDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil
} }

View file

@ -3,11 +3,13 @@ package provisioner
import ( import (
"context" "context"
"crypto/x509" "crypto/x509"
"net/http"
"testing" "testing"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
) )
@ -36,6 +38,7 @@ func TestK8sSA_authorizeToken(t *testing.T) {
p *K8sSA p *K8sSA
token string token string
err error err error
code int
} }
tests := map[string]func(*testing.T) test{ tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test { "fail/bad-token": func(t *testing.T) test {
@ -44,7 +47,24 @@ func TestK8sSA_authorizeToken(t *testing.T) {
return test{ return test{
p: p, p: p,
token: "foo", token: "foo",
err: errors.New("error parsing token"), code: http.StatusUnauthorized,
err: errors.New("k8ssa.authorizeToken; error parsing k8sSA token"),
}
},
"fail/not-implemented": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
p, err := generateK8sSA(nil)
assert.FatalError(t, err)
tok, err := generateToken("", p.Name, testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk)
p.pubKeys = nil
assert.FatalError(t, err)
return test{
p: p,
token: tok,
err: errors.New("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented"),
code: http.StatusUnauthorized,
} }
}, },
"fail/error-validating-token": func(t *testing.T) test { "fail/error-validating-token": func(t *testing.T) test {
@ -58,7 +78,8 @@ func TestK8sSA_authorizeToken(t *testing.T) {
return test{ return test{
p: p, p: p,
token: tok, token: tok,
err: errors.New("error validating token and extracting claims"), err: errors.New("k8ssa.authorizeToken; error validating k8sSA token and extracting claims"),
code: http.StatusUnauthorized,
} }
}, },
"fail/invalid-issuer": func(t *testing.T) test { "fail/invalid-issuer": func(t *testing.T) test {
@ -73,7 +94,8 @@ func TestK8sSA_authorizeToken(t *testing.T) {
return test{ return test{
p: p, p: p,
token: tok, token: tok,
err: errors.New("invalid token claims: square/go-jose/jwt: validation failed, invalid issuer claim (iss)"), code: http.StatusUnauthorized,
err: errors.New("k8ssa.authorizeToken; invalid k8sSA token claims: square/go-jose/jwt: validation failed, invalid issuer claim (iss)"),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
@ -94,6 +116,9 @@ func TestK8sSA_authorizeToken(t *testing.T) {
tc := tt(t) tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else {
@ -105,12 +130,12 @@ func TestK8sSA_authorizeToken(t *testing.T) {
} }
} }
func TestK8sSA_AuthorizeSign(t *testing.T) { func TestK8sSA_AuthorizeRevoke(t *testing.T) {
type test struct { type test struct {
p *K8sSA p *K8sSA
token string token string
ctx context.Context
err error err error
code int
} }
tests := map[string]func(*testing.T) test{ tests := map[string]func(*testing.T) test{
"fail/invalid-token": func(t *testing.T) test { "fail/invalid-token": func(t *testing.T) test {
@ -119,21 +144,8 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
return test{ return test{
p: p, p: p,
token: "foo", token: "foo",
err: errors.New("error parsing token"), code: http.StatusUnauthorized,
} err: errors.New("k8ssa.AuthorizeRevoke: k8ssa.authorizeToken; error parsing k8sSA token"),
},
"fail/ssh-unimplemented": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
p, err := generateK8sSA(jwk.Public().Key)
assert.FatalError(t, err)
tok, err := generateK8sSAToken(jwk, nil)
assert.FatalError(t, err)
return test{
p: p,
ctx: NewContextWithMethod(context.Background(), SignSSHMethod),
token: tok,
err: errors.Errorf("ssh certificates not enabled for k8s ServiceAccount provisioners"),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
@ -145,7 +157,6 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
ctx: NewContextWithMethod(context.Background(), SignMethod),
token: tok, token: tok,
} }
}, },
@ -153,10 +164,110 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
for name, tt := range tests { for name, tt := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if opts, err := tc.p.AuthorizeSign(tc.ctx, tc.token); err != nil { if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestK8sSA_AuthorizeRenew(t *testing.T) {
type test struct {
p *K8sSA
cert *x509.Certificate
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/renew-disabled": func(t *testing.T) test {
p, err := generateK8sSA(nil)
assert.FatalError(t, err)
// disable renewal
disable := true
p.Claims = &Claims{DisableRenewal: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
return test{
p: p,
cert: &x509.Certificate{},
code: http.StatusUnauthorized,
err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID()),
}
},
"ok": func(t *testing.T) test {
p, err := generateK8sSA(nil)
assert.FatalError(t, err)
return test{
p: p,
cert: &x509.Certificate{},
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestK8sSA_AuthorizeSign(t *testing.T) {
type test struct {
p *K8sSA
token string
code int
err error
}
tests := map[string]func(*testing.T) test{
"fail/invalid-token": func(t *testing.T) test {
p, err := generateK8sSA(nil)
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("k8ssa.AuthorizeSign: k8ssa.authorizeToken; error parsing k8sSA token"),
}
},
"ok": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
p, err := generateK8sSA(jwk.Public().Key)
assert.FatalError(t, err)
tok, err := generateK8sSAToken(jwk, nil)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) { if assert.NotNil(t, opts) {
@ -187,20 +298,37 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
} }
} }
func TestK8sSA_AuthorizeRevoke(t *testing.T) { func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
type test struct { type test struct {
p *K8sSA p *K8sSA
token string token string
code int
err error err error
} }
tests := map[string]func(*testing.T) test{ tests := map[string]func(*testing.T) test{
"fail/sshCA-disabled": func(t *testing.T) test {
p, err := generateK8sSA(nil)
assert.FatalError(t, err)
// disable sshCA
disable := false
p.Claims = &Claims{EnableSSHCA: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID()),
}
},
"fail/invalid-token": func(t *testing.T) test { "fail/invalid-token": func(t *testing.T) test {
p, err := generateK8sSA(nil) p, err := generateK8sSA(nil)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
token: "foo", token: "foo",
err: errors.New("error parsing token"), code: http.StatusUnauthorized,
err: errors.New("k8ssa.AuthorizeSSHSign: k8ssa.authorizeToken; error parsing k8sSA token"),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
@ -219,45 +347,36 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) {
for name, tt := range tests { for name, tt := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeRevoke(context.TODO(), tc.token); err != nil { if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else {
assert.Nil(t, tc.err) if assert.Nil(t, tc.err) {
} if assert.NotNil(t, opts) {
}) tot := 0
} for _, o := range opts {
} switch v := o.(type) {
case sshCertDefaultsModifier:
func TestK8sSA_AuthorizeRenew(t *testing.T) { assert.Equals(t, v.CertType, SSHUserCert)
p1, err := generateK8sSA(nil) case *sshDefaultExtensionModifier:
assert.FatalError(t, err) case *sshCertValidityValidator:
p2, err := generateK8sSA(nil) assert.Equals(t, v.Claimer, tc.p.claimer)
assert.FatalError(t, err) case *sshDefaultPublicKeyValidator:
case *sshCertDefaultValidator:
// disable renewal case *sshDefaultDuration:
disable := true assert.Equals(t, v.Claimer, tc.p.claimer)
p2.Claims = &Claims{DisableRenewal: &disable} default:
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
assert.FatalError(t, err) }
tot++
type args struct { }
cert *x509.Certificate assert.Equals(t, tot, 6)
} }
tests := []struct { }
name string
prov *K8sSA
args args
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("X5C.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} }
}) })
} }

View file

@ -16,14 +16,16 @@ const (
SignMethod Method = iota SignMethod Method = iota
// RevokeMethod is the method used to revoke X.509 certificates. // RevokeMethod is the method used to revoke X.509 certificates.
RevokeMethod RevokeMethod
// SignSSHMethod is the method used to sign SSH certificates. // RenewMethod is the method used to renew X.509 certificates.
SignSSHMethod RenewMethod
// RenewSSHMethod is the method used to renew SSH certificates. // SSHSignMethod is the method used to sign SSH certificates.
RenewSSHMethod SSHSignMethod
// RevokeSSHMethod is the method used to revoke SSH certificates. // SSHRenewMethod is the method used to renew SSH certificates.
RevokeSSHMethod SSHRenewMethod
// RekeySSHMethod is the method used to rekey SSH certificates. // SSHRevokeMethod is the method used to revoke SSH certificates.
RekeySSHMethod SSHRevokeMethod
// SSHRekeyMethod is the method used to rekey SSH certificates.
SSHRekeyMethod
) )
// String returns a string representation of the context method. // String returns a string representation of the context method.
@ -33,14 +35,16 @@ func (m Method) String() string {
return "sign-method" return "sign-method"
case RevokeMethod: case RevokeMethod:
return "revoke-method" return "revoke-method"
case SignSSHMethod: case RenewMethod:
return "sign-ssh-method" return "renew-method"
case RenewSSHMethod: case SSHSignMethod:
return "renew-ssh-method" return "ssh-sign-method"
case RevokeSSHMethod: case SSHRenewMethod:
return "revoke-ssh-method" return "ssh-renew-method"
case RekeySSHMethod: case SSHRevokeMethod:
return "rekey-ssh-method" return "ssh-revoke-method"
case SSHRekeyMethod:
return "ssh-rekey-method"
default: default:
return "unknown" return "unknown"
} }

View file

@ -14,8 +14,8 @@ func Test_noop(t *testing.T) {
assert.Equals(t, "noop", p.GetName()) assert.Equals(t, "noop", p.GetName())
assert.Equals(t, noopType, p.GetType()) assert.Equals(t, noopType, p.GetType())
assert.Equals(t, nil, p.Init(Config{})) assert.Equals(t, nil, p.Init(Config{}))
assert.Equals(t, nil, p.AuthorizeRenew(context.TODO(), &x509.Certificate{})) assert.Equals(t, nil, p.AuthorizeRenew(context.Background(), &x509.Certificate{}))
assert.Equals(t, nil, p.AuthorizeRevoke(context.TODO(), "foo")) assert.Equals(t, nil, p.AuthorizeRevoke(context.Background(), "foo"))
kid, key, ok := p.GetEncryptedKey() kid, key, ok := p.GetEncryptedKey()
assert.Equals(t, "", kid) assert.Equals(t, "", kid)

View file

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
) )
@ -189,17 +190,17 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
Audience: jose.Audience{o.ClientID}, Audience: jose.Audience{o.ClientID},
Time: time.Now().UTC(), Time: time.Now().UTC(),
}, time.Minute); err != nil { }, time.Minute); err != nil {
return errors.Wrap(err, "failed to validate payload") return errs.Wrap(http.StatusUnauthorized, err, "validatePayload: failed to validate oidc token payload")
} }
// Validate azp if present // Validate azp if present
if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID { if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID {
return errors.New("failed to validate payload: invalid azp") return errs.Unauthorized("validatePayload: failed to validate oidc token payload: invalid azp")
} }
// Enforce an email claim // Enforce an email claim
if p.Email == "" { if p.Email == "" {
return errors.New("failed to validate payload: email not found") return errs.Unauthorized("validatePayload: failed to validate oidc token payload: email not found")
} }
// Validate domains (case-insensitive) // Validate domains (case-insensitive)
@ -213,7 +214,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
} }
} }
if !found { if !found {
return errors.New("failed to validate payload: email is not allowed") return errs.Unauthorized("validatePayload: failed to validate oidc token payload: email is not allowed")
} }
} }
@ -229,7 +230,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
} }
} }
if !found { if !found {
return errors.New("validation failed: invalid group") return errs.Unauthorized("validatePayload: oidc token payload validation failed: invalid group")
} }
} }
@ -241,13 +242,15 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) { func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) {
jwt, err := jose.ParseSigned(token) jwt, err := jose.ParseSigned(token)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error parsing token") return nil, errs.Wrap(http.StatusUnauthorized, err,
"oidc.AuthorizeToken; error parsing oidc token")
} }
// Parse claims to get the kid // Parse claims to get the kid
var claims openIDPayload var claims openIDPayload
if err := jwt.UnsafeClaimsWithoutVerification(&claims); err != nil { if err := jwt.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, errors.Wrap(err, "error parsing claims") return nil, errs.Wrap(http.StatusUnauthorized, err,
"oidc.AuthorizeToken; error parsing oidc token claims")
} }
found := false found := false
@ -260,11 +263,11 @@ func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) {
} }
} }
if !found { if !found {
return nil, errors.New("cannot validate token") return nil, errs.Unauthorized("oidc.AuthorizeToken; cannot validate oidc token")
} }
if err := o.ValidatePayload(claims); err != nil { if err := o.ValidatePayload(claims); err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeToken")
} }
return &claims, nil return &claims, nil
@ -276,21 +279,21 @@ func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) {
func (o *OIDC) AuthorizeRevoke(ctx context.Context, token string) error { func (o *OIDC) AuthorizeRevoke(ctx context.Context, token string) error {
claims, err := o.authorizeToken(token) claims, err := o.authorizeToken(token)
if err != nil { if err != nil {
return err return errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeRevoke")
} }
// Only admins can revoke certificates. // Only admins can revoke certificates.
if o.IsAdmin(claims.Email) { if o.IsAdmin(claims.Email) {
return nil return nil
} }
return errors.New("cannot revoke with non-admin token") return errs.Unauthorized("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token")
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := o.authorizeToken(token) claims, err := o.authorizeToken(token)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSign")
} }
so := []SignOption{ so := []SignOption{
@ -315,7 +318,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if o.claimer.IsDisableRenewal() { if o.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", o.GetID()) return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner %s", o.GetID())
} }
return nil return nil
} }
@ -323,22 +326,22 @@ func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !o.claimer.IsSSHCAEnabled() { if !o.claimer.IsSSHCAEnabled() {
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", o.GetID()) return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner %s", o.GetID())
} }
claims, err := o.authorizeToken(token) claims, err := o.authorizeToken(token)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
} }
signOptions := []SignOption{ signOptions := []SignOption{
// set the key id to the token email // set the key id to the token email
sshCertificateKeyIDModifier(claims.Email), sshCertKeyIDModifier(claims.Email),
} }
// Get the identity using either the default identityFunc or one injected // Get the identity using either the default identityFunc or one injected
// externally. // externally.
iden, err := o.getIdentityFunc(o, claims.Email) iden, err := o.getIdentityFunc(o, claims.Email)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "authorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
} }
defaults := SSHOptions{ defaults := SSHOptions{
CertType: SSHUserCert, CertType: SSHUserCert,
@ -349,12 +352,12 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
// Non-admin users can only use principals returned by the identityFunc, and // Non-admin users can only use principals returned by the identityFunc, and
// can only sign user certificates. // can only sign user certificates.
if !o.IsAdmin(claims.Email) { if !o.IsAdmin(claims.Email) {
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) signOptions = append(signOptions, sshCertOptionsValidator(defaults))
} }
// Default to a user certificate with usernames as principals if those options // Default to a user certificate with usernames as principals if those options
// are not set. // are not set.
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults)) signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
return append(signOptions, return append(signOptions,
// Set the default extensions // Set the default extensions
@ -364,9 +367,9 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertificateValidityValidator{o.claimer}, &sshCertValidityValidator{o.claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil
} }
@ -374,14 +377,14 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
func (o *OIDC) AuthorizeSSHRevoke(ctx context.Context, token string) error { func (o *OIDC) AuthorizeSSHRevoke(ctx context.Context, token string) error {
claims, err := o.authorizeToken(token) claims, err := o.authorizeToken(token)
if err != nil { if err != nil {
return err return errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHRevoke")
} }
// Only admins can revoke certificates. // Only admins can revoke certificates.
if o.IsAdmin(claims.Email) { if !o.IsAdmin(claims.Email) {
return nil return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token")
} }
return errors.New("cannot revoke with non-admin token") return nil
} }
func getAndDecode(uri string, v interface{}) error { func getAndDecode(uri string, v interface{}) error {

View file

@ -7,12 +7,14 @@ import (
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"net/http"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
) )
@ -206,20 +208,21 @@ func TestOIDC_authorizeToken(t *testing.T) {
name string name string
prov *OIDC prov *OIDC
args args args args
code int
wantErr bool wantErr bool
}{ }{
{"ok1", p1, args{t1}, false}, {"ok1", p1, args{t1}, http.StatusOK, false},
{"ok2", p2, args{t2}, false}, {"ok2", p2, args{t2}, http.StatusOK, false},
{"fail-email", p3, args{failEmail}, true}, {"fail-email", p3, args{failEmail}, http.StatusUnauthorized, true},
{"fail-domain", p3, args{failDomain}, true}, {"fail-domain", p3, args{failDomain}, http.StatusUnauthorized, true},
{"fail-key", p1, args{failKey}, true}, {"fail-key", p1, args{failKey}, http.StatusUnauthorized, true},
{"fail-token", p1, args{failTok}, true}, {"fail-token", p1, args{failTok}, http.StatusUnauthorized, true},
{"fail-claims", p1, args{failClaims}, true}, {"fail-claims", p1, args{failClaims}, http.StatusUnauthorized, true},
{"fail-issuer", p1, args{failIss}, true}, {"fail-issuer", p1, args{failIss}, http.StatusUnauthorized, true},
{"fail-audience", p1, args{failAud}, true}, {"fail-audience", p1, args{failAud}, http.StatusUnauthorized, true},
{"fail-signature", p1, args{failSig}, true}, {"fail-signature", p1, args{failSig}, http.StatusUnauthorized, true},
{"fail-expired", p1, args{failExp}, true}, {"fail-expired", p1, args{failExp}, http.StatusUnauthorized, true},
{"fail-not-before", p1, args{failNbf}, true}, {"fail-not-before", p1, args{failNbf}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -230,6 +233,9 @@ func TestOIDC_authorizeToken(t *testing.T) {
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else { } else {
assert.NotNil(t, got) assert.NotNil(t, got)
@ -282,21 +288,24 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
name string name string
prov *OIDC prov *OIDC
args args args args
code int
wantErr bool wantErr bool
}{ }{
{"ok1", p1, args{t1}, false}, {"ok1", p1, args{t1}, http.StatusOK, false},
{"admin", p3, args{okAdmin}, false}, {"admin", p3, args{okAdmin}, http.StatusOK, false},
{"fail-email", p3, args{failEmail}, true}, {"fail-email", p3, args{failEmail}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignMethod) got, err := tt.prov.AuthorizeSign(context.Background(), tt.args.token)
got, err := tt.prov.AuthorizeSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else { } else {
if assert.NotNil(t, got) { if assert.NotNil(t, got) {
@ -330,6 +339,107 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
} }
} }
func TestOIDC_AuthorizeRevoke(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
var keys jose.JSONWebKeySet
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
// Create test provisioners
p1, err := generateOIDC()
assert.FatalError(t, err)
p3, err := generateOIDC()
assert.FatalError(t, err)
// Admin + Domains
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
p3.Domains = []string{"smallstep.com"}
// Update configuration endpoints and initialize
config := Config{Claims: globalProvisionerClaims}
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
assert.FatalError(t, p1.Init(config))
assert.FatalError(t, p3.Init(config))
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
// Admin email not in domains
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
// Invalid email
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
prov *OIDC
args args
code int
wantErr bool
}{
{"ok1", p1, args{t1}, http.StatusUnauthorized, true},
{"admin", p3, args{okAdmin}, http.StatusOK, false},
{"fail-email", p3, args{failEmail}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token)
if (err != nil) != tt.wantErr {
fmt.Println(tt)
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}
})
}
}
func TestOIDC_AuthorizeRenew(t *testing.T) {
p1, err := generateOIDC()
assert.FatalError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
prov *OIDC
args args
code int
wantErr bool
}{
{"ok", p1, args{nil}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert)
if (err != nil) != tt.wantErr {
t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}
})
}
}
func TestOIDC_AuthorizeSSHSign(t *testing.T) { func TestOIDC_AuthorizeSSHSign(t *testing.T) {
tm, fn := mockNow() tm, fn := mockNow()
defer fn() defer fn()
@ -351,9 +461,16 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
p5, err := generateOIDC() p5, err := generateOIDC()
assert.FatalError(t, err) assert.FatalError(t, err)
p6, err := generateOIDC()
assert.FatalError(t, err)
// Admin + Domains // Admin + Domains
p3.Admins = []string{"name@smallstep.com", "root@example.com"} p3.Admins = []string{"name@smallstep.com", "root@example.com"}
p3.Domains = []string{"smallstep.com"} p3.Domains = []string{"smallstep.com"}
// disable sshCA
disable := false
p6.Claims = &Claims{EnableSSHCA: &disable}
p6.claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
// Update configuration endpoints and initialize // Update configuration endpoints and initialize
config := Config{Claims: globalProvisionerClaims} config := Config{Claims: globalProvisionerClaims}
@ -425,48 +542,53 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
prov *OIDC prov *OIDC
args args args args
expected *SSHOptions expected *SSHOptions
code int
wantErr bool wantErr bool
wantSignErr bool wantSignErr bool
}{ }{
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, false, false}, {"ok", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, http.StatusOK, false, false},
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, false, false}, {"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, http.StatusOK, false, false},
{"ok-user", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, false, false}, {"ok-user", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, http.StatusOK, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub}, {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub},
&SSHOptions{CertType: "user", Principals: []string{"name"}, &SSHOptions{CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
{"ok-principals-getIdentity", p4, args{okGetIdentityToken, SSHOptions{Principals: []string{"mariano"}}, pub}, {"ok-principals-getIdentity", p4, args{okGetIdentityToken, SSHOptions{Principals: []string{"mariano"}}, pub},
&SSHOptions{CertType: "user", Principals: []string{"mariano"}, &SSHOptions{CertType: "user", Principals: []string{"mariano"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
{"ok-emptyPrincipals-getIdentity", p4, args{okGetIdentityToken, SSHOptions{}, pub}, {"ok-emptyPrincipals-getIdentity", p4, args{okGetIdentityToken, SSHOptions{}, pub},
&SSHOptions{CertType: "user", Principals: []string{"max", "mariano"}, &SSHOptions{CertType: "user", Principals: []string{"max", "mariano"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, {"ok-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub},
&SSHOptions{CertType: "user", Principals: []string{"name"}, &SSHOptions{CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
{"admin", p3, args{okAdmin, SSHOptions{}, pub}, expectedAdminOptions, false, false}, {"admin", p3, args{okAdmin, SSHOptions{}, pub}, expectedAdminOptions, http.StatusOK, false, false},
{"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}, pub}, expectedAdminOptions, false, false}, {"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}, pub}, expectedAdminOptions, http.StatusOK, false, false},
{"admin-principals", p3, args{okAdmin, SSHOptions{Principals: []string{"root"}}, pub}, {"admin-principals", p3, args{okAdmin, SSHOptions{Principals: []string{"root"}}, pub},
&SSHOptions{CertType: "user", Principals: []string{"root"}, &SSHOptions{CertType: "user", Principals: []string{"root"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
{"admin-options", p3, args{okAdmin, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, {"admin-options", p3, args{okAdmin, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub},
&SSHOptions{CertType: "user", Principals: []string{"name"}, &SSHOptions{CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
{"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false}, {"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub},
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, false, true}, expectedHostOptions, http.StatusOK, false, false},
{"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}, pub}, nil, false, true}, {"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true},
{"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}, pub}, nil, false, true}, {"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}, pub}, nil, http.StatusOK, false, true},
{"fail-email", p3, args{failEmail, SSHOptions{}, pub}, nil, true, false}, {"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}, pub}, nil, http.StatusOK, false, true},
{"fail-getIdentity", p5, args{failGetIdentityToken, SSHOptions{}, pub}, nil, true, false}, {"fail-email", p3, args{failEmail, SSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false},
{"fail-getIdentity", p5, args{failGetIdentityToken, SSHOptions{}, pub}, nil, http.StatusInternalServerError, true, false},
{"fail-sshCA-disabled", p6, args{"foo", SSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignSSHMethod) got, err := tt.prov.AuthorizeSSHSign(context.Background(), tt.args.token)
got, err := tt.prov.AuthorizeSSHSign(ctx, tt.args.token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("OIDC.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("OIDC.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
@ -484,36 +606,32 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
} }
} }
func TestOIDC_AuthorizeRevoke(t *testing.T) { func TestOIDC_AuthorizeSSHRevoke(t *testing.T) {
p1, err := generateOIDC()
assert.FatalError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
p2.Admins = []string{"root@example.com"}
srv := generateJWKServer(2) srv := generateJWKServer(2)
defer srv.Close() defer srv.Close()
var keys jose.JSONWebKeySet var keys jose.JSONWebKeySet
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys)) assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
// Create test provisioners
p1, err := generateOIDC()
assert.FatalError(t, err)
p3, err := generateOIDC()
assert.FatalError(t, err)
// Admin + Domains
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
p3.Domains = []string{"smallstep.com"}
// Update configuration endpoints and initialize
config := Config{Claims: globalProvisionerClaims} config := Config{Claims: globalProvisionerClaims}
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
assert.FatalError(t, p1.Init(config)) assert.FatalError(t, p1.Init(config))
assert.FatalError(t, p3.Init(config)) assert.FatalError(t, p2.Init(config))
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0]) // Invalid email
failEmail, err := generateToken("subject", "the-issuer", p1.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
// Admin email not in domains // Admin email not in domains
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0]) noAdmin, err := generateToken("subject", "the-issuer", p1.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
// Invalid email // Admin email in domains
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0]) okAdmin, err := generateToken("subject", "the-issuer", p2.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
type args struct { type args struct {
@ -523,52 +641,22 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) {
name string name string
prov *OIDC prov *OIDC
args args args args
code int
wantErr bool wantErr bool
}{ }{
{"ok1", p1, args{t1}, true}, {"ok", p2, args{okAdmin}, http.StatusOK, false},
{"admin", p3, args{okAdmin}, false}, {"fail/invalid-token", p1, args{failEmail}, http.StatusUnauthorized, true},
{"fail-email", p3, args{failEmail}, true}, {"fail/not-admin", p1, args{noAdmin}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
err := tt.prov.AuthorizeRevoke(context.TODO(), tt.args.token) err := tt.prov.AuthorizeSSHRevoke(context.Background(), tt.args.token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
fmt.Println(tt) t.Errorf("OIDC.AuthorizeSSHRevoke() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil {
return sc, ok := err.(errs.StatusCoder)
} assert.Fatal(t, ok, "error does not implement StatusCoder interface")
}) assert.Equals(t, sc.StatusCode(), tt.code)
}
}
func TestOIDC_AuthorizeRenew(t *testing.T) {
p1, err := generateOIDC()
assert.FatalError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
prov *OIDC
args args
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} }
}) })
} }

View file

@ -10,6 +10,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -283,43 +284,43 @@ type base struct{}
// AuthorizeSign returns an unimplmented error. Provisioners should overwrite // AuthorizeSign returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for signing x509 Certificates. // this method if they will support authorizing tokens for signing x509 Certificates.
func (b *base) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (b *base) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return nil, errors.New("not implemented; provisioner does not implement AuthorizeSign") return nil, errs.Unauthorized("provisioner.AuthorizeSign not implemented")
} }
// AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite // AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for revoking x509 Certificates. // this method if they will support authorizing tokens for revoking x509 Certificates.
func (b *base) AuthorizeRevoke(ctx context.Context, token string) error { func (b *base) AuthorizeRevoke(ctx context.Context, token string) error {
return errors.New("not implemented; provisioner does not implement AuthorizeRevoke") return errs.Unauthorized("provisioner.AuthorizeRevoke not implemented")
} }
// AuthorizeRenew returns an unimplmented error. Provisioners should overwrite // AuthorizeRenew returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for renewing x509 Certificates. // this method if they will support authorizing tokens for renewing x509 Certificates.
func (b *base) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (b *base) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
return errors.New("not implemented; provisioner does not implement AuthorizeRenew") return errs.Unauthorized("provisioner.AuthorizeRenew not implemented")
} }
// AuthorizeSSHSign returns an unimplmented error. Provisioners should overwrite // AuthorizeSSHSign returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for signing SSH Certificates. // this method if they will support authorizing tokens for signing SSH Certificates.
func (b *base) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (b *base) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
return nil, errors.New("not implemented; provisioner does not implement AuthorizeSSHSign") return nil, errs.Unauthorized("provisioner.AuthorizeSSHSign not implemented")
} }
// AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite // AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for revoking SSH Certificates. // this method if they will support authorizing tokens for revoking SSH Certificates.
func (b *base) AuthorizeSSHRevoke(ctx context.Context, token string) error { func (b *base) AuthorizeSSHRevoke(ctx context.Context, token string) error {
return errors.New("not implemented; provisioner does not implement AuthorizeSSHRevoke") return errs.Unauthorized("provisioner.AuthorizeSSHRevoke not implemented")
} }
// AuthorizeSSHRenew returns an unimplmented error. Provisioners should overwrite // AuthorizeSSHRenew returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for renewing SSH Certificates. // this method if they will support authorizing tokens for renewing SSH Certificates.
func (b *base) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { func (b *base) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
return nil, errors.New("not implemented; provisioner does not implement AuthorizeSSHRenew") return nil, errs.Unauthorized("provisioner.AuthorizeSSHRenew not implemented")
} }
// AuthorizeSSHRekey returns an unimplmented error. Provisioners should overwrite // AuthorizeSSHRekey returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for rekeying SSH Certificates. // this method if they will support authorizing tokens for rekeying SSH Certificates.
func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) { func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
return nil, nil, errors.New("not implemented; provisioner does not implement AuthorizeSSHRekey") return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented")
} }
// Identity is the type representing an externally supplied identity that is used // Identity is the type representing an externally supplied identity that is used

View file

@ -1,10 +1,14 @@
package provisioner package provisioner
import ( import (
"context"
"net/http"
"testing" "testing"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
) )
func TestType_String(t *testing.T) { func TestType_String(t *testing.T) {
@ -101,3 +105,93 @@ func TestDefaultIdentityFunc(t *testing.T) {
}) })
} }
} }
func TestUnimplementedMethods(t *testing.T) {
tests := []struct {
name string
p Interface
method Method
}{
{"jwk/sshRekey", &JWK{}, SSHRekeyMethod},
{"jwk/sshRenew", &JWK{}, SSHRenewMethod},
{"aws/revoke", &AWS{}, RevokeMethod},
{"aws/sshRenew", &AWS{}, SSHRenewMethod},
{"aws/rekey", &AWS{}, SSHRekeyMethod},
{"aws/sshRevoke", &AWS{}, SSHRevokeMethod},
{"azure/revoke", &Azure{}, RevokeMethod},
{"azure/sshRenew", &Azure{}, SSHRenewMethod},
{"azure/sshRekey", &Azure{}, SSHRekeyMethod},
{"azure/sshRevoke", &Azure{}, SSHRevokeMethod},
{"gcp/revoke", &GCP{}, RevokeMethod},
{"gcp/sshRenew", &GCP{}, SSHRenewMethod},
{"gcp/sshRekey", &GCP{}, SSHRekeyMethod},
{"gcp/sshRevoke", &GCP{}, SSHRevokeMethod},
{"oidc/sshRenew", &OIDC{}, SSHRenewMethod},
{"oidc/sshRekey", &OIDC{}, SSHRekeyMethod},
{"x5c/sshRenew", &X5C{}, SSHRenewMethod},
{"x5c/sshRekey", &X5C{}, SSHRekeyMethod},
{"x5c/sshRevoke", &X5C{}, SSHRekeyMethod},
{"acme/revoke", &ACME{}, RevokeMethod},
{"acme/sshSign", &ACME{}, SSHSignMethod},
{"acme/sshRekey", &ACME{}, SSHRekeyMethod},
{"acme/sshRenew", &ACME{}, SSHRenewMethod},
{"acme/sshRevoke", &ACME{}, SSHRevokeMethod},
{"sshpop/sign", &SSHPOP{}, SignMethod},
{"sshpop/renew", &SSHPOP{}, RenewMethod},
{"sshpop/revoke", &SSHPOP{}, RevokeMethod},
{"sshpop/sshSign", &SSHPOP{}, SSHSignMethod},
{"k8ssa/sshRekey", &K8sSA{}, SSHRekeyMethod},
{"k8ssa/sshRenew", &K8sSA{}, SSHRenewMethod},
{"k8ssa/sshRevoke", &K8sSA{}, SSHRevokeMethod},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var (
err error
msg string
)
switch tt.method {
case SignMethod:
var signOpts []SignOption
signOpts, err = tt.p.AuthorizeSign(context.Background(), "")
assert.Nil(t, signOpts)
msg = "provisioner.AuthorizeSign not implemented"
case RenewMethod:
err = tt.p.AuthorizeRenew(context.Background(), nil)
msg = "provisioner.AuthorizeRenew not implemented"
case RevokeMethod:
err = tt.p.AuthorizeRevoke(context.Background(), "")
msg = "provisioner.AuthorizeRevoke not implemented"
case SSHSignMethod:
var signOpts []SignOption
signOpts, err = tt.p.AuthorizeSSHSign(context.Background(), "")
assert.Nil(t, signOpts)
msg = "provisioner.AuthorizeSSHSign not implemented"
case SSHRenewMethod:
var cert *ssh.Certificate
cert, err = tt.p.AuthorizeSSHRenew(context.Background(), "")
assert.Nil(t, cert)
msg = "provisioner.AuthorizeSSHRenew not implemented"
case SSHRekeyMethod:
var (
cert *ssh.Certificate
signOpts []SignOption
)
cert, signOpts, err = tt.p.AuthorizeSSHRekey(context.Background(), "")
assert.Nil(t, cert)
assert.Nil(t, signOpts)
msg = "provisioner.AuthorizeSSHRekey not implemented"
case SSHRevokeMethod:
err = tt.p.AuthorizeSSHRevoke(context.Background(), "")
msg = "provisioner.AuthorizeSSHRevoke not implemented"
default:
t.Errorf("unexpected method %s", tt.method)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), http.StatusUnauthorized)
assert.Equals(t, err.Error(), msg)
})
}
}

View file

@ -30,7 +30,7 @@ type SignOption interface{}
// CertificateValidator is the interface used to validate a X.509 certificate. // CertificateValidator is the interface used to validate a X.509 certificate.
type CertificateValidator interface { type CertificateValidator interface {
SignOption SignOption
Valid(crt *x509.Certificate) error Valid(cert *x509.Certificate, o Options) error
} }
// CertificateRequestValidator is the interface used to validate a X.509 // CertificateRequestValidator is the interface used to validate a X.509
@ -106,7 +106,7 @@ func (v commonNameValidator) Valid(req *x509.CertificateRequest) error {
return errors.New("certificate request cannot contain an empty common name") return errors.New("certificate request cannot contain an empty common name")
} }
if req.Subject.CommonName != string(v) { if req.Subject.CommonName != string(v) {
return errors.Errorf("certificate request does not contain the valid common name, got %s, want %s", req.Subject.CommonName, v) return errors.Errorf("certificate request does not contain the valid common name; requested common name = %s, token subject = %s", req.Subject.CommonName, v)
} }
return nil return nil
} }
@ -265,35 +265,32 @@ func newValidityValidator(min, max time.Duration) *validityValidator {
// Valid validates the certificate validity settings (notBefore/notAfter) and // Valid validates the certificate validity settings (notBefore/notAfter) and
// and total duration. // and total duration.
func (v *validityValidator) Valid(crt *x509.Certificate) error { func (v *validityValidator) Valid(cert *x509.Certificate, o Options) error {
var ( var (
na = crt.NotAfter.Truncate(time.Second) na = cert.NotAfter.Truncate(time.Second)
nb = crt.NotBefore.Truncate(time.Second) nb = cert.NotBefore.Truncate(time.Second)
now = time.Now().Truncate(time.Second) now = time.Now().Truncate(time.Second)
) )
// To not take into account the backdate, time.Now() will be used to d := na.Sub(nb)
// calculate the duration if NotBefore is in the past.
var d time.Duration
if now.After(nb) {
d = na.Sub(now)
} else {
d = na.Sub(nb)
}
if na.Before(now) { if na.Before(now) {
return errors.Errorf("NotAfter: %v cannot be in the past", na) return errors.Errorf("notAfter cannot be in the past; na=%v", na)
} }
if na.Before(nb) { if na.Before(nb) {
return errors.Errorf("NotAfter: %v cannot be before NotBefore: %v", na, nb) return errors.Errorf("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb)
} }
if d < v.min { if d < v.min {
return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v", return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v",
d, v.min) d, v.min)
} }
if d > v.max { // NOTE: this check is not "technically correct". We're allowing the max
// duration of a cert to be "max + backdate" and not all certificates will
// be backdated (e.g. if a user passes the NotBefore value then we do not
// apply a backdate). This is good enough.
if d > v.max+o.Backdate {
return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v", return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v",
d, v.max) d, v.max+o.Backdate)
} }
return nil return nil
} }

View file

@ -3,9 +3,10 @@ package provisioner
import ( import (
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"fmt"
"net" "net"
"net/url" "net/url"
"reflect" "strings"
"testing" "testing"
"time" "time"
@ -48,22 +49,22 @@ func Test_emailOnlyIdentity_Valid(t *testing.T) {
} }
func Test_defaultPublicKeyValidator_Valid(t *testing.T) { func Test_defaultPublicKeyValidator_Valid(t *testing.T) {
_shortRSA, err := pemutil.Read("./testdata/short-rsa.csr") _shortRSA, err := pemutil.Read("./testdata/certs/short-rsa.csr")
assert.FatalError(t, err) assert.FatalError(t, err)
shortRSA, ok := _shortRSA.(*x509.CertificateRequest) shortRSA, ok := _shortRSA.(*x509.CertificateRequest)
assert.Fatal(t, ok) assert.Fatal(t, ok)
_rsa, err := pemutil.Read("./testdata/rsa.csr") _rsa, err := pemutil.Read("./testdata/certs/rsa.csr")
assert.FatalError(t, err) assert.FatalError(t, err)
rsaCSR, ok := _rsa.(*x509.CertificateRequest) rsaCSR, ok := _rsa.(*x509.CertificateRequest)
assert.Fatal(t, ok) assert.Fatal(t, ok)
_ecdsa, err := pemutil.Read("./testdata/ecdsa.csr") _ecdsa, err := pemutil.Read("./testdata/certs/ecdsa.csr")
assert.FatalError(t, err) assert.FatalError(t, err)
ecdsaCSR, ok := _ecdsa.(*x509.CertificateRequest) ecdsaCSR, ok := _ecdsa.(*x509.CertificateRequest)
assert.Fatal(t, ok) assert.Fatal(t, ok)
_ed25519, err := pemutil.Read("./testdata/ed25519.csr") _ed25519, err := pemutil.Read("./testdata/certs/ed25519.csr")
assert.FatalError(t, err) assert.FatalError(t, err)
ed25519CSR, ok := _ed25519.(*x509.CertificateRequest) ed25519CSR, ok := _ed25519.(*x509.CertificateRequest)
assert.Fatal(t, ok) assert.Fatal(t, ok)
@ -246,30 +247,191 @@ func Test_ipAddressesValidator_Valid(t *testing.T) {
} }
func Test_validityValidator_Valid(t *testing.T) { func Test_validityValidator_Valid(t *testing.T) {
type fields struct { type test struct {
min time.Duration cert *x509.Certificate
max time.Duration opts Options
vv *validityValidator
err error
} }
type args struct { tests := map[string]func() test{
crt *x509.Certificate "fail/notAfter-past": func() test {
return test{
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
cert: &x509.Certificate{NotAfter: time.Now().Add(-5 * time.Minute)},
opts: Options{},
err: errors.New("notAfter cannot be in the past"),
} }
tests := []struct { },
name string "fail/notBefore-after-notAfter": func() test {
fields fields return test{
args args vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
wantErr bool cert: &x509.Certificate{NotBefore: time.Now().Add(10 * time.Minute),
}{ NotAfter: time.Now().Add(5 * time.Minute)},
// TODO: Add test cases. opts: Options{},
err: errors.New("notAfter cannot be before notBefore"),
} }
for _, tt := range tests { },
t.Run(tt.name, func(t *testing.T) { "fail/duration-too-short": func() test {
v := &validityValidator{ n := now()
min: tt.fields.min, return test{
max: tt.fields.max, vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
cert: &x509.Certificate{NotBefore: n,
NotAfter: n.Add(3 * time.Minute)},
opts: Options{},
err: errors.New("is less than the authorized minimum certificate duration of "),
} }
if err := v.Valid(tt.args.crt); (err != nil) != tt.wantErr { },
t.Errorf("validityValidator.Valid() error = %v, wantErr %v", err, tt.wantErr) "ok/duration-exactly-min": func() test {
n := now()
return test{
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
cert: &x509.Certificate{NotBefore: n,
NotAfter: n.Add(5 * time.Minute)},
opts: Options{},
} }
},
"fail/duration-too-great": func() test {
n := now()
return test{
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
cert: &x509.Certificate{NotBefore: n,
NotAfter: n.Add(24*time.Hour + time.Second)},
err: errors.New("is more than the authorized maximum certificate duration of "),
}
},
"ok/duration-exactly-max": func() test {
n := time.Now()
return test{
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
cert: &x509.Certificate{NotBefore: n,
NotAfter: n.Add(24 * time.Hour)},
}
},
"ok/duration-exact-min-with-backdate": func() test {
now := time.Now()
cert := &x509.Certificate{NotBefore: now, NotAfter: now.Add(5 * time.Minute)}
time.Sleep(time.Second)
return test{
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
cert: cert,
opts: Options{Backdate: time.Second},
}
},
"ok/duration-exact-max-with-backdate": func() test {
backdate := time.Second
now := time.Now()
cert := &x509.Certificate{NotBefore: now, NotAfter: now.Add(24*time.Hour + backdate)}
time.Sleep(backdate)
return test{
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
cert: cert,
opts: Options{Backdate: backdate},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tt := run()
if err := tt.vv.Valid(tt.cert, tt.opts); err != nil {
if assert.NotNil(t, tt.err, fmt.Sprintf("expected no error, but got err = %s", err.Error())) {
assert.True(t, strings.Contains(err.Error(), tt.err.Error()),
fmt.Sprintf("want err = %s, but got err = %s", tt.err.Error(), err.Error()))
}
} else {
assert.Nil(t, tt.err, fmt.Sprintf("expected err = %s, but not <nil>", tt.err))
}
})
}
}
func Test_profileDefaultDuration_Option(t *testing.T) {
type test struct {
so Options
pdd profileDefaultDuration
cert *x509.Certificate
valid func(*x509.Certificate)
}
tests := map[string]func() test{
"ok/notBefore-notAfter-duration-empty": func() test {
return test{
pdd: profileDefaultDuration(0),
so: Options{},
cert: new(x509.Certificate),
valid: func(cert *x509.Certificate) {
n := now()
assert.True(t, n.After(cert.NotBefore))
assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore))
assert.True(t, n.Add(24*time.Hour).After(cert.NotAfter))
assert.True(t, n.Add(24*time.Hour).Add(-1*time.Minute).Before(cert.NotAfter))
},
}
},
"ok/notBefore-set": func() test {
nb := time.Now().Add(5 * time.Minute).UTC()
return test{
pdd: profileDefaultDuration(0),
so: Options{NotBefore: NewTimeDuration(nb)},
cert: new(x509.Certificate),
valid: func(cert *x509.Certificate) {
assert.Equals(t, cert.NotBefore, nb)
assert.Equals(t, cert.NotAfter, nb.Add(24*time.Hour))
},
}
},
"ok/duration-set": func() test {
d := 4 * time.Hour
return test{
pdd: profileDefaultDuration(d),
so: Options{Backdate: time.Second},
cert: new(x509.Certificate),
valid: func(cert *x509.Certificate) {
n := now()
assert.True(t, n.After(cert.NotBefore), fmt.Sprintf("expected now = %s to be after cert.NotBefore = %s", n, cert.NotBefore))
assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore))
assert.True(t, n.Add(d).After(cert.NotAfter))
assert.True(t, n.Add(d).Add(-1*time.Minute).Before(cert.NotAfter))
},
}
},
"ok/notAfter-set": func() test {
na := now().Add(10 * time.Minute).UTC()
return test{
pdd: profileDefaultDuration(0),
so: Options{NotAfter: NewTimeDuration(na)},
cert: new(x509.Certificate),
valid: func(cert *x509.Certificate) {
n := now()
assert.True(t, n.After(cert.NotBefore), fmt.Sprintf("expected now = %s to be after cert.NotBefore = %s", n, cert.NotBefore))
assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore))
assert.Equals(t, cert.NotAfter, na)
},
}
},
"ok/notBefore-and-notAfter-set": func() test {
nb := time.Now().Add(5 * time.Minute).UTC()
na := time.Now().Add(10 * time.Minute).UTC()
d := 4 * time.Hour
return test{
pdd: profileDefaultDuration(d),
so: Options{NotBefore: NewTimeDuration(nb), NotAfter: NewTimeDuration(na)},
cert: new(x509.Certificate),
valid: func(cert *x509.Certificate) {
assert.Equals(t, cert.NotBefore, nb)
assert.Equals(t, cert.NotAfter, na)
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tt := run()
prof := &x509util.Leaf{}
prof.SetSubject(tt.cert)
assert.FatalError(t, tt.pdd.Option(tt.so)(prof), "unexpected error")
tt.valid(prof.Subject())
}) })
} }
} }
@ -381,43 +543,3 @@ func Test_profileLimitDuration_Option(t *testing.T) {
}) })
} }
} }
func Test_profileDefaultDuration_Option(t *testing.T) {
tm, fn := mockNow()
defer fn()
v := profileDefaultDuration(24 * time.Hour)
type args struct {
so Options
}
tests := []struct {
name string
v profileDefaultDuration
args args
want *x509.Certificate
}{
{"default", v, args{Options{}}, &x509.Certificate{NotBefore: tm, NotAfter: tm.Add(24 * time.Hour)}},
{"backdate", v, args{Options{Backdate: 1 * time.Minute}}, &x509.Certificate{NotBefore: tm.Add(-1 * time.Minute), NotAfter: tm.Add(24 * time.Hour)}},
{"notBefore", v, args{Options{NotBefore: NewTimeDuration(tm.Add(10 * time.Second))}}, &x509.Certificate{NotBefore: tm.Add(10 * time.Second), NotAfter: tm.Add(24*time.Hour + 10*time.Second)}},
{"notAfter", v, args{Options{NotAfter: NewTimeDuration(tm.Add(1 * time.Hour))}}, &x509.Certificate{NotBefore: tm, NotAfter: tm.Add(1 * time.Hour)}},
{"notBefore and notAfter", v, args{Options{NotBefore: NewTimeDuration(tm.Add(10 * time.Second)), NotAfter: NewTimeDuration(tm.Add(1 * time.Hour))}},
&x509.Certificate{NotBefore: tm.Add(10 * time.Second), NotAfter: tm.Add(1 * time.Hour)}},
{"notBefore and backdate", v, args{Options{Backdate: 1 * time.Minute, NotBefore: NewTimeDuration(tm.Add(10 * time.Second))}},
&x509.Certificate{NotBefore: tm.Add(10 * time.Second), NotAfter: tm.Add(24*time.Hour + 10*time.Second)}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cert := &x509.Certificate{}
profile := &x509util.Leaf{}
profile.SetSubject(cert)
fn := tt.v.Option(tt.args.so)
if err := fn(profile); err != nil {
t.Errorf("profileDefaultDuration.Option() error = %v", err)
}
if !reflect.DeepEqual(cert, tt.want) {
t.Errorf("profileDefaultDuration.Option() = %v, \nwant %v", cert, tt.want)
}
})
}
}

View file

@ -19,29 +19,29 @@ const (
SSHHostCert = "host" SSHHostCert = "host"
) )
// SSHCertificateModifier is the interface used to change properties in an SSH // SSHCertModifier is the interface used to change properties in an SSH
// certificate. // certificate.
type SSHCertificateModifier interface { type SSHCertModifier interface {
SignOption SignOption
Modify(cert *ssh.Certificate) error Modify(cert *ssh.Certificate) error
} }
// SSHCertificateOptionModifier is the interface used to add custom options used // SSHCertOptionModifier is the interface used to add custom options used
// to modify the SSH certificate. // to modify the SSH certificate.
type SSHCertificateOptionModifier interface { type SSHCertOptionModifier interface {
SignOption SignOption
Option(o SSHOptions) SSHCertificateModifier Option(o SSHOptions) SSHCertModifier
} }
// SSHCertificateValidator is the interface used to validate an SSH certificate. // SSHCertValidator is the interface used to validate an SSH certificate.
type SSHCertificateValidator interface { type SSHCertValidator interface {
SignOption SignOption
Valid(cert *ssh.Certificate) error Valid(cert *ssh.Certificate, opts SSHOptions) error
} }
// SSHCertificateOptionsValidator is the interface used to validate the custom // SSHCertOptionsValidator is the interface used to validate the custom
// options used to modify the SSH certificate. // options used to modify the SSH certificate.
type SSHCertificateOptionsValidator interface { type SSHCertOptionsValidator interface {
SignOption SignOption
Valid(got SSHOptions) error Valid(got SSHOptions) error
} }
@ -69,7 +69,7 @@ func (o SSHOptions) Type() uint32 {
return sshCertTypeUInt32(o.CertType) return sshCertTypeUInt32(o.CertType)
} }
// Modify implements SSHCertificateModifier and sets the SSHOption in the ssh.Certificate. // Modify implements SSHCertModifier and sets the SSHOption in the ssh.Certificate.
func (o SSHOptions) Modify(cert *ssh.Certificate) error { func (o SSHOptions) Modify(cert *ssh.Certificate) error {
switch o.CertType { switch o.CertType {
case "": // ignore case "": // ignore
@ -78,7 +78,7 @@ func (o SSHOptions) Modify(cert *ssh.Certificate) error {
case SSHHostCert: case SSHHostCert:
cert.CertType = ssh.HostCert cert.CertType = ssh.HostCert
default: default:
return errors.Errorf("ssh certificate has an unknown type: %s", o.CertType) return errors.Errorf("ssh certificate has an unknown type - %s", o.CertType)
} }
cert.KeyId = o.KeyID cert.KeyId = o.KeyID
@ -116,7 +116,7 @@ func (o SSHOptions) match(got SSHOptions) error {
return nil return nil
} }
// sshCertPrincipalsModifier is an SSHCertificateModifier that sets the // sshCertPrincipalsModifier is an SSHCertModifier that sets the
// principals to the SSH certificate. // principals to the SSH certificate.
type sshCertPrincipalsModifier []string type sshCertPrincipalsModifier []string
@ -126,16 +126,16 @@ func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error {
return nil return nil
} }
// sshCertificateKeyIDModifier is an SSHCertificateModifier that sets the given // sshCertKeyIDModifier is an SSHCertModifier that sets the given
// Key ID in the SSH certificate. // Key ID in the SSH certificate.
type sshCertificateKeyIDModifier string type sshCertKeyIDModifier string
func (m sshCertificateKeyIDModifier) Modify(cert *ssh.Certificate) error { func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate) error {
cert.KeyId = string(m) cert.KeyId = string(m)
return nil return nil
} }
// sshCertTypeModifier is an SSHCertificateModifier that sets the // sshCertTypeModifier is an SSHCertModifier that sets the
// certificate type. // certificate type.
type sshCertTypeModifier string type sshCertTypeModifier string
@ -145,30 +145,30 @@ func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error {
return nil return nil
} }
// sshCertificateValidAfterModifier is an SSHCertificateModifier that sets the // sshCertValidAfterModifier is an SSHCertModifier that sets the
// ValidAfter in the SSH certificate. // ValidAfter in the SSH certificate.
type sshCertificateValidAfterModifier uint64 type sshCertValidAfterModifier uint64
func (m sshCertificateValidAfterModifier) Modify(cert *ssh.Certificate) error { func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate) error {
cert.ValidAfter = uint64(m) cert.ValidAfter = uint64(m)
return nil return nil
} }
// sshCertificateValidBeforeModifier is an SSHCertificateModifier that sets the // sshCertValidBeforeModifier is an SSHCertModifier that sets the
// ValidBefore in the SSH certificate. // ValidBefore in the SSH certificate.
type sshCertificateValidBeforeModifier uint64 type sshCertValidBeforeModifier uint64
func (m sshCertificateValidBeforeModifier) Modify(cert *ssh.Certificate) error { func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate) error {
cert.ValidBefore = uint64(m) cert.ValidBefore = uint64(m)
return nil return nil
} }
// sshCertificateDefaultModifier implements a SSHCertificateModifier that // sshCertDefaultsModifier implements a SSHCertModifier that
// modifies the certificate with the given options if they are not set. // modifies the certificate with the given options if they are not set.
type sshCertificateDefaultsModifier SSHOptions type sshCertDefaultsModifier SSHOptions
// Modify implements the SSHCertificateModifier interface. // Modify implements the SSHCertModifier interface.
func (m sshCertificateDefaultsModifier) Modify(cert *ssh.Certificate) error { func (m sshCertDefaultsModifier) Modify(cert *ssh.Certificate) error {
if cert.CertType == 0 { if cert.CertType == 0 {
cert.CertType = sshCertTypeUInt32(m.CertType) cert.CertType = sshCertTypeUInt32(m.CertType)
} }
@ -184,7 +184,7 @@ func (m sshCertificateDefaultsModifier) Modify(cert *ssh.Certificate) error {
return nil return nil
} }
// sshDefaultExtensionModifier implements an SSHCertificateModifier that sets // sshDefaultExtensionModifier implements an SSHCertModifier that sets
// the default extensions in an SSH certificate. // the default extensions in an SSH certificate.
type sshDefaultExtensionModifier struct{} type sshDefaultExtensionModifier struct{}
@ -208,14 +208,14 @@ func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate) error {
} }
} }
// sshDefaultDuration is an SSHCertificateModifier that sets the certificate // sshDefaultDuration is an SSHCertModifier that sets the certificate
// ValidAfter and ValidBefore if they have not been set. It will fail if a // ValidAfter and ValidBefore if they have not been set. It will fail if a
// CertType has not been set or is not valid. // CertType has not been set or is not valid.
type sshDefaultDuration struct { type sshDefaultDuration struct {
*Claimer *Claimer
} }
func (m *sshDefaultDuration) Option(o SSHOptions) SSHCertificateModifier { func (m *sshDefaultDuration) Option(o SSHOptions) SSHCertModifier {
return sshModifierFunc(func(cert *ssh.Certificate) error { return sshModifierFunc(func(cert *ssh.Certificate) error {
d, err := m.DefaultSSHCertDuration(cert.CertType) d, err := m.DefaultSSHCertDuration(cert.CertType)
if err != nil { if err != nil {
@ -248,7 +248,7 @@ type sshLimitDuration struct {
NotAfter time.Time NotAfter time.Time
} }
func (m *sshLimitDuration) Option(o SSHOptions) SSHCertificateModifier { func (m *sshLimitDuration) Option(o SSHOptions) SSHCertModifier {
if m.NotAfter.IsZero() { if m.NotAfter.IsZero() {
defaultDuration := &sshDefaultDuration{m.Claimer} defaultDuration := &sshDefaultDuration{m.Claimer}
return defaultDuration.Option(o) return defaultDuration.Option(o)
@ -295,22 +295,22 @@ func (m *sshLimitDuration) Option(o SSHOptions) SSHCertificateModifier {
}) })
} }
// sshCertificateOptionsValidator validates the user SSHOptions with the ones // sshCertOptionsValidator validates the user SSHOptions with the ones
// usually present in the token. // usually present in the token.
type sshCertificateOptionsValidator SSHOptions type sshCertOptionsValidator SSHOptions
// Valid implements SSHCertificateOptionsValidator and returns nil if both // Valid implements SSHCertOptionsValidator and returns nil if both
// SSHOptions match. // SSHOptions match.
func (v sshCertificateOptionsValidator) Valid(got SSHOptions) error { func (v sshCertOptionsValidator) Valid(got SSHOptions) error {
want := SSHOptions(v) want := SSHOptions(v)
return want.match(got) return want.match(got)
} }
type sshCertificateValidityValidator struct { type sshCertValidityValidator struct {
*Claimer *Claimer
} }
func (v *sshCertificateValidityValidator) Valid(cert *ssh.Certificate) error { func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SSHOptions) error {
switch { switch {
case cert.ValidAfter == 0: case cert.ValidAfter == 0:
return errors.New("ssh certificate validAfter cannot be 0") return errors.New("ssh certificate validAfter cannot be 0")
@ -336,31 +336,26 @@ func (v *sshCertificateValidityValidator) Valid(cert *ssh.Certificate) error {
// To not take into account the backdate, time.Now() will be used to // To not take into account the backdate, time.Now() will be used to
// calculate the duration if ValidAfter is in the past. // calculate the duration if ValidAfter is in the past.
var dur time.Duration dur := time.Duration(cert.ValidBefore-cert.ValidAfter) * time.Second
if t := now().Unix(); t > int64(cert.ValidAfter) {
dur = time.Duration(int64(cert.ValidBefore)-t) * time.Second
} else {
dur = time.Duration(cert.ValidBefore-cert.ValidAfter) * time.Second
}
switch { switch {
case dur < min: case dur < min:
return errors.Errorf("requested duration of %s is less than minimum "+ return errors.Errorf("requested duration of %s is less than minimum "+
"accepted duration for selected provisioner of %s", dur, min) "accepted duration for selected provisioner of %s", dur, min)
case dur > max: case dur > max+opts.Backdate:
return errors.Errorf("requested duration of %s is greater than maximum "+ return errors.Errorf("requested duration of %s is greater than maximum "+
"accepted duration for selected provisioner of %s", dur, max) "accepted duration for selected provisioner of %s", dur, max+opts.Backdate)
default: default:
return nil return nil
} }
} }
// sshCertificateDefaultValidator implements a simple validator for all the // sshCertDefaultValidator implements a simple validator for all the
// fields in the SSH certificate. // fields in the SSH certificate.
type sshCertificateDefaultValidator struct{} type sshCertDefaultValidator struct{}
// Valid returns an error if the given certificate does not contain the necessary fields. // Valid returns an error if the given certificate does not contain the necessary fields.
func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error { func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, o SSHOptions) error {
switch { switch {
case len(cert.Nonce) == 0: case len(cert.Nonce) == 0:
return errors.New("ssh certificate nonce cannot be empty") return errors.New("ssh certificate nonce cannot be empty")
@ -395,7 +390,7 @@ func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error {
type sshDefaultPublicKeyValidator struct{} type sshDefaultPublicKeyValidator struct{}
// Valid checks that certificate request common name matches the one configured. // Valid checks that certificate request common name matches the one configured.
func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate) error { func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SSHOptions) error {
if cert.Key == nil { if cert.Key == nil {
return errors.New("ssh certificate key cannot be nil") return errors.New("ssh certificate key cannot be nil")
} }
@ -425,7 +420,7 @@ func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate) error {
type sshCertKeyIDValidator string type sshCertKeyIDValidator string
// Valid returns an error if the given certificate does not contain the necessary fields. // Valid returns an error if the given certificate does not contain the necessary fields.
func (v sshCertKeyIDValidator) Valid(cert *ssh.Certificate) error { func (v sshCertKeyIDValidator) Valid(cert *ssh.Certificate, o SSHOptions) error {
if string(v) != cert.KeyId { if string(v) != cert.KeyId {
return errors.Errorf("invalid ssh certificate KeyId; want %s, but got %s", string(v), cert.KeyId) return errors.Errorf("invalid ssh certificate KeyId; want %s, but got %s", string(v), cert.KeyId)
} }

View file

@ -38,12 +38,463 @@ func TestSSHOptions_Type(t *testing.T) {
} }
} }
func Test_sshCertificateDefaultValidator_Valid(t *testing.T) { func TestSSHOptions_Modify(t *testing.T) {
type test struct {
so *SSHOptions
cert *ssh.Certificate
valid func(*ssh.Certificate)
err error
}
tests := map[string](func() test){
"fail/unexpected-cert-type": func() test {
return test{
so: &SSHOptions{CertType: "foo"},
cert: new(ssh.Certificate),
err: errors.Errorf("ssh certificate has an unknown type - foo"),
}
},
"fail/validAfter-greater-validBefore": func() test {
return test{
so: &SSHOptions{CertType: "user"},
cert: &ssh.Certificate{ValidAfter: uint64(15), ValidBefore: uint64(10)},
err: errors.Errorf("ssh certificate valid after cannot be greater than valid before"),
}
},
"ok/user-cert": func() test {
return test{
so: &SSHOptions{CertType: "user"},
cert: new(ssh.Certificate),
valid: func(cert *ssh.Certificate) {
assert.Equals(t, cert.CertType, uint32(ssh.UserCert))
},
}
},
"ok/host-cert": func() test {
return test{
so: &SSHOptions{CertType: "host"},
cert: new(ssh.Certificate),
valid: func(cert *ssh.Certificate) {
assert.Equals(t, cert.CertType, uint32(ssh.HostCert))
},
}
},
"ok": func() test {
va := time.Now().Add(5 * time.Minute)
vb := time.Now().Add(1 * time.Hour)
so := &SSHOptions{CertType: "host", KeyID: "foo", Principals: []string{"foo", "bar"},
ValidAfter: NewTimeDuration(va), ValidBefore: NewTimeDuration(vb)}
return test{
so: so,
cert: new(ssh.Certificate),
valid: func(cert *ssh.Certificate) {
assert.Equals(t, cert.CertType, uint32(ssh.HostCert))
assert.Equals(t, cert.KeyId, so.KeyID)
assert.Equals(t, cert.ValidPrincipals, so.Principals)
assert.Equals(t, cert.ValidAfter, uint64(so.ValidAfter.RelativeTime(time.Now()).Unix()))
assert.Equals(t, cert.ValidBefore, uint64(so.ValidBefore.RelativeTime(time.Now()).Unix()))
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if err := tc.so.Modify(tc.cert); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
tc.valid(tc.cert)
}
}
})
}
}
func TestSSHOptions_Match(t *testing.T) {
type test struct {
so SSHOptions
cmp SSHOptions
err error
}
tests := map[string](func() test){
"fail/cert-type": func() test {
return test{
so: SSHOptions{CertType: "foo"},
cmp: SSHOptions{CertType: "bar"},
err: errors.Errorf("ssh certificate type does not match - got bar, want foo"),
}
},
"fail/pricipals": func() test {
return test{
so: SSHOptions{Principals: []string{"foo"}},
cmp: SSHOptions{Principals: []string{"bar"}},
err: errors.Errorf("ssh certificate principals does not match - got [bar], want [foo]"),
}
},
"fail/validAfter": func() test {
return test{
so: SSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute))},
cmp: SSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(5 * time.Minute))},
err: errors.Errorf("ssh certificate valid after does not match"),
}
},
"fail/validBefore": func() test {
return test{
so: SSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(1 * time.Minute))},
cmp: SSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute))},
err: errors.Errorf("ssh certificate valid before does not match"),
}
},
"ok/original-empty": func() test {
return test{
so: SSHOptions{},
cmp: SSHOptions{
CertType: "foo",
Principals: []string{"foo"},
ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute)),
ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute)),
},
}
},
"ok/cmp-empty": func() test {
return test{
cmp: SSHOptions{},
so: SSHOptions{
CertType: "foo",
Principals: []string{"foo"},
ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute)),
ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute)),
},
}
},
"ok/equal": func() test {
n := time.Now()
va := NewTimeDuration(n.Add(1 * time.Minute))
vb := NewTimeDuration(n.Add(5 * time.Minute))
return test{
cmp: SSHOptions{
CertType: "foo",
Principals: []string{"foo"},
ValidAfter: va,
ValidBefore: vb,
},
so: SSHOptions{
CertType: "foo",
Principals: []string{"foo"},
ValidAfter: va,
ValidBefore: vb,
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if err := tc.so.match(tc.cmp); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func Test_sshCertPrincipalsModifier_Modify(t *testing.T) {
type test struct {
modifier sshCertPrincipalsModifier
cert *ssh.Certificate
expected []string
}
tests := map[string](func() test){
"ok": func() test {
a := []string{"foo", "bar"}
return test{
modifier: sshCertPrincipalsModifier(a),
cert: new(ssh.Certificate),
expected: a,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
assert.Equals(t, tc.cert.ValidPrincipals, tc.expected)
}
})
}
}
func Test_sshCertKeyIDModifier_Modify(t *testing.T) {
type test struct {
modifier sshCertKeyIDModifier
cert *ssh.Certificate
expected string
}
tests := map[string](func() test){
"ok": func() test {
a := "foo"
return test{
modifier: sshCertKeyIDModifier(a),
cert: new(ssh.Certificate),
expected: a,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
assert.Equals(t, tc.cert.KeyId, tc.expected)
}
})
}
}
func Test_sshCertTypeModifier_Modify(t *testing.T) {
type test struct {
modifier sshCertTypeModifier
cert *ssh.Certificate
expected uint32
}
tests := map[string](func() test){
"ok/user": func() test {
return test{
modifier: sshCertTypeModifier("user"),
cert: new(ssh.Certificate),
expected: ssh.UserCert,
}
},
"ok/host": func() test {
return test{
modifier: sshCertTypeModifier("host"),
cert: new(ssh.Certificate),
expected: ssh.HostCert,
}
},
"ok/default": func() test {
return test{
modifier: sshCertTypeModifier("foo"),
cert: new(ssh.Certificate),
expected: 0,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
assert.Equals(t, tc.cert.CertType, uint32(tc.expected))
}
})
}
}
func Test_sshCertValidAfterModifier_Modify(t *testing.T) {
type test struct {
modifier sshCertValidAfterModifier
cert *ssh.Certificate
expected uint64
}
tests := map[string](func() test){
"ok": func() test {
return test{
modifier: sshCertValidAfterModifier(15),
cert: new(ssh.Certificate),
expected: 15,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
assert.Equals(t, tc.cert.ValidAfter, tc.expected)
}
})
}
}
func Test_sshCertDefaultsModifier_Modify(t *testing.T) {
type test struct {
modifier sshCertDefaultsModifier
cert *ssh.Certificate
valid func(*ssh.Certificate)
}
tests := map[string](func() test){
"ok/changes": func() test {
n := time.Now()
va := NewTimeDuration(n.Add(1 * time.Minute))
vb := NewTimeDuration(n.Add(5 * time.Minute))
so := SSHOptions{
Principals: []string{"foo", "bar"},
CertType: "host",
ValidAfter: va,
ValidBefore: vb,
}
return test{
modifier: sshCertDefaultsModifier(so),
cert: new(ssh.Certificate),
valid: func(cert *ssh.Certificate) {
assert.Equals(t, cert.ValidPrincipals, so.Principals)
assert.Equals(t, cert.CertType, uint32(ssh.HostCert))
assert.Equals(t, cert.ValidAfter, uint64(so.ValidAfter.RelativeTime(time.Now()).Unix()))
assert.Equals(t, cert.ValidBefore, uint64(so.ValidBefore.RelativeTime(time.Now()).Unix()))
},
}
},
"ok/no-changes": func() test {
n := time.Now()
so := SSHOptions{
Principals: []string{"foo", "bar"},
CertType: "host",
ValidAfter: NewTimeDuration(n.Add(15 * time.Minute)),
ValidBefore: NewTimeDuration(n.Add(25 * time.Minute)),
}
return test{
modifier: sshCertDefaultsModifier(so),
cert: &ssh.Certificate{
CertType: uint32(ssh.UserCert),
ValidPrincipals: []string{"zap", "zoop"},
ValidAfter: 15,
ValidBefore: 25,
},
valid: func(cert *ssh.Certificate) {
assert.Equals(t, cert.ValidPrincipals, []string{"zap", "zoop"})
assert.Equals(t, cert.CertType, uint32(ssh.UserCert))
assert.Equals(t, cert.ValidAfter, uint64(15))
assert.Equals(t, cert.ValidBefore, uint64(25))
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
tc.valid(tc.cert)
}
})
}
}
func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
type test struct {
modifier sshDefaultExtensionModifier
cert *ssh.Certificate
valid func(*ssh.Certificate)
err error
}
tests := map[string](func() test){
"fail/unexpected-cert-type": func() test {
cert := &ssh.Certificate{CertType: 3}
return test{
modifier: sshDefaultExtensionModifier{},
cert: cert,
err: errors.New("ssh certificate type has not been set or is invalid"),
}
},
"ok/host": func() test {
cert := &ssh.Certificate{CertType: ssh.HostCert}
return test{
modifier: sshDefaultExtensionModifier{},
cert: cert,
valid: func(cert *ssh.Certificate) {
assert.Len(t, 0, cert.Extensions)
},
}
},
"ok/user/extensions-exists": func() test {
cert := &ssh.Certificate{CertType: ssh.UserCert, Permissions: ssh.Permissions{Extensions: map[string]string{
"foo": "bar",
}}}
return test{
modifier: sshDefaultExtensionModifier{},
cert: cert,
valid: func(cert *ssh.Certificate) {
val, ok := cert.Extensions["foo"]
assert.True(t, ok)
assert.Equals(t, val, "bar")
val, ok = cert.Extensions["permit-X11-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-agent-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-port-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-pty"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-user-rc"]
assert.True(t, ok)
assert.Equals(t, val, "")
},
}
},
"ok/user/no-extensions": func() test {
return test{
modifier: sshDefaultExtensionModifier{},
cert: &ssh.Certificate{CertType: ssh.UserCert},
valid: func(cert *ssh.Certificate) {
_, ok := cert.Extensions["foo"]
assert.False(t, ok)
val, ok := cert.Extensions["permit-X11-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-agent-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-port-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-pty"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-user-rc"]
assert.True(t, ok)
assert.Equals(t, val, "")
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if err := tc.modifier.Modify(tc.cert); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
tc.valid(tc.cert)
}
}
})
}
}
func Test_sshCertDefaultValidator_Valid(t *testing.T) {
pub, _, err := keys.GenerateDefaultKeyPair() pub, _, err := keys.GenerateDefaultKeyPair()
assert.FatalError(t, err) assert.FatalError(t, err)
sshPub, err := ssh.NewPublicKey(pub) sshPub, err := ssh.NewPublicKey(pub)
assert.FatalError(t, err) assert.FatalError(t, err)
v := sshCertificateDefaultValidator{} v := sshCertDefaultValidator{}
tests := []struct { tests := []struct {
name string name string
cert *ssh.Certificate cert *ssh.Certificate
@ -208,7 +659,7 @@ func Test_sshCertificateDefaultValidator_Valid(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := v.Valid(tt.cert); err != nil { if err := v.Valid(tt.cert, SSHOptions{}); err != nil {
if assert.NotNil(t, tt.err) { if assert.NotNil(t, tt.err) {
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
} }
@ -219,34 +670,39 @@ func Test_sshCertificateDefaultValidator_Valid(t *testing.T) {
} }
} }
func Test_sshCertificateValidityValidator(t *testing.T) { func Test_sshCertValidityValidator(t *testing.T) {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) assert.FatalError(t, err)
v := sshCertificateValidityValidator{p.claimer} v := sshCertValidityValidator{p.claimer}
n := now() n := now()
tests := []struct { tests := []struct {
name string name string
cert *ssh.Certificate cert *ssh.Certificate
opts SSHOptions
err error err error
}{ }{
{ {
"fail/validAfter-0", "fail/validAfter-0",
&ssh.Certificate{CertType: ssh.UserCert}, &ssh.Certificate{CertType: ssh.UserCert},
SSHOptions{},
errors.New("ssh certificate validAfter cannot be 0"), errors.New("ssh certificate validAfter cannot be 0"),
}, },
{ {
"fail/validBefore-in-past", "fail/validBefore-in-past",
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(-time.Minute).Unix())}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(-time.Minute).Unix())},
SSHOptions{},
errors.New("ssh certificate validBefore cannot be in the past"), errors.New("ssh certificate validBefore cannot be in the past"),
}, },
{ {
"fail/validBefore-before-validAfter", "fail/validBefore-before-validAfter",
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Add(5 * time.Minute).Unix()), ValidBefore: uint64(now().Add(3 * time.Minute).Unix())}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Add(5 * time.Minute).Unix()), ValidBefore: uint64(now().Add(3 * time.Minute).Unix())},
SSHOptions{},
errors.New("ssh certificate validBefore cannot be before validAfter"), errors.New("ssh certificate validBefore cannot be before validAfter"),
}, },
{ {
"fail/cert-type-not-set", "fail/cert-type-not-set",
&ssh.Certificate{ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(10 * time.Minute).Unix())}, &ssh.Certificate{ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(10 * time.Minute).Unix())},
SSHOptions{},
errors.New("ssh certificate type has not been set"), errors.New("ssh certificate type has not been set"),
}, },
{ {
@ -256,6 +712,7 @@ func Test_sshCertificateValidityValidator(t *testing.T) {
ValidAfter: uint64(now().Unix()), ValidAfter: uint64(now().Unix()),
ValidBefore: uint64(now().Add(10 * time.Minute).Unix()), ValidBefore: uint64(now().Add(10 * time.Minute).Unix()),
}, },
SSHOptions{},
errors.New("unknown ssh certificate type 3"), errors.New("unknown ssh certificate type 3"),
}, },
{ {
@ -265,8 +722,19 @@ func Test_sshCertificateValidityValidator(t *testing.T) {
ValidAfter: uint64(n.Unix()), ValidAfter: uint64(n.Unix()),
ValidBefore: uint64(n.Add(4 * time.Minute).Unix()), ValidBefore: uint64(n.Add(4 * time.Minute).Unix()),
}, },
SSHOptions{Backdate: time.Second},
errors.New("requested duration of 4m0s is less than minimum accepted duration for selected provisioner of 5m0s"), errors.New("requested duration of 4m0s is less than minimum accepted duration for selected provisioner of 5m0s"),
}, },
{
"ok/duration-exactly-min",
&ssh.Certificate{
CertType: 1,
ValidAfter: uint64(n.Unix()),
ValidBefore: uint64(n.Add(5 * time.Minute).Unix()),
},
SSHOptions{Backdate: time.Second},
nil,
},
{ {
"fail/duration>max", "fail/duration>max",
&ssh.Certificate{ &ssh.Certificate{
@ -274,7 +742,18 @@ func Test_sshCertificateValidityValidator(t *testing.T) {
ValidAfter: uint64(n.Unix()), ValidAfter: uint64(n.Unix()),
ValidBefore: uint64(n.Add(48 * time.Hour).Unix()), ValidBefore: uint64(n.Add(48 * time.Hour).Unix()),
}, },
errors.New("requested duration of 48h0m0s is greater than maximum accepted duration for selected provisioner of 24h0m0s"), SSHOptions{Backdate: time.Second},
errors.New("requested duration of 48h0m0s is greater than maximum accepted duration for selected provisioner of 24h0m1s"),
},
{
"ok/duration-exactly-max",
&ssh.Certificate{
CertType: 1,
ValidAfter: uint64(n.Unix()),
ValidBefore: uint64(n.Add(24*time.Hour + time.Second).Unix()),
},
SSHOptions{Backdate: time.Second},
nil,
}, },
{ {
"ok", "ok",
@ -283,12 +762,13 @@ func Test_sshCertificateValidityValidator(t *testing.T) {
ValidAfter: uint64(now().Unix()), ValidAfter: uint64(now().Unix()),
ValidBefore: uint64(now().Add(8 * time.Hour).Unix()), ValidBefore: uint64(now().Add(8 * time.Hour).Unix()),
}, },
SSHOptions{Backdate: time.Second},
nil, nil,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := v.Valid(tt.cert); err != nil { if err := v.Valid(tt.cert, tt.opts); err != nil {
if assert.NotNil(t, tt.err) { if assert.NotNil(t, tt.err) {
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
} }
@ -505,7 +985,7 @@ func Test_sshDefaultDuration_Option(t *testing.T) {
{"host backdate", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert}}, {"host backdate", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert}},
&ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(30 * 24 * time.Hour)}, false}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(30 * 24 * time.Hour)}, false},
{"user validAfter", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(1 * time.Hour)}}, {"user validAfter", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(1 * time.Hour)}},
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(time.Minute), ValidBefore: unix(17 * time.Hour)}, false}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(time.Hour), ValidBefore: unix(17 * time.Hour)}, false},
{"user validBefore", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidBefore: unix(1 * time.Hour)}}, {"user validBefore", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidBefore: unix(1 * time.Hour)}},
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(time.Hour)}, false}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(time.Hour)}, false},
{"host validAfter validBefore", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(1 * time.Minute), ValidBefore: unix(2 * time.Minute)}}, {"host validAfter validBefore", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(1 * time.Minute), ValidBefore: unix(2 * time.Minute)}},
@ -541,7 +1021,7 @@ func Test_sshLimitDuration_Option(t *testing.T) {
name string name string
fields fields fields fields
args args args args
want SSHCertificateModifier want SSHCertModifier
}{ }{
// TODO: Add test cases. // TODO: Add test cases.
} }

View file

@ -45,22 +45,22 @@ func signSSHCertificate(key crypto.PublicKey, opts SSHOptions, signOpts []SignOp
return nil, err return nil, err
} }
var mods []SSHCertificateModifier var mods []SSHCertModifier
var validators []SSHCertificateValidator var validators []SSHCertValidator
for _, op := range signOpts { for _, op := range signOpts {
switch o := op.(type) { switch o := op.(type) {
// modify the ssh.Certificate // modify the ssh.Certificate
case SSHCertificateModifier: case SSHCertModifier:
mods = append(mods, o) mods = append(mods, o)
// modify the ssh.Certificate given the SSHOptions // modify the ssh.Certificate given the SSHOptions
case SSHCertificateOptionModifier: case SSHCertOptionModifier:
mods = append(mods, o.Option(opts)) mods = append(mods, o.Option(opts))
// validate the ssh.Certificate // validate the ssh.Certificate
case SSHCertificateValidator: case SSHCertValidator:
validators = append(validators, o) validators = append(validators, o)
// validate the given SSHOptions // validate the given SSHOptions
case SSHCertificateOptionsValidator: case SSHCertOptionsValidator:
if err := o.Valid(opts); err != nil { if err := o.Valid(opts); err != nil {
return nil, err return nil, err
} }
@ -116,7 +116,7 @@ func signSSHCertificate(key crypto.PublicKey, opts SSHOptions, signOpts []SignOp
// User provisioners validators // User provisioners validators
for _, v := range validators { for _, v := range validators {
if err := v.Valid(cert); err != nil { if err := v.Valid(cert, opts); err != nil {
return nil, err return nil, err
} }
} }

View file

@ -3,11 +3,13 @@ package provisioner
import ( import (
"context" "context"
"encoding/base64" "encoding/base64"
"net/http"
"strconv" "strconv"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -99,33 +101,31 @@ func (p *SSHPOP) Init(config Config) error {
// claims for case specific downstream parsing. // claims for case specific downstream parsing.
// e.g. a Sign request will auth/validate different fields than a Revoke request. // e.g. a Sign request will auth/validate different fields than a Revoke request.
func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) { func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) {
sshCert, err := ExtractSSHPOPCert(token) sshCert, jwt, err := ExtractSSHPOPCert(token)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "authorizeToken ssh-pop") return nil, errs.Wrap(http.StatusUnauthorized, err,
"sshpop.authorizeToken; error extracting sshpop header from token")
} }
// Check for revocation. // Check for revocation.
if isRevoked, err := p.db.IsSSHRevoked(strconv.FormatUint(sshCert.Serial, 10)); err != nil { if isRevoked, err := p.db.IsSSHRevoked(strconv.FormatUint(sshCert.Serial, 10)); err != nil {
return nil, errors.Wrap(err, "authorizeToken ssh-pop") return nil, errs.Wrap(http.StatusInternalServerError, err,
"sshpop.authorizeToken; error checking checking sshpop cert revocation")
} else if isRevoked { } else if isRevoked {
return nil, errors.New("authorizeToken ssh-pop: ssh certificate has been revoked") return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate is revoked")
} }
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, errors.Wrapf(err, "error parsing token")
}
// Check validity period of the certificate. // Check validity period of the certificate.
n := time.Now() n := time.Now()
if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) { if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) {
return nil, errors.New("sshpop certificate validAfter is in the future") return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future")
} }
if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) { if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) {
return nil, errors.New("sshpop certificate validBefore is in the past") return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past")
} }
sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey) sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey)
if !ok { if !ok {
return nil, errors.New("ssh public key could not be cast to ssh CryptoPublicKey") return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey")
} }
pubKey := sshCryptoPubKey.CryptoPublicKey() pubKey := sshCryptoPubKey.CryptoPublicKey()
@ -146,7 +146,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
} }
} }
if !found { if !found {
return nil, errors.New("error: provisioner could could not verify the sshpop header certificate") return nil, errs.Unauthorized("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate")
} }
// Using the ssh certificates key to validate the claims accomplishes two // Using the ssh certificates key to validate the claims accomplishes two
@ -156,7 +156,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
// 2. Asserts that the claims are valid - have not been tampered with. // 2. Asserts that the claims are valid - have not been tampered with.
var claims sshPOPPayload var claims sshPOPPayload
if err = jwt.Claims(pubKey, &claims); err != nil { if err = jwt.Claims(pubKey, &claims); err != nil {
return nil, errors.Wrap(err, "error parsing claims") return nil, errs.Wrap(http.StatusUnauthorized, err, "sshpop.authorizeToken; error parsing sshpop token claims")
} }
// According to "rfc7519 JSON Web Token" acceptable skew should be no // According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -165,16 +165,17 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
Issuer: p.Name, Issuer: p.Name,
Time: time.Now().UTC(), Time: time.Now().UTC(),
}, time.Minute); err != nil { }, time.Minute); err != nil {
return nil, errors.Wrapf(err, "invalid token") return nil, errs.Wrap(http.StatusUnauthorized, err, "sshpop.authorizeToken; invalid sshpop token")
} }
// validate audiences with the defaults // validate audiences with the defaults
if !matchesAudience(claims.Audience, audiences) { if !matchesAudience(claims.Audience, audiences) {
return nil, errors.New("invalid token: invalid audience claim (aud)") return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop token has invalid audience "+
"claim (aud): expected %s, but got %s", audiences, claims.Audience)
} }
if claims.Subject == "" { if claims.Subject == "" {
return nil, errors.New("token subject cannot be empty") return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop token subject cannot be empty")
} }
claims.sshCert = sshCert claims.sshCert = sshCert
@ -186,12 +187,13 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
claims, err := p.authorizeToken(token, p.audiences.SSHRevoke) claims, err := p.authorizeToken(token, p.audiences.SSHRevoke)
if err != nil { if err != nil {
return err return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
} }
if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) { if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) {
return errors.New("token subject must be equivalent to certificate serial number") return errs.BadRequest("sshpop.AuthorizeSSHRevoke; sshpop token subject " +
"must be equivalent to sshpop certificate serial number")
} }
return err return nil
} }
// AuthorizeSSHRenew validates the authorization token and extracts/validates // AuthorizeSSHRenew validates the authorization token and extracts/validates
@ -199,10 +201,10 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
claims, err := p.authorizeToken(token, p.audiences.SSHRenew) claims, err := p.authorizeToken(token, p.audiences.SSHRenew)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew")
} }
if claims.sshCert.CertType != ssh.HostCert { if claims.sshCert.CertType != ssh.HostCert {
return nil, errors.New("sshpop AuthorizeSSHRenew: sshpop certificate must be a host ssh certificate") return nil, errs.BadRequest("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate")
} }
return claims.sshCert, nil return claims.sshCert, nil
@ -214,51 +216,52 @@ func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Cert
func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) { func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.SSHRekey) claims, err := p.authorizeToken(token, p.audiences.SSHRekey)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")
} }
if claims.sshCert.CertType != ssh.HostCert { if claims.sshCert.CertType != ssh.HostCert {
return nil, nil, errors.New("sshpop AuthorizeSSHRekey: sshpop certificate must be a host ssh certificate") return nil, nil, errs.BadRequest("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate")
} }
return claims.sshCert, []SignOption{ return claims.sshCert, []SignOption{
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertificateValidityValidator{p.claimer}, &sshCertValidityValidator{p.claimer},
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertificateDefaultValidator{}, &sshCertDefaultValidator{},
}, nil }, nil
} }
// ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate // ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate
// in the sshpop header. If the header is missing, an error is returned. // in the sshpop header. If the header is missing, an error is returned.
func ExtractSSHPOPCert(token string) (*ssh.Certificate, error) { func ExtractSSHPOPCert(token string) (*ssh.Certificate, *jose.JSONWebToken, error) {
jwt, err := jose.ParseSigned(token) jwt, err := jose.ParseSigned(token)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error parsing token") return nil, nil, errors.Wrapf(err, "extractSSHPOPCert; error parsing token")
} }
encodedSSHCert, ok := jwt.Headers[0].ExtraHeaders["sshpop"] encodedSSHCert, ok := jwt.Headers[0].ExtraHeaders["sshpop"]
if !ok { if !ok {
return nil, errors.New("token missing sshpop header") return nil, nil, errors.New("extractSSHPOPCert; token missing sshpop header")
} }
encodedSSHCertStr, ok := encodedSSHCert.(string) encodedSSHCertStr, ok := encodedSSHCert.(string)
if !ok { if !ok {
return nil, errors.New("error unexpected type for sshpop header") return nil, nil, errors.Errorf("extractSSHPOPCert; error unexpected type for sshpop header: "+
"want 'string', but got '%T'", encodedSSHCert)
} }
sshCertBytes, err := base64.StdEncoding.DecodeString(encodedSSHCertStr) sshCertBytes, err := base64.StdEncoding.DecodeString(encodedSSHCertStr)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error decoding sshpop header") return nil, nil, errors.Wrap(err, "extractSSHPOPCert; error base64 decoding sshpop header")
} }
sshPub, err := ssh.ParsePublicKey(sshCertBytes) sshPub, err := ssh.ParsePublicKey(sshCertBytes)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error parsing ssh public key") return nil, nil, errors.Wrap(err, "extractSSHPOPCert; error parsing ssh public key")
} }
sshCert, ok := sshPub.(*ssh.Certificate) sshCert, ok := sshPub.(*ssh.Certificate)
if !ok { if !ok {
return nil, errors.New("error converting ssh public key to ssh certificate") return nil, nil, errors.New("extractSSHPOPCert; error converting ssh public key to ssh certificate")
} }
return sshCert, nil return sshCert, jwt, nil
} }
func bytesForSigning(cert *ssh.Certificate) []byte { func bytesForSigning(cert *ssh.Certificate) []byte {

View file

@ -0,0 +1,684 @@
package provisioner
import (
"context"
"crypto"
"crypto/rand"
"encoding/base64"
"net/http"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/jose"
"golang.org/x/crypto/ssh"
)
func TestSSHPOP_Getters(t *testing.T) {
p, err := generateSSHPOP()
assert.FatalError(t, err)
id := "sshpop/" + p.Name
if got := p.GetID(); got != id {
t.Errorf("SSHPOP.GetID() = %v, want %v", got, id)
}
if got := p.GetName(); got != p.Name {
t.Errorf("SSHPOP.GetName() = %v, want %v", got, p.Name)
}
if got := p.GetType(); got != TypeSSHPOP {
t.Errorf("SSHPOP.GetType() = %v, want %v", got, TypeSSHPOP)
}
kid, key, ok := p.GetEncryptedKey()
if kid != "" || key != "" || ok == true {
t.Errorf("SSHPOP.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
kid, key, ok, "", "", false)
}
}
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
if err != nil {
return nil, nil, err
}
cert.Key, err = ssh.NewPublicKey(jwk.Public().Key)
if err != nil {
return nil, nil, err
}
if err = cert.SignCert(rand.Reader, signer); err != nil {
return nil, nil, err
}
return cert, jwk, nil
}
func generateSSHPOPToken(p Interface, cert *ssh.Certificate, jwk *jose.JSONWebKey) (string, error) {
return generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
}
func TestSSHPOP_authorizeToken(t *testing.T) {
key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key")
assert.FatalError(t, err)
signer, ok := key.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
sshSigner, err := ssh.NewSignerFromSigner(signer)
assert.FatalError(t, err)
type test struct {
p *SSHPOP
token string
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "),
}
},
"fail/error-revoked-db-check": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, errors.New("force")
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusInternalServerError,
err: errors.New("sshpop.authorizeToken; error checking checking sshpop cert revocation: force"),
}
},
"fail/cert-already-revoked": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return true, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; sshpop certificate is revoked"),
}
},
"fail/cert-not-yet-valid": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{
CertType: ssh.UserCert,
ValidAfter: uint64(time.Now().Add(time.Minute).Unix()),
}, sshSigner)
assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; sshpop certificate validAfter is in the future"),
}
},
"fail/cert-past-validity": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{
CertType: ssh.UserCert,
ValidBefore: uint64(time.Now().Add(-time.Minute).Unix()),
}, sshSigner)
assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; sshpop certificate validBefore is in the past"),
}
},
"fail/no-signer-found": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate"),
}
},
"fail/error-parsing-claims-bad-sig": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, _, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
otherJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, otherJWK)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; error parsing sshpop token claims"),
}
},
"fail/invalid-claims-issuer": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateToken("foo", "bar", testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; invalid sshpop token"),
}
},
"fail/invalid-audience": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), "invalid-aud", "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; sshpop token has invalid audience claim (aud)"),
}
},
"fail/empty-subject": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; sshpop token subject cannot be empty"),
}
},
"ok": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.NotNil(t, claims)
}
}
})
}
}
func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key")
assert.FatalError(t, err)
signer, ok := key.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
sshSigner, err := ssh.NewSignerFromSigner(signer)
assert.FatalError(t, err)
type test struct {
p *SSHPOP
token string
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("sshpop.AuthorizeSSHRevoke: sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "),
}
},
"fail/subject-not-equal-serial": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRevoke[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusBadRequest,
err: errors.New("sshpop.AuthorizeSSHRevoke; sshpop token subject must be equivalent to sshpop certificate serial number"),
}
},
"ok": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRevoke[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key")
assert.FatalError(t, err)
userSigner, ok := key.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh user signing key to crypto signer")
sshUserSigner, err := ssh.NewSignerFromSigner(userSigner)
assert.FatalError(t, err)
hostKey, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
assert.FatalError(t, err)
hostSigner, ok := hostKey.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh host signing key to crypto signer")
sshHostSigner, err := ssh.NewSignerFromSigner(hostSigner)
assert.FatalError(t, err)
type test struct {
p *SSHPOP
token string
cert *ssh.Certificate
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("sshpop.AuthorizeSSHRenew: sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "),
}
},
"fail/not-host-cert": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner)
assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusBadRequest,
err: errors.New("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate"),
}
},
"ok": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner)
assert.FatalError(t, err)
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRenew[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
cert: cert,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if cert, err := tc.p.AuthorizeSSHRenew(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.cert.Nonce, cert.Nonce)
}
}
})
}
}
func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key")
assert.FatalError(t, err)
userSigner, ok := key.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh user signing key to crypto signer")
sshUserSigner, err := ssh.NewSignerFromSigner(userSigner)
assert.FatalError(t, err)
hostKey, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
assert.FatalError(t, err)
hostSigner, ok := hostKey.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh host signing key to crypto signer")
sshHostSigner, err := ssh.NewSignerFromSigner(hostSigner)
assert.FatalError(t, err)
type test struct {
p *SSHPOP
token string
cert *ssh.Certificate
err error
code int
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("sshpop.AuthorizeSSHRekey: sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "),
}
},
"fail/not-host-cert": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner)
assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRekey[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusBadRequest,
err: errors.New("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate"),
}
},
"ok": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner)
assert.FatalError(t, err)
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRekey[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
cert: cert,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if cert, opts, err := tc.p.AuthorizeSSHRekey(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Len(t, 3, opts)
for _, o := range opts {
switch v := o.(type) {
case *sshDefaultPublicKeyValidator:
case *sshCertDefaultValidator:
case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer)
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
}
}
assert.Equals(t, tc.cert.Nonce, cert.Nonce)
}
}
})
}
}
func TestSSHPOP_ExtractSSHPOPCert(t *testing.T) {
hostKey, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
assert.FatalError(t, err)
hostSigner, ok := hostKey.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh host signing key to crypto signer")
sshHostSigner, err := ssh.NewSignerFromSigner(hostSigner)
assert.FatalError(t, err)
type test struct {
token string
cert *ssh.Certificate
jwk *jose.JSONWebKey
err error
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
return test{
token: "foo",
err: errors.New("extractSSHPOPCert; error parsing token"),
}
},
"fail/sshpop-missing": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
tok, err := generateToken("sub", "sshpop-provisioner", testAudiences.SSHRekey[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk)
assert.FatalError(t, err)
return test{
token: tok,
err: errors.New("extractSSHPOPCert; token missing sshpop header"),
}
},
"fail/wrong-sshpop-type": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, func(so *jose.SignerOptions) error {
so.WithHeader("sshpop", 12345)
return nil
})
assert.FatalError(t, err)
return test{
token: tok,
err: errors.New("extractSSHPOPCert; error unexpected type for sshpop header: "),
}
},
"fail/base64decode-error": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, func(so *jose.SignerOptions) error {
so.WithHeader("sshpop", "!@#$%^&*")
return nil
})
assert.FatalError(t, err)
return test{
token: tok,
err: errors.New("extractSSHPOPCert; error base64 decoding sshpop header: illegal base64"),
}
},
"fail/parsing-sshpop-pubkey": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, func(so *jose.SignerOptions) error {
so.WithHeader("sshpop", base64.StdEncoding.EncodeToString([]byte("foo")))
return nil
})
assert.FatalError(t, err)
return test{
token: tok,
err: errors.New("extractSSHPOPCert; error parsing ssh public key"),
}
},
"ok": func(t *testing.T) test {
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner)
assert.FatalError(t, err)
tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return test{
token: tok,
jwk: jwk,
cert: cert,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if cert, jwt, err := ExtractSSHPOPCert(tc.token); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.cert.Nonce, cert.Nonce)
assert.Equals(t, tc.jwk.KeyID, jwt.Headers[0].KeyID)
}
}
})
}
}

View file

@ -0,0 +1 @@
ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJj80EJXJR9vxefhdqOLSdzRzBw24t9YKPxb+eCYLf7BU50pJQnB/jK2ZM3qLFbieLaYjngZ86T4DzHxlPAnlAY=

View file

@ -0,0 +1 @@
ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJ8einS88ZaWpcTZG27D5N9JDKfGv0rzjDByLGsZzMsLYl3XcsN9IWKXB6b+5GJ3UaoZf/pFxzRzIdDIh7Ypw3Y=

View file

@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIHzAUYu3h8e1gL5ONGZo+lghJJa9rl1TvP2UlqDXazxvoAoGCCqGSM49
AwEHoUQDQgAEOLScS+1Yzmqdyots9lSC0tzTSXUXEgyOD9wYrQ0BqnVZtBXlQw1p
m3fnF/7Ehl6bD1YZWjrF1t+IBZQMq1uBBw==
-----END EC PRIVATE KEY-----

View file

@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEINWGD2xneE43YeytQzORItISxv6d/oH+9TXvDKHo6TyXoAoGCCqGSM49
AwEHoUQDQgAEVK/EtXgVV7+7ppnQSjCtI5qb/gIGnQUF4i//F/JKKho7kRNyMDSn
BP3kndiv8Yfxg4PsyIRY5ZofbEo5eJE6bg==
-----END EC PRIVATE KEY-----

View file

@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIKZCgb5pTSSCbr/xcHCOkl9O6tQtZmNahr3Ap3/c2nBLoAoGCCqGSM49
AwEHoUQDQgAEmPzQQlclH2/F5+F2o4tJ3NHMHDbi31go/Fv54Jgt/sFTnSklCcH+
MrZkzeosVuJ4tpiOeBnzpPgPMfGU8CeUBg==
-----END EC PRIVATE KEY-----

View file

@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIDuzykyPM6rLnSoyF4jnOpPAlyKZERqtaB8PTh179DMgoAoGCCqGSM49
AwEHoUQDQgAEnx6KdLzxlpalxNkbbsPk30kMp8a/SvOMMHIsaxnMywtiXddyw30h
YpcHpv7kYndRqhl/+kXHNHMh0MiHtinDdg==
-----END EC PRIVATE KEY-----

View file

@ -19,6 +19,7 @@ import (
"github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/randutil" "github.com/smallstep/cli/crypto/randutil"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
"golang.org/x/crypto/ssh"
) )
var ( var (
@ -47,24 +48,6 @@ var (
} }
) )
func provisionerClaims() *Claims {
ddr := false
des := true
return &Claims{
MinTLSDur: &Duration{5 * time.Minute},
MaxTLSDur: &Duration{24 * time.Hour},
DefaultTLSDur: &Duration{24 * time.Hour},
DisableRenewal: &ddr,
MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &Duration{Duration: 4 * time.Hour},
MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &des,
}
}
const awsTestCertificate = `-----BEGIN CERTIFICATE----- const awsTestCertificate = `-----BEGIN CERTIFICATE-----
MIICFTCCAX6gAwIBAgIRAKmbVVYAl/1XEqRfF3eJ97MwDQYJKoZIhvcNAQELBQAw MIICFTCCAX6gAwIBAgIRAKmbVVYAl/1XEqRfF3eJ97MwDQYJKoZIhvcNAQELBQAw
GDEWMBQGA1UEAxMNQVdTIFRlc3QgQ2VydDAeFw0xOTA0MjQyMjU3MzlaFw0yOTA0 GDEWMBQGA1UEAxMNQVdTIFRlc3QgQ2VydDAeFw0xOTA0MjQyMjU3MzlaFw0yOTA0
@ -204,7 +187,7 @@ func generateJWK() (*JWK, error) {
} }
func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
fooPubB, err := ioutil.ReadFile("./testdata/foo.pub") fooPubB, err := ioutil.ReadFile("./testdata/certs/foo.pub")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -212,7 +195,7 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
barPubB, err := ioutil.ReadFile("./testdata/bar.pub") barPubB, err := ioutil.ReadFile("./testdata/certs/bar.pub")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -240,6 +223,46 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
}, nil }, nil
} }
func generateSSHPOP() (*SSHPOP, error) {
name, err := randutil.Alphanumeric(10)
if err != nil {
return nil, err
}
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
userB, err := ioutil.ReadFile("./testdata/certs/ssh_user_ca_key.pub")
if err != nil {
return nil, err
}
userKey, _, _, _, err := ssh.ParseAuthorizedKey(userB)
if err != nil {
return nil, err
}
hostB, err := ioutil.ReadFile("./testdata/certs/ssh_host_ca_key.pub")
if err != nil {
return nil, err
}
hostKey, _, _, _, err := ssh.ParseAuthorizedKey(hostB)
if err != nil {
return nil, err
}
return &SSHPOP{
Name: name,
Type: "SSHPOP",
Claims: &globalProvisionerClaims,
audiences: testAudiences,
claimer: claimer,
sshPubKeys: &SSHKeys{
UserKeys: []ssh.PublicKey{userKey},
HostKeys: []ssh.PublicKey{hostKey},
},
}, nil
}
func generateX5C(root []byte) (*X5C, error) { func generateX5C(root []byte) (*X5C, error) {
if root == nil { if root == nil {
root = []byte(`-----BEGIN CERTIFICATE----- root = []byte(`-----BEGIN CERTIFICATE-----
@ -589,6 +612,13 @@ func withX5CHdr(certs []*x509.Certificate) tokOption {
} }
} }
func withSSHPOPFile(cert *ssh.Certificate) tokOption {
return func(so *jose.SignerOptions) error {
so.WithHeader("sshpop", base64.StdEncoding.EncodeToString(cert.Marshal()))
return nil
}
}
func generateToken(sub, iss, aud string, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) { func generateToken(sub, iss, aud string, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) {
so := new(jose.SignerOptions) so := new(jose.SignerOptions)
so.WithType("JWT") so.WithType("JWT")
@ -630,6 +660,24 @@ func generateToken(sub, iss, aud string, email string, sans []string, iat time.T
return jose.Signed(sig).Claims(claims).CompactSerialize() return jose.Signed(sig).Claims(claims).CompactSerialize()
} }
func generateX5CSSHToken(jwk *jose.JSONWebKey, claims *x5cPayload, tokOpts ...tokOption) (string, error) {
so := new(jose.SignerOptions)
so.WithType("JWT")
so.WithHeader("kid", jwk.KeyID)
for _, o := range tokOpts {
if err := o(so); err != nil {
return "", err
}
}
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so)
if err != nil {
return "", err
}
return jose.Signed(sig).Claims(claims).CompactSerialize()
}
func getK8sSAPayload() *k8sSAPayload { func getK8sSAPayload() *k8sSAPayload {
return &k8sSAPayload{ return &k8sSAPayload{
Claims: jose.Claims{ Claims: jose.Claims{

View file

@ -4,9 +4,11 @@ import (
"context" "context"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"net/http"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/x509util" "github.com/smallstep/cli/crypto/x509util"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
) )
@ -121,19 +123,20 @@ func (p *X5C) Init(config Config) error {
func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, error) { func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, error) {
jwt, err := jose.ParseSigned(token) jwt, err := jose.ParseSigned(token)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error parsing token") return nil, errs.Wrap(http.StatusUnauthorized, err, "x5c.authorizeToken; error parsing x5c token")
} }
verifiedChains, err := jwt.Headers[0].Certificates(x509.VerifyOptions{ verifiedChains, err := jwt.Headers[0].Certificates(x509.VerifyOptions{
Roots: p.rootPool, Roots: p.rootPool,
}) })
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error verifying x5c certificate chain") return nil, errs.Wrap(http.StatusUnauthorized, err,
"x5c.authorizeToken; error verifying x5c certificate chain in token")
} }
leaf := verifiedChains[0][0] leaf := verifiedChains[0][0]
if leaf.KeyUsage&x509.KeyUsageDigitalSignature == 0 { if leaf.KeyUsage&x509.KeyUsageDigitalSignature == 0 {
return nil, errors.New("certificate used to sign x5c token cannot be used for digital signature") return nil, errs.Unauthorized("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature")
} }
// Using the leaf certificates key to validate the claims accomplishes two // Using the leaf certificates key to validate the claims accomplishes two
@ -143,7 +146,7 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
// 2. Asserts that the claims are valid - have not been tampered with. // 2. Asserts that the claims are valid - have not been tampered with.
var claims x5cPayload var claims x5cPayload
if err = jwt.Claims(leaf.PublicKey, &claims); err != nil { if err = jwt.Claims(leaf.PublicKey, &claims); err != nil {
return nil, errors.Wrap(err, "error parsing claims") return nil, errs.Wrap(http.StatusUnauthorized, err, "x5c.authorizeToken; error parsing x5c claims")
} }
// According to "rfc7519 JSON Web Token" acceptable skew should be no // According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -152,16 +155,17 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
Issuer: p.Name, Issuer: p.Name,
Time: time.Now().UTC(), Time: time.Now().UTC(),
}, time.Minute); err != nil { }, time.Minute); err != nil {
return nil, errors.Wrapf(err, "invalid token") return nil, errs.Wrapf(http.StatusUnauthorized, err, "x5c.authorizeToken; invalid x5c claims")
} }
// validate audiences with the defaults // validate audiences with the defaults
if !matchesAudience(claims.Audience, audiences) { if !matchesAudience(claims.Audience, audiences) {
return nil, errors.New("invalid token: invalid audience claim (aud)") return nil, errs.Unauthorized("x5c.authorizeToken; x5c token has invalid audience "+
"claim (aud); expected %s, but got %s", audiences, claims.Audience)
} }
if claims.Subject == "" { if claims.Subject == "" {
return nil, errors.New("token subject cannot be empty") return nil, errs.Unauthorized("x5c.authorizeToken; x5c token subject cannot be empty")
} }
// Save the verified chains on the x5c payload object. // Save the verified chains on the x5c payload object.
@ -173,14 +177,14 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
// revoke the certificate with serial number in the `sub` property. // revoke the certificate with serial number in the `sub` property.
func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error { func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke) _, err := p.authorizeToken(token, p.audiences.Revoke)
return err return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke")
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign) claims, err := p.authorizeToken(token, p.audiences.Sign)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign")
} }
// NOTE: This is for backwards compatibility with older versions of cli // NOTE: This is for backwards compatibility with older versions of cli
@ -209,7 +213,7 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID()) return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID())
} }
return nil return nil
} }
@ -217,22 +221,22 @@ func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.claimer.IsSSHCAEnabled() {
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID()) return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID())
} }
claims, err := p.authorizeToken(token, p.audiences.SSHSign) claims, err := p.authorizeToken(token, p.audiences.SSHSign)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign")
} }
if claims.Step == nil || claims.Step.SSH == nil { if claims.Step == nil || claims.Step.SSH == nil {
return nil, errors.New("authorization token must be an SSH provisioning token") return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token")
} }
opts := claims.Step.SSH opts := claims.Step.SSH
signOptions := []SignOption{ signOptions := []SignOption{
// validates user's SSHOptions with the ones in the token // validates user's SSHOptions with the ones in the token
sshCertificateOptionsValidator(*opts), sshCertOptionsValidator(*opts),
} }
// Add modifiers from custom claims // Add modifiers from custom claims
@ -245,18 +249,18 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
} }
t := now() t := now()
if !opts.ValidAfter.IsZero() { if !opts.ValidAfter.IsZero() {
signOptions = append(signOptions, sshCertificateValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix())) signOptions = append(signOptions, sshCertValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))
} }
if !opts.ValidBefore.IsZero() { if !opts.ValidBefore.IsZero() {
signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix())) signOptions = append(signOptions, sshCertValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
} }
// Make sure to define the the KeyID // Make sure to define the the KeyID
if opts.KeyID == "" { if opts.KeyID == "" {
signOptions = append(signOptions, sshCertificateKeyIDModifier(claims.Subject)) signOptions = append(signOptions, sshCertKeyIDModifier(claims.Subject))
} }
// Default to a user certificate with no principals if not set // Default to a user certificate with no principals if not set
signOptions = append(signOptions, sshCertificateDefaultsModifier{CertType: SSHUserCert}) signOptions = append(signOptions, sshCertDefaultsModifier{CertType: SSHUserCert})
return append(signOptions, return append(signOptions,
// Set the default extensions. // Set the default extensions.
@ -268,8 +272,8 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate public key. // Validate public key.
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertificateValidityValidator{p.claimer}, &sshCertValidityValidator{p.claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil
} }

View file

@ -2,14 +2,16 @@ package provisioner
import ( import (
"context" "context"
"crypto/x509"
"net" "net"
"net/http"
"testing" "testing"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/randutil"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
) )
@ -151,9 +153,15 @@ M46l92gdOozT
} }
func TestX5C_authorizeToken(t *testing.T) { func TestX5C_authorizeToken(t *testing.T) {
x5cCerts, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
assert.FatalError(t, err)
x5cJWK, err := jose.ParseKey("./testdata/secrets/x5c-leaf.key")
assert.FatalError(t, err)
type test struct { type test struct {
p *X5C p *X5C
token string token string
code int
err error err error
} }
tests := map[string]func(*testing.T) test{ tests := map[string]func(*testing.T) test{
@ -163,7 +171,8 @@ func TestX5C_authorizeToken(t *testing.T) {
return test{ return test{
p: p, p: p,
token: "foo", token: "foo",
err: errors.New("error parsing token"), code: http.StatusUnauthorized,
err: errors.New("x5c.authorizeToken; error parsing x5c token"),
} }
}, },
"fail/invalid-cert-chain": func(t *testing.T) test { "fail/invalid-cert-chain": func(t *testing.T) test {
@ -190,7 +199,8 @@ a5wpg+9s6QIgHIW6L60F8klQX+EO3o0SBqLeNcaskA4oSZsKjEdpSGo=
return test{ return test{
p: p, p: p,
token: tok, token: tok,
err: errors.New("error verifying x5c certificate chain: x509: certificate signed by unknown authority"), code: http.StatusUnauthorized,
err: errors.New("x5c.authorizeToken; error verifying x5c certificate chain in token"),
} }
}, },
"fail/doubled-up-self-signed-cert": func(t *testing.T) test { "fail/doubled-up-self-signed-cert": func(t *testing.T) test {
@ -228,7 +238,8 @@ EXAHTA9L
return test{ return test{
p: p, p: p,
token: tok, token: tok,
err: errors.New("error verifying x5c certificate chain: x509: certificate signed by unknown authority"), code: http.StatusUnauthorized,
err: errors.New("x5c.authorizeToken; error verifying x5c certificate chain in token"),
} }
}, },
"fail/digital-signature-ext-required": func(t *testing.T) test { "fail/digital-signature-ext-required": func(t *testing.T) test {
@ -269,7 +280,8 @@ lgsqsR63is+0YQ==
return test{ return test{
p: p, p: p,
token: tok, token: tok,
err: errors.New("certificate used to sign x5c token cannot be used for digital signature"), code: http.StatusUnauthorized,
err: errors.New("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature"),
} }
}, },
"fail/signature-does-not-match-x5c-pub-key": func(t *testing.T) test { "fail/signature-does-not-match-x5c-pub-key": func(t *testing.T) test {
@ -309,74 +321,58 @@ lgsqsR63is+0YQ==
return test{ return test{
p: p, p: p,
token: tok, token: tok,
err: errors.New("error parsing claims: square/go-jose: error in cryptographic primitive"), code: http.StatusUnauthorized,
err: errors.New("x5c.authorizeToken; error parsing x5c claims"),
} }
}, },
"fail/invalid-issuer": func(t *testing.T) test { "fail/invalid-issuer": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
assert.FatalError(t, err)
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateToken("", "foobar", testAudiences.Sign[0], "", tok, err := generateToken("", "foobar", testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, []string{"test.smallstep.com"}, time.Now(), x5cJWK,
withX5CHdr(certs)) withX5CHdr(x5cCerts))
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
token: tok, token: tok,
err: errors.New("invalid token: square/go-jose/jwt: validation failed, invalid issuer claim (iss)"), code: http.StatusUnauthorized,
err: errors.New("x5c.authorizeToken; invalid x5c claims"),
} }
}, },
"fail/invalid-audience": func(t *testing.T) test { "fail/invalid-audience": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
assert.FatalError(t, err)
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateToken("", p.GetName(), "foobar", "", tok, err := generateToken("", p.GetName(), "foobar", "",
[]string{"test.smallstep.com"}, time.Now(), jwk, []string{"test.smallstep.com"}, time.Now(), x5cJWK,
withX5CHdr(certs)) withX5CHdr(x5cCerts))
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
token: tok, token: tok,
err: errors.New("invalid token: invalid audience claim (aud)"), code: http.StatusUnauthorized,
err: errors.New("x5c.authorizeToken; x5c token has invalid audience claim (aud)"),
} }
}, },
"fail/empty-subject": func(t *testing.T) test { "fail/empty-subject": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
assert.FatalError(t, err)
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "", tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, []string{"test.smallstep.com"}, time.Now(), x5cJWK,
withX5CHdr(certs)) withX5CHdr(x5cCerts))
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
token: tok, token: tok,
err: errors.New("token subject cannot be empty"), code: http.StatusUnauthorized,
err: errors.New("x5c.authorizeToken; x5c token subject cannot be empty"),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
assert.FatalError(t, err)
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "", tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk, []string{"test.smallstep.com"}, time.Now(), x5cJWK,
withX5CHdr(certs)) withX5CHdr(x5cCerts))
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
@ -389,6 +385,9 @@ lgsqsR63is+0YQ==
tc := tt(t) tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else {
@ -402,10 +401,15 @@ lgsqsR63is+0YQ==
} }
func TestX5C_AuthorizeSign(t *testing.T) { func TestX5C_AuthorizeSign(t *testing.T) {
certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/secrets/x5c-leaf.key")
assert.FatalError(t, err)
type test struct { type test struct {
p *X5C p *X5C
token string token string
ctx context.Context code int
err error err error
dns []string dns []string
emails []string emails []string
@ -418,56 +422,11 @@ func TestX5C_AuthorizeSign(t *testing.T) {
return test{ return test{
p: p, p: p,
token: "foo", token: "foo",
ctx: NewContextWithMethod(context.Background(), SignMethod), code: http.StatusUnauthorized,
err: errors.New("error parsing token"), err: errors.New("x5c.AuthorizeSign: x5c.authorizeToken; error parsing x5c token"),
}
},
"fail/ssh/disabled": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
assert.FatalError(t, err)
p, err := generateX5C(nil)
assert.FatalError(t, err)
p.claimer.claims = provisionerClaims()
*p.claimer.claims.EnableSSHCA = false
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk,
withX5CHdr(certs))
assert.FatalError(t, err)
return test{
p: p,
ctx: NewContextWithMethod(context.Background(), SignSSHMethod),
token: tok,
err: errors.Errorf("ssh ca is disabled for provisioner x5c/%s", p.GetName()),
}
},
"fail/ssh/invalid-token": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
assert.FatalError(t, err)
p, err := generateX5C(nil)
assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk,
withX5CHdr(certs))
assert.FatalError(t, err)
return test{
p: p,
ctx: NewContextWithMethod(context.Background(), SignSSHMethod),
token: tok,
err: errors.New("authorization token must be an SSH provisioning token"),
} }
}, },
"ok/empty-sans": func(t *testing.T) test { "ok/empty-sans": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
assert.FatalError(t, err)
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "", tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
@ -476,7 +435,6 @@ func TestX5C_AuthorizeSign(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
ctx: NewContextWithMethod(context.Background(), SignMethod),
token: tok, token: tok,
dns: []string{"foo"}, dns: []string{"foo"},
emails: []string{}, emails: []string{},
@ -484,11 +442,6 @@ func TestX5C_AuthorizeSign(t *testing.T) {
} }
}, },
"ok/multi-sans": func(t *testing.T) test { "ok/multi-sans": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
assert.FatalError(t, err)
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "", tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
@ -497,7 +450,6 @@ func TestX5C_AuthorizeSign(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
ctx: NewContextWithMethod(context.Background(), SignMethod),
token: tok, token: tok,
dns: []string{"foo"}, dns: []string{"foo"},
emails: []string{"max@smallstep.com"}, emails: []string{"max@smallstep.com"},
@ -508,8 +460,11 @@ func TestX5C_AuthorizeSign(t *testing.T) {
for name, tt := range tests { for name, tt := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if opts, err := tc.p.AuthorizeSign(tc.ctx, tc.token); err != nil { if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else {
@ -554,126 +509,11 @@ func TestX5C_AuthorizeSign(t *testing.T) {
} }
} }
func TestX5C_AuthorizeSSHSign(t *testing.T) {
_, fn := mockNow()
defer fn()
type test struct {
p *X5C
token string
claims *x5cPayload
err error
}
tests := map[string]func(*testing.T) test{
"fail/no-Step-claim": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
return test{
p: p,
claims: new(x5cPayload),
err: errors.New("authorization token must be an SSH provisioning token"),
}
},
"fail/no-SSH-subattribute-in-claims": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
return test{
p: p,
claims: &x5cPayload{Step: new(stepPayload)},
err: errors.New("authorization token must be an SSH provisioning token"),
}
},
"ok/with-claims": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
return test{
p: p,
claims: &x5cPayload{
Step: &stepPayload{SSH: &SSHOptions{
CertType: SSHHostCert,
Principals: []string{"max", "mariano", "alan"},
ValidAfter: TimeDuration{d: 5 * time.Minute},
ValidBefore: TimeDuration{d: 10 * time.Minute},
}},
Claims: jose.Claims{Subject: "foo"},
chains: [][]*x509.Certificate{certs},
},
}
},
"ok/without-claims": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
assert.FatalError(t, err)
return test{
p: p,
claims: &x5cPayload{
Step: &stepPayload{SSH: &SSHOptions{}},
Claims: jose.Claims{Subject: "foo"},
chains: [][]*x509.Certificate{certs},
},
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSSHSign(context.TODO(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) {
tot := 0
nw := now()
for _, o := range opts {
switch v := o.(type) {
case sshCertificateOptionsValidator:
tc.claims.Step.SSH.ValidAfter.t = time.Time{}
tc.claims.Step.SSH.ValidBefore.t = time.Time{}
assert.Equals(t, SSHOptions(v), *tc.claims.Step.SSH)
case sshCertificateKeyIDModifier:
assert.Equals(t, string(v), "foo")
case sshCertTypeModifier:
assert.Equals(t, string(v), tc.claims.Step.SSH.CertType)
case sshCertPrincipalsModifier:
assert.Equals(t, []string(v), tc.claims.Step.SSH.Principals)
case sshCertificateValidAfterModifier:
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix())
case sshCertificateValidBeforeModifier:
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix())
case sshCertificateDefaultsModifier:
assert.Equals(t, SSHOptions(v), SSHOptions{CertType: SSHUserCert})
case *sshLimitDuration:
assert.Equals(t, v.Claimer, tc.p.claimer)
assert.Equals(t, v.NotAfter, tc.claims.chains[0][0].NotAfter)
case *sshCertificateValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer)
case *sshDefaultExtensionModifier, *sshDefaultPublicKeyValidator,
*sshCertificateDefaultValidator:
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
}
tot++
}
if len(tc.claims.Step.SSH.CertType) > 0 {
assert.Equals(t, tot, 12)
} else {
assert.Equals(t, tot, 8)
}
}
}
}
})
}
}
func TestX5C_AuthorizeRevoke(t *testing.T) { func TestX5C_AuthorizeRevoke(t *testing.T) {
type test struct { type test struct {
p *X5C p *X5C
token string token string
code int
err error err error
} }
tests := map[string]func(*testing.T) test{ tests := map[string]func(*testing.T) test{
@ -683,13 +523,14 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
return test{ return test{
p: p, p: p,
token: "foo", token: "foo",
err: errors.New("error parsing token"), code: http.StatusUnauthorized,
err: errors.New("x5c.AuthorizeRevoke: x5c.authorizeToken; error parsing x5c token"),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt") certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
assert.FatalError(t, err) assert.FatalError(t, err)
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key") jwk, err := jose.ParseKey("./testdata/secrets/x5c-leaf.key")
assert.FatalError(t, err) assert.FatalError(t, err)
p, err := generateX5C(nil) p, err := generateX5C(nil)
@ -707,8 +548,11 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
for name, tt := range tests { for name, tt := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeRevoke(context.TODO(), tc.token); err != nil { if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else {
@ -719,33 +563,248 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
} }
func TestX5C_AuthorizeRenew(t *testing.T) { func TestX5C_AuthorizeRenew(t *testing.T) {
p1, err := generateX5C(nil) type test struct {
p *X5C
code int
err error
}
tests := map[string]func(*testing.T) test{
"fail/renew-disabled": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateX5C(nil)
assert.FatalError(t, err)
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{DisableRenewal: &disable} p.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{
type args struct { p: p,
cert *x509.Certificate code: http.StatusUnauthorized,
err: errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID()),
}
},
"ok": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
return test{
p: p,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), nil); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestX5C_AuthorizeSSHSign(t *testing.T) {
x5cCerts, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
assert.FatalError(t, err)
x5cJWK, err := jose.ParseKey("./testdata/secrets/x5c-leaf.key")
assert.FatalError(t, err)
_, fn := mockNow()
defer fn()
type test struct {
p *X5C
token string
claims *x5cPayload
code int
err error
}
tests := map[string]func(*testing.T) test{
"fail/sshCA-disabled": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
// disable sshCA
enable := false
p.Claims = &Claims{EnableSSHCA: &enable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID()),
}
},
"fail/invalid-token": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.New("x5c.AuthorizeSSHSign: x5c.authorizeToken; error parsing x5c token"),
}
},
"fail/no-Step-claim": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHSign[0], "",
[]string{"test.smallstep.com"}, time.Now(), x5cJWK,
withX5CHdr(x5cCerts))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token"),
}
},
"fail/no-SSH-subattribute-in-claims": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
id, err := randutil.ASCII(64)
assert.FatalError(t, err)
now := time.Now()
claims := &x5cPayload{
Claims: jose.Claims{
ID: id,
Subject: "foo",
Issuer: p.GetName(),
IssuedAt: jose.NewNumericDate(now),
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
Audience: []string{testAudiences.SSHSign[0]},
},
Step: &stepPayload{},
}
tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token"),
}
},
"ok/with-claims": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
id, err := randutil.ASCII(64)
assert.FatalError(t, err)
now := time.Now()
claims := &x5cPayload{
Claims: jose.Claims{
ID: id,
Subject: "foo",
Issuer: p.GetName(),
IssuedAt: jose.NewNumericDate(now),
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
Audience: []string{testAudiences.SSHSign[0]},
},
Step: &stepPayload{SSH: &SSHOptions{
CertType: SSHHostCert,
Principals: []string{"max", "mariano", "alan"},
ValidAfter: TimeDuration{d: 5 * time.Minute},
ValidBefore: TimeDuration{d: 10 * time.Minute},
}},
}
tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts))
assert.FatalError(t, err)
return test{
p: p,
claims: claims,
token: tok,
}
},
"ok/without-claims": func(t *testing.T) test {
p, err := generateX5C(nil)
assert.FatalError(t, err)
id, err := randutil.ASCII(64)
assert.FatalError(t, err)
now := time.Now()
claims := &x5cPayload{
Claims: jose.Claims{
ID: id,
Subject: "foo",
Issuer: p.GetName(),
IssuedAt: jose.NewNumericDate(now),
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
Audience: []string{testAudiences.SSHSign[0]},
},
Step: &stepPayload{SSH: &SSHOptions{}},
}
tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts))
assert.FatalError(t, err)
return test{
p: p,
claims: claims,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) {
tot := 0
nw := now()
for _, o := range opts {
switch v := o.(type) {
case sshCertOptionsValidator:
tc.claims.Step.SSH.ValidAfter.t = time.Time{}
tc.claims.Step.SSH.ValidBefore.t = time.Time{}
assert.Equals(t, SSHOptions(v), *tc.claims.Step.SSH)
case sshCertKeyIDModifier:
assert.Equals(t, string(v), "foo")
case sshCertTypeModifier:
assert.Equals(t, string(v), tc.claims.Step.SSH.CertType)
case sshCertPrincipalsModifier:
assert.Equals(t, []string(v), tc.claims.Step.SSH.Principals)
case sshCertValidAfterModifier:
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix())
case sshCertValidBeforeModifier:
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix())
case sshCertDefaultsModifier:
assert.Equals(t, SSHOptions(v), SSHOptions{CertType: SSHUserCert})
case *sshLimitDuration:
assert.Equals(t, v.Claimer, tc.p.claimer)
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer)
case *sshDefaultExtensionModifier, *sshDefaultPublicKeyValidator,
*sshCertDefaultValidator:
case sshCertKeyIDValidator:
assert.Equals(t, string(v), "foo")
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
}
tot++
}
if len(tc.claims.Step.SSH.CertType) > 0 {
assert.Equals(t, tot, 13)
} else {
assert.Equals(t, tot, 9)
}
} }
tests := []struct {
name string
prov *X5C
args args
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
} }
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("X5C.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} }
}) })
} }

View file

@ -2,18 +2,16 @@ package authority
import ( import (
"crypto/x509" "crypto/x509"
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
) )
// GetEncryptedKey returns the JWE key corresponding to the given kid argument. // GetEncryptedKey returns the JWE key corresponding to the given kid argument.
func (a *Authority) GetEncryptedKey(kid string) (string, error) { func (a *Authority) GetEncryptedKey(kid string) (string, error) {
key, ok := a.provisioners.LoadEncryptedKey(kid) key, ok := a.provisioners.LoadEncryptedKey(kid)
if !ok { if !ok {
return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid), return "", errs.NotFound("encrypted key with kid %s was not found", kid)
http.StatusNotFound, apiCtx{}}
} }
return key, nil return key, nil
} }
@ -30,8 +28,7 @@ func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List,
func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) { func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) {
p, ok := a.provisioners.LoadByCertificate(crt) p, ok := a.provisioners.LoadByCertificate(crt)
if !ok { if !ok {
return nil, &apiError{errors.Errorf("provisioner not found"), return nil, errs.NotFound("provisioner not found")
http.StatusNotFound, apiCtx{}}
} }
return p, nil return p, nil
} }
@ -40,8 +37,7 @@ func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisi
func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) { func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
p, ok := a.provisioners.Load(id) p, ok := a.provisioners.Load(id)
if !ok { if !ok {
return nil, &apiError{errors.Errorf("provisioner not found"), return nil, errs.NotFound("provisioner not found")
http.StatusNotFound, apiCtx{}}
} }
return p, nil return p, nil
} }

View file

@ -7,13 +7,15 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
) )
func TestGetEncryptedKey(t *testing.T) { func TestGetEncryptedKey(t *testing.T) {
type ek struct { type ek struct {
a *Authority a *Authority
kid string kid string
err *apiError err error
code int
} }
tests := map[string]func(t *testing.T) *ek{ tests := map[string]func(t *testing.T) *ek{
"ok": func(t *testing.T) *ek { "ok": func(t *testing.T) *ek {
@ -34,8 +36,8 @@ func TestGetEncryptedKey(t *testing.T) {
return &ek{ return &ek{
a: a, a: a,
kid: "foo", kid: "foo",
err: &apiError{errors.Errorf("encrypted key with kid foo was not found"), err: errors.New("encrypted key with kid foo was not found"),
http.StatusNotFound, apiCtx{}}, code: http.StatusNotFound,
} }
}, },
} }
@ -47,14 +49,10 @@ func TestGetEncryptedKey(t *testing.T) {
ek, err := tc.a.GetEncryptedKey(tc.kid) ek, err := tc.a.GetEncryptedKey(tc.kid)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch v := err.(type) { sc, ok := err.(errs.StatusCoder)
case *apiError: assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.HasPrefix(t, v.err.Error(), tc.err.Error()) assert.Equals(t, sc.StatusCode(), tc.code)
assert.Equals(t, v.code, tc.err.code) assert.HasPrefix(t, err.Error(), tc.err.Error())
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
@ -72,7 +70,8 @@ func TestGetEncryptedKey(t *testing.T) {
func TestGetProvisioners(t *testing.T) { func TestGetProvisioners(t *testing.T) {
type gp struct { type gp struct {
a *Authority a *Authority
err *apiError err error
code int
} }
tests := map[string]func(t *testing.T) *gp{ tests := map[string]func(t *testing.T) *gp{
"ok": func(t *testing.T) *gp { "ok": func(t *testing.T) *gp {
@ -91,14 +90,10 @@ func TestGetProvisioners(t *testing.T) {
ps, next, err := tc.a.GetProvisioners("", 0) ps, next, err := tc.a.GetProvisioners("", 0)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch v := err.(type) { sc, ok := err.(errs.StatusCoder)
case *apiError: assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.HasPrefix(t, v.err.Error(), tc.err.Error()) assert.Equals(t, sc.StatusCode(), tc.code)
assert.Equals(t, v.code, tc.err.code) assert.HasPrefix(t, err.Error(), tc.err.Error())
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {

View file

@ -2,23 +2,20 @@ package authority
import ( import (
"crypto/x509" "crypto/x509"
"net/http"
"github.com/pkg/errors" "github.com/smallstep/certificates/errs"
) )
// Root returns the certificate corresponding to the given SHA sum argument. // Root returns the certificate corresponding to the given SHA sum argument.
func (a *Authority) Root(sum string) (*x509.Certificate, error) { func (a *Authority) Root(sum string) (*x509.Certificate, error) {
val, ok := a.certificates.Load(sum) val, ok := a.certificates.Load(sum)
if !ok { if !ok {
return nil, &apiError{errors.Errorf("certificate with fingerprint %s was not found", sum), return nil, errs.NotFound("certificate with fingerprint %s was not found", sum)
http.StatusNotFound, apiCtx{}}
} }
crt, ok := val.(*x509.Certificate) crt, ok := val.(*x509.Certificate)
if !ok { if !ok {
return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"), return nil, errs.InternalServer("stored value is not a *x509.Certificate")
http.StatusInternalServerError, apiCtx{}}
} }
return crt, nil return crt, nil
} }
@ -52,8 +49,7 @@ func (a *Authority) GetFederation() (federation []*x509.Certificate, err error)
crt, ok := v.(*x509.Certificate) crt, ok := v.(*x509.Certificate)
if !ok { if !ok {
federation = nil federation = nil
err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"), err = errs.InternalServer("stored value is not a *x509.Certificate")
http.StatusInternalServerError, apiCtx{}}
return false return false
} }
federation = append(federation, crt) federation = append(federation, crt)

View file

@ -8,6 +8,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/pemutil"
) )
@ -17,11 +18,12 @@ func TestRoot(t *testing.T) {
tests := map[string]struct { tests := map[string]struct {
sum string sum string
err *apiError err error
code int
}{ }{
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, apiCtx{}}}, "not-found": {"foo", errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound},
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, apiCtx{}}}, "invalid-stored-certificate": {"invaliddata", errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError},
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil}, "success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil, http.StatusOK},
} }
for name, tc := range tests { for name, tc := range tests {
@ -29,14 +31,10 @@ func TestRoot(t *testing.T) {
crt, err := a.Root(tc.sum) crt, err := a.Root(tc.sum)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch v := err.(type) { sc, ok := err.(errs.StatusCoder)
case *apiError: assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.HasPrefix(t, v.err.Error(), tc.err.Error()) assert.Equals(t, sc.StatusCode(), tc.code)
assert.Equals(t, v.code, tc.err.code) assert.HasPrefix(t, err.Error(), tc.err.Error())
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {

View file

@ -122,10 +122,7 @@ func (a *Authority) GetSSHFederation() (*SSHKeys, error) {
// GetSSHConfig returns rendered templates for clients (user) or servers (host). // GetSSHConfig returns rendered templates for clients (user) or servers (host).
func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) { func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) {
if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil { if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil {
return nil, &apiError{ return nil, errs.NotFound("getSSHConfig: ssh is not configured")
err: errors.New("getSSHConfig: ssh is not configured"),
code: http.StatusNotFound,
}
} }
var ts []templates.Template var ts []templates.Template
@ -139,10 +136,7 @@ func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]template
ts = a.config.Templates.SSH.Host ts = a.config.Templates.SSH.Host
} }
default: default:
return nil, &apiError{ return nil, errs.BadRequest("getSSHConfig: type %s is not valid", typ)
err: errors.Errorf("getSSHConfig: type %s is not valid", typ),
code: http.StatusBadRequest,
}
} }
// Merge user and default data // Merge user and default data
@ -174,7 +168,8 @@ func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]template
// hostname. // hostname.
func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error) { func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error) {
if a.sshBastionFunc != nil { if a.sshBastionFunc != nil {
return a.sshBastionFunc(user, hostname) bs, err := a.sshBastionFunc(user, hostname)
return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion")
} }
if a.config.SSH != nil { if a.config.SSH != nil {
if a.config.SSH.Bastion != nil && a.config.SSH.Bastion.Hostname != "" { if a.config.SSH.Bastion != nil && a.config.SSH.Bastion.Hostname != "" {
@ -182,32 +177,13 @@ func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error
} }
return nil, nil return nil, nil
} }
return nil, &apiError{ return nil, errs.NotFound("authority.GetSSHBastion; ssh is not configured")
err: errors.New("getSSHBastion: ssh is not configured"),
code: http.StatusNotFound,
}
}
// authorizeSSHSign loads the provisioner from the token, checks that it has not
// been used again and calls the provisioner AuthorizeSSHSign method. Returns a
// list of methods to apply to the signing flow.
func (a *Authority) authorizeSSHSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
var errContext = apiCtx{"ott": ott}
p, err := a.authorizeToken(ctx, ott)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeSSHSign"), http.StatusUnauthorized, errContext}
}
opts, err := p.AuthorizeSSHSign(ctx, ott)
if err != nil {
return nil, &apiError{errors.Wrap(err, "authorizeSSHSign"), http.StatusUnauthorized, errContext}
}
return opts, nil
} }
// SignSSH creates a signed SSH certificate with the given public key and options. // SignSSH creates a signed SSH certificate with the given public key and options.
func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var mods []provisioner.SSHCertificateModifier var mods []provisioner.SSHCertModifier
var validators []provisioner.SSHCertificateValidator var validators []provisioner.SSHCertValidator
// Set backdate with the configured value // Set backdate with the configured value
opts.Backdate = a.config.AuthorityConfig.Backdate.Duration opts.Backdate = a.config.AuthorityConfig.Backdate.Duration
@ -215,38 +191,32 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
for _, op := range signOpts { for _, op := range signOpts {
switch o := op.(type) { switch o := op.(type) {
// modify the ssh.Certificate // modify the ssh.Certificate
case provisioner.SSHCertificateModifier: case provisioner.SSHCertModifier:
mods = append(mods, o) mods = append(mods, o)
// modify the ssh.Certificate given the SSHOptions // modify the ssh.Certificate given the SSHOptions
case provisioner.SSHCertificateOptionModifier: case provisioner.SSHCertOptionModifier:
mods = append(mods, o.Option(opts)) mods = append(mods, o.Option(opts))
// validate the ssh.Certificate // validate the ssh.Certificate
case provisioner.SSHCertificateValidator: case provisioner.SSHCertValidator:
validators = append(validators, o) validators = append(validators, o)
// validate the given SSHOptions // validate the given SSHOptions
case provisioner.SSHCertificateOptionsValidator: case provisioner.SSHCertOptionsValidator:
if err := o.Valid(opts); err != nil { if err := o.Valid(opts); err != nil {
return nil, &apiError{err: err, code: http.StatusForbidden} return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
} }
default: default:
return nil, &apiError{ return nil, errs.InternalServer("signSSH: invalid extra option type %T", o)
err: errors.Errorf("signSSH: invalid extra option type %T", o),
code: http.StatusInternalServerError,
}
} }
} }
nonce, err := randutil.ASCII(32) nonce, err := randutil.ASCII(32)
if err != nil { if err != nil {
return nil, &apiError{err: err, code: http.StatusInternalServerError} return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH")
} }
var serial uint64 var serial uint64
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil { if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
return nil, &apiError{ return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error reading random number")
err: errors.Wrap(err, "signSSH: error reading random number"),
code: http.StatusInternalServerError,
}
} }
// Build base certificate with the key and some random values // Build base certificate with the key and some random values
@ -258,13 +228,13 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
// Use opts to modify the certificate // Use opts to modify the certificate
if err := opts.Modify(cert); err != nil { if err := opts.Modify(cert); err != nil {
return nil, &apiError{err: err, code: http.StatusForbidden} return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
} }
// Use provisioner modifiers // Use provisioner modifiers
for _, m := range mods { for _, m := range mods {
if err := m.Modify(cert); err != nil { if err := m.Modify(cert); err != nil {
return nil, &apiError{err: err, code: http.StatusForbidden} return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
} }
} }
@ -273,25 +243,16 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
switch cert.CertType { switch cert.CertType {
case ssh.UserCert: case ssh.UserCert:
if a.sshCAUserCertSignKey == nil { if a.sshCAUserCertSignKey == nil {
return nil, &apiError{ return nil, errs.NotImplemented("signSSH: user certificate signing is not enabled")
err: errors.New("signSSH: user certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
} }
signer = a.sshCAUserCertSignKey signer = a.sshCAUserCertSignKey
case ssh.HostCert: case ssh.HostCert:
if a.sshCAHostCertSignKey == nil { if a.sshCAHostCertSignKey == nil {
return nil, &apiError{ return nil, errs.NotImplemented("signSSH: host certificate signing is not enabled")
err: errors.New("signSSH: host certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
} }
signer = a.sshCAHostCertSignKey signer = a.sshCAHostCertSignKey
default: default:
return nil, &apiError{ return nil, errs.InternalServer("signSSH: unexpected ssh certificate type: %d", cert.CertType)
err: errors.Errorf("signSSH: unexpected ssh certificate type: %d", cert.CertType),
code: http.StatusInternalServerError,
}
} }
cert.SignatureKey = signer.PublicKey() cert.SignatureKey = signer.PublicKey()
@ -302,71 +263,38 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
// Sign the certificate // Sign the certificate
sig, err := signer.Sign(rand.Reader, data) sig, err := signer.Sign(rand.Reader, data)
if err != nil { if err != nil {
return nil, &apiError{ return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate")
err: errors.Wrap(err, "signSSH: error signing certificate"),
code: http.StatusInternalServerError,
}
} }
cert.Signature = sig cert.Signature = sig
// User provisioners validators // User provisioners validators
for _, v := range validators { for _, v := range validators {
if err := v.Valid(cert); err != nil { if err := v.Valid(cert, opts); err != nil {
return nil, &apiError{err: err, code: http.StatusForbidden} return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
} }
} }
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
return nil, &apiError{ return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error storing certificate in db")
err: errors.Wrap(err, "signSSH: error storing certificate in db"),
code: http.StatusInternalServerError,
}
} }
return cert, nil return cert, nil
} }
// authorizeSSHRenew authorizes an SSH certificate renewal request, by
// validating the contents of an SSHPOP token.
func (a *Authority) authorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
errContext := map[string]interface{}{"ott": token}
p, err := a.authorizeToken(ctx, token)
if err != nil {
return nil, &apiError{
err: errors.Wrap(err, "authorizeSSHRenew"),
code: http.StatusUnauthorized,
context: errContext,
}
}
cert, err := p.AuthorizeSSHRenew(ctx, token)
if err != nil {
return nil, &apiError{
err: errors.Wrap(err, "authorizeSSHRenew"),
code: http.StatusUnauthorized,
context: errContext,
}
}
return cert, nil
}
// RenewSSH creates a signed SSH certificate using the old SSH certificate as a template. // RenewSSH creates a signed SSH certificate using the old SSH certificate as a template.
func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) { func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) {
nonce, err := randutil.ASCII(32) nonce, err := randutil.ASCII(32)
if err != nil { if err != nil {
return nil, &apiError{err: err, code: http.StatusInternalServerError} return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH")
} }
var serial uint64 var serial uint64
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil { if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
return nil, &apiError{ return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error reading random number")
err: errors.Wrap(err, "renewSSH: error reading random number"),
code: http.StatusInternalServerError,
}
} }
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
return nil, errors.New("rewnewSSH: cannot renew certificate without validity period") return nil, errs.BadRequest("rewnewSSH: cannot renew certificate without validity period")
} }
backdate := a.config.AuthorityConfig.Backdate.Duration backdate := a.config.AuthorityConfig.Backdate.Duration
@ -393,25 +321,16 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
switch cert.CertType { switch cert.CertType {
case ssh.UserCert: case ssh.UserCert:
if a.sshCAUserCertSignKey == nil { if a.sshCAUserCertSignKey == nil {
return nil, &apiError{ return nil, errs.NotImplemented("renewSSH: user certificate signing is not enabled")
err: errors.New("renewSSH: user certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
} }
signer = a.sshCAUserCertSignKey signer = a.sshCAUserCertSignKey
case ssh.HostCert: case ssh.HostCert:
if a.sshCAHostCertSignKey == nil { if a.sshCAHostCertSignKey == nil {
return nil, &apiError{ return nil, errs.NotImplemented("renewSSH: host certificate signing is not enabled")
err: errors.New("renewSSH: host certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
} }
signer = a.sshCAHostCertSignKey signer = a.sshCAHostCertSignKey
default: default:
return nil, &apiError{ return nil, errs.InternalServer("renewSSH: unexpected ssh certificate type: %d", cert.CertType)
err: errors.Errorf("renewSSH: unexpected ssh certificate type: %d", cert.CertType),
code: http.StatusInternalServerError,
}
} }
cert.SignatureKey = signer.PublicKey() cert.SignatureKey = signer.PublicKey()
@ -422,79 +341,43 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
// Sign the certificate // Sign the certificate
sig, err := signer.Sign(rand.Reader, data) sig, err := signer.Sign(rand.Reader, data)
if err != nil { if err != nil {
return nil, &apiError{ return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error signing certificate")
err: errors.Wrap(err, "renewSSH: error signing certificate"),
code: http.StatusInternalServerError,
}
} }
cert.Signature = sig cert.Signature = sig
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
return nil, &apiError{ return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error storing certificate in db")
err: errors.Wrap(err, "renewSSH: error storing certificate in db"),
code: http.StatusInternalServerError,
}
} }
return cert, nil return cert, nil
} }
// authorizeSSHRekey authorizes an SSH certificate rekey request, by
// validating the contents of an SSHPOP token.
func (a *Authority) authorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []provisioner.SignOption, error) {
errContext := map[string]interface{}{"ott": token}
p, err := a.authorizeToken(ctx, token)
if err != nil {
return nil, nil, &apiError{
err: errors.Wrap(err, "authorizeSSHRenew"),
code: http.StatusUnauthorized,
context: errContext,
}
}
cert, opts, err := p.AuthorizeSSHRekey(ctx, token)
if err != nil {
return nil, nil, &apiError{
err: errors.Wrap(err, "authorizeSSHRekey"),
code: http.StatusUnauthorized,
context: errContext,
}
}
return cert, opts, nil
}
// RekeySSH creates a signed SSH certificate using the old SSH certificate as a template. // RekeySSH creates a signed SSH certificate using the old SSH certificate as a template.
func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var validators []provisioner.SSHCertificateValidator var validators []provisioner.SSHCertValidator
for _, op := range signOpts { for _, op := range signOpts {
switch o := op.(type) { switch o := op.(type) {
// validate the ssh.Certificate // validate the ssh.Certificate
case provisioner.SSHCertificateValidator: case provisioner.SSHCertValidator:
validators = append(validators, o) validators = append(validators, o)
default: default:
return nil, &apiError{ return nil, errs.InternalServer("rekeySSH; invalid extra option type %T", o)
err: errors.Errorf("rekeySSH: invalid extra option type %T", o),
code: http.StatusInternalServerError,
}
} }
} }
nonce, err := randutil.ASCII(32) nonce, err := randutil.ASCII(32)
if err != nil { if err != nil {
return nil, &apiError{err: err, code: http.StatusInternalServerError} return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH")
} }
var serial uint64 var serial uint64
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil { if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
return nil, &apiError{ return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error reading random number")
err: errors.Wrap(err, "rekeySSH: error reading random number"),
code: http.StatusInternalServerError,
}
} }
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
return nil, errors.New("rekeySSH: cannot rekey certificate without validity period") return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period")
} }
backdate := a.config.AuthorityConfig.Backdate.Duration backdate := a.config.AuthorityConfig.Backdate.Duration
@ -521,25 +404,16 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
switch cert.CertType { switch cert.CertType {
case ssh.UserCert: case ssh.UserCert:
if a.sshCAUserCertSignKey == nil { if a.sshCAUserCertSignKey == nil {
return nil, &apiError{ return nil, errs.NotImplemented("rekeySSH; user certificate signing is not enabled")
err: errors.New("rekeySSH: user certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
} }
signer = a.sshCAUserCertSignKey signer = a.sshCAUserCertSignKey
case ssh.HostCert: case ssh.HostCert:
if a.sshCAHostCertSignKey == nil { if a.sshCAHostCertSignKey == nil {
return nil, &apiError{ return nil, errs.NotImplemented("rekeySSH; host certificate signing is not enabled")
err: errors.New("rekeySSH: host certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
} }
signer = a.sshCAHostCertSignKey signer = a.sshCAHostCertSignKey
default: default:
return nil, &apiError{ return nil, errs.BadRequest("rekeySSH; unexpected ssh certificate type: %d", cert.CertType)
err: errors.Errorf("rekeySSH: unexpected ssh certificate type: %d", cert.CertType),
code: http.StatusInternalServerError,
}
} }
cert.SignatureKey = signer.PublicKey() cert.SignatureKey = signer.PublicKey()
@ -547,80 +421,47 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
data := cert.Marshal() data := cert.Marshal()
data = data[:len(data)-4] data = data[:len(data)-4]
// Sign the certificate // Sign the certificate.
sig, err := signer.Sign(rand.Reader, data) sig, err := signer.Sign(rand.Reader, data)
if err != nil { if err != nil {
return nil, &apiError{ return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error signing certificate")
err: errors.Wrap(err, "rekeySSH: error signing certificate"),
code: http.StatusInternalServerError,
}
} }
cert.Signature = sig cert.Signature = sig
// User provisioners validators // Apply validators from provisioner.
for _, v := range validators { for _, v := range validators {
if err := v.Valid(cert); err != nil { if err := v.Valid(cert, provisioner.SSHOptions{Backdate: backdate}); err != nil {
return nil, &apiError{err: err, code: http.StatusForbidden} return nil, errs.Wrap(http.StatusForbidden, err, "rekeySSH")
} }
} }
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
return nil, &apiError{ return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error storing certificate in db")
err: errors.Wrap(err, "rekeySSH: error storing certificate in db"),
code: http.StatusInternalServerError,
}
} }
return cert, nil return cert, nil
} }
// authorizeSSHRevoke authorizes an SSH certificate revoke request, by
// validating the contents of an SSHPOP token.
func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error {
errContext := map[string]interface{}{"ott": token}
p, err := a.authorizeToken(ctx, token)
if err != nil {
return &apiError{errors.Wrap(err, "authorizeSSHRevoke"), http.StatusUnauthorized, errContext}
}
if err = p.AuthorizeSSHRevoke(ctx, token); err != nil {
return &apiError{errors.Wrap(err, "authorizeSSHRevoke"), http.StatusUnauthorized, errContext}
}
return nil
}
// SignSSHAddUser signs a certificate that provisions a new user in a server. // SignSSHAddUser signs a certificate that provisions a new user in a server.
func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) { func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) {
if a.sshCAUserCertSignKey == nil { if a.sshCAUserCertSignKey == nil {
return nil, &apiError{ return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled")
err: errors.New("signSSHAddUser: user certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
} }
if subject.CertType != ssh.UserCert { if subject.CertType != ssh.UserCert {
return nil, &apiError{ return nil, errs.Forbidden("signSSHAddUser: certificate is not a user certificate")
err: errors.New("signSSHAddUser: certificate is not a user certificate"),
code: http.StatusForbidden,
}
} }
if len(subject.ValidPrincipals) != 1 { if len(subject.ValidPrincipals) != 1 {
return nil, &apiError{ return nil, errs.Forbidden("signSSHAddUser: certificate does not have only one principal")
err: errors.New("signSSHAddUser: certificate does not have only one principal"),
code: http.StatusForbidden,
}
} }
nonce, err := randutil.ASCII(32) nonce, err := randutil.ASCII(32)
if err != nil { if err != nil {
return nil, &apiError{err: err, code: http.StatusInternalServerError} return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser")
} }
var serial uint64 var serial uint64
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil { if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
return nil, &apiError{ return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error reading random number")
err: errors.Wrap(err, "signSSHAddUser: error reading random number"),
code: http.StatusInternalServerError,
}
} }
signer := a.sshCAUserCertSignKey signer := a.sshCAUserCertSignKey
@ -656,10 +497,7 @@ func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate)
cert.Signature = sig cert.Signature = sig
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
return nil, &apiError{ return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error storing certificate in db")
err: errors.Wrap(err, "signSSHAddUser: error storing certificate in db"),
code: http.StatusInternalServerError,
}
} }
return cert, nil return cert, nil
@ -691,14 +529,12 @@ func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token st
// GetSSHHosts returns a list of valid host principals. // GetSSHHosts returns a list of valid host principals.
func (a *Authority) GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) { func (a *Authority) GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) {
if a.sshGetHostsFunc != nil { if a.sshGetHostsFunc != nil {
return a.sshGetHostsFunc(cert) hosts, err := a.sshGetHostsFunc(cert)
return hosts, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts")
} }
hostnames, err := a.db.GetSSHHostPrincipals() hostnames, err := a.db.GetSSHHostPrincipals()
if err != nil { if err != nil {
return nil, &apiError{ return nil, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts")
err: errors.Wrap(err, "getSSHHosts"),
code: http.StatusInternalServerError,
}
} }
hosts := make([]sshutil.Host, len(hostnames)) hosts := make([]sshutil.Host, len(hostnames))

View file

@ -5,8 +5,10 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
"crypto/rand" "crypto/rand"
"crypto/x509"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net/http"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -15,6 +17,8 @@ import (
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/sshutil"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -58,7 +62,7 @@ func (m sshTestCertModifier) Modify(cert *ssh.Certificate) error {
type sshTestCertValidator string type sshTestCertValidator string
func (v sshTestCertValidator) Valid(crt *ssh.Certificate) error { func (v sshTestCertValidator) Valid(crt *ssh.Certificate, opts provisioner.SSHOptions) error {
if v == "" { if v == "" {
return nil return nil
} }
@ -76,7 +80,7 @@ func (v sshTestOptionsValidator) Valid(opts provisioner.SSHOptions) error {
type sshTestOptionsModifier string type sshTestOptionsModifier string
func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertificateModifier { func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertModifier {
return sshTestCertModifier(string(m)) return sshTestCertModifier(string(m))
} }
@ -488,18 +492,18 @@ func TestAuthority_CheckSSHHost(t *testing.T) {
want bool want bool
wantErr bool wantErr bool
}{ }{
{"true", fields{true, nil}, args{context.TODO(), "foo.internal.com", ""}, true, false}, {"true", fields{true, nil}, args{context.Background(), "foo.internal.com", ""}, true, false},
{"false", fields{false, nil}, args{context.TODO(), "foo.internal.com", ""}, false, false}, {"false", fields{false, nil}, args{context.Background(), "foo.internal.com", ""}, false, false},
{"notImplemented", fields{false, db.ErrNotImplemented}, args{context.TODO(), "foo.internal.com", ""}, false, true}, {"notImplemented", fields{false, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true},
{"notImplemented", fields{true, db.ErrNotImplemented}, args{context.TODO(), "foo.internal.com", ""}, false, true}, {"notImplemented", fields{true, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true},
{"internal", fields{false, fmt.Errorf("an error")}, args{context.TODO(), "foo.internal.com", ""}, false, true}, {"internal", fields{false, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true},
{"internal", fields{true, fmt.Errorf("an error")}, args{context.TODO(), "foo.internal.com", ""}, false, true}, {"internal", fields{true, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
a := testAuthority(t) a := testAuthority(t)
a.db = &MockAuthDB{ a.db = &db.MockAuthDB{
isSSHHost: func(_ string) (bool, error) { MIsSSHHost: func(_ string) (bool, error) {
return tt.fields.exists, tt.fields.err return tt.fields.exists, tt.fields.err
}, },
} }
@ -640,6 +644,9 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr)
return return
} else if err != nil {
_, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
} }
if !reflect.DeepEqual(got, tt.want) { if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want) t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want)
@ -647,3 +654,266 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
}) })
} }
} }
func TestAuthority_GetSSHHosts(t *testing.T) {
a := testAuthority(t)
type test struct {
getHostsFunc func(*x509.Certificate) ([]sshutil.Host, error)
auth *Authority
cert *x509.Certificate
cmp func(got []sshutil.Host)
err error
code int
}
tests := map[string]func(t *testing.T) *test{
"fail/getHostsFunc-fail": func(t *testing.T) *test {
return &test{
getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) {
return nil, errors.New("force")
},
cert: &x509.Certificate{},
err: errors.New("getSSHHosts: force"),
code: http.StatusInternalServerError,
}
},
"ok/getHostsFunc-defined": func(t *testing.T) *test {
hosts := []sshutil.Host{
{HostID: "1", Hostname: "foo"},
{HostID: "2", Hostname: "bar"},
}
return &test{
getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) {
return hosts, nil
},
cert: &x509.Certificate{},
cmp: func(got []sshutil.Host) {
assert.Equals(t, got, hosts)
},
}
},
"fail/db-get-fail": func(t *testing.T) *test {
return &test{
auth: testAuthority(t, WithDatabase(&db.MockAuthDB{
MGetSSHHostPrincipals: func() ([]string, error) {
return nil, errors.New("force")
},
})),
cert: &x509.Certificate{},
err: errors.New("getSSHHosts: force"),
code: http.StatusInternalServerError,
}
},
"ok": func(t *testing.T) *test {
return &test{
auth: testAuthority(t, WithDatabase(&db.MockAuthDB{
MGetSSHHostPrincipals: func() ([]string, error) {
return []string{"foo", "bar"}, nil
},
})),
cert: &x509.Certificate{},
cmp: func(got []sshutil.Host) {
assert.Equals(t, got, []sshutil.Host{
{Hostname: "foo"},
{Hostname: "bar"},
})
},
}
},
}
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
auth := tc.auth
if auth == nil {
auth = a
}
auth.sshGetHostsFunc = tc.getHostsFunc
hosts, err := auth.GetSSHHosts(tc.cert)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
tc.cmp(hosts)
}
}
})
}
}
func TestAuthority_RekeySSH(t *testing.T) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err)
pub, err := ssh.NewPublicKey(key.Public())
assert.FatalError(t, err)
signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err)
signer, err := ssh.NewSignerFromKey(signKey)
assert.FatalError(t, err)
userOptions := sshTestModifier{
CertType: ssh.UserCert,
}
now := time.Now().UTC()
a := testAuthority(t)
type test struct {
auth *Authority
userSigner ssh.Signer
hostSigner ssh.Signer
cert *ssh.Certificate
key ssh.PublicKey
signOpts []provisioner.SignOption
cmpResult func(old, n *ssh.Certificate)
err error
code int
}
tests := map[string]func(t *testing.T) *test{
"fail/opts-type": func(t *testing.T) *test {
return &test{
userSigner: signer,
hostSigner: signer,
key: pub,
signOpts: []provisioner.SignOption{userOptions},
err: errors.New("rekeySSH; invalid extra option type"),
code: http.StatusInternalServerError,
}
},
"fail/old-cert-validAfter": func(t *testing.T) *test {
return &test{
userSigner: signer,
hostSigner: signer,
cert: &ssh.Certificate{},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; cannot rekey certificate without validity period"),
code: http.StatusBadRequest,
}
},
"fail/old-cert-validBefore": func(t *testing.T) *test {
return &test{
userSigner: signer,
hostSigner: signer,
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; cannot rekey certificate without validity period"),
code: http.StatusBadRequest,
}
},
"fail/old-cert-no-user-key": func(t *testing.T) *test {
return &test{
userSigner: nil,
hostSigner: signer,
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.UserCert},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; user certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
},
"fail/old-cert-no-host-key": func(t *testing.T) *test {
return &test{
userSigner: signer,
hostSigner: nil,
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.HostCert},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; host certificate signing is not enabled"),
code: http.StatusNotImplemented,
}
},
"fail/unexpected-old-cert-type": func(t *testing.T) *test {
return &test{
userSigner: signer,
hostSigner: signer,
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; unexpected ssh certificate type: 0"),
code: http.StatusBadRequest,
}
},
"fail/db-store": func(t *testing.T) *test {
return &test{
auth: testAuthority(t, WithDatabase(&db.MockAuthDB{
MStoreSSHCertificate: func(cert *ssh.Certificate) error {
return errors.New("force")
},
})),
userSigner: signer,
hostSigner: nil,
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.UserCert},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; error storing certificate in db: force"),
code: http.StatusInternalServerError,
}
},
"ok": func(t *testing.T) *test {
va1 := now.Add(-24 * time.Hour)
vb1 := now.Add(-23 * time.Hour)
return &test{
userSigner: signer,
hostSigner: nil,
cert: &ssh.Certificate{
ValidAfter: uint64(va1.Unix()),
ValidBefore: uint64(vb1.Unix()),
CertType: ssh.UserCert,
ValidPrincipals: []string{"foo", "bar"},
KeyId: "foo",
},
key: pub,
signOpts: []provisioner.SignOption{},
cmpResult: func(old, n *ssh.Certificate) {
assert.Equals(t, n.CertType, old.CertType)
assert.Equals(t, n.ValidPrincipals, old.ValidPrincipals)
assert.Equals(t, n.KeyId, old.KeyId)
assert.True(t, n.ValidAfter > uint64(now.Add(-5*time.Minute).Unix()))
assert.True(t, n.ValidAfter < uint64(now.Add(5*time.Minute).Unix()))
l8r := now.Add(1 * time.Hour)
assert.True(t, n.ValidBefore > uint64(l8r.Add(-5*time.Minute).Unix()))
assert.True(t, n.ValidBefore < uint64(l8r.Add(5*time.Minute).Unix()))
},
}
},
}
for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
auth := tc.auth
if auth == nil {
auth = a
}
a.sshCAUserCertSignKey = tc.userSigner
a.sshCAHostCertSignKey = tc.hostSigner
cert, err := auth.RekeySSH(tc.cert, tc.key, tc.signOpts...)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
tc.cmpResult(tc.cert, cert)
}
}
})
}
}

View file

@ -0,0 +1 @@
ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJj80EJXJR9vxefhdqOLSdzRzBw24t9YKPxb+eCYLf7BU50pJQnB/jK2ZM3qLFbieLaYjngZ86T4DzHxlPAnlAY=

View file

@ -0,0 +1 @@
ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJ8einS88ZaWpcTZG27D5N9JDKfGv0rzjDByLGsZzMsLYl3XcsN9IWKXB6b+5GJ3UaoZf/pFxzRzIdDIh7Ypw3Y=

View file

@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIKZCgb5pTSSCbr/xcHCOkl9O6tQtZmNahr3Ap3/c2nBLoAoGCCqGSM49
AwEHoUQDQgAEmPzQQlclH2/F5+F2o4tJ3NHMHDbi31go/Fv54Jgt/sFTnSklCcH+
MrZkzeosVuJ4tpiOeBnzpPgPMfGU8CeUBg==
-----END EC PRIVATE KEY-----

View file

@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIDuzykyPM6rLnSoyF4jnOpPAlyKZERqtaB8PTh179DMgoAoGCCqGSM49
AwEHoUQDQgAEnx6KdLzxlpalxNkbbsPk30kMp8a/SvOMMHIsaxnMywtiXddyw30h
YpcHpv7kYndRqhl/+kXHNHMh0MiHtinDdg==
-----END EC PRIVATE KEY-----

View file

@ -14,6 +14,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/crypto/tlsutil"
"github.com/smallstep/cli/crypto/x509util" "github.com/smallstep/cli/crypto/x509util"
@ -60,7 +61,7 @@ func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
// Sign creates a signed certificate from a certificate signing request. // Sign creates a signed certificate from a certificate signing request.
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
var ( var (
errContext = apiCtx{"csr": csr, "signOptions": signOpts} opts = []interface{}{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)}
mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)} mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)}
certValidators = []provisioner.CertificateValidator{} certValidators = []provisioner.CertificateValidator{}
issIdentity = a.intermediateIdentity issIdentity = a.intermediateIdentity
@ -75,54 +76,52 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti
certValidators = append(certValidators, k) certValidators = append(certValidators, k)
case provisioner.CertificateRequestValidator: case provisioner.CertificateRequestValidator:
if err := k.Valid(csr); err != nil { if err := k.Valid(csr); err != nil {
return nil, &apiError{errors.Wrap(err, "sign"), http.StatusUnauthorized, errContext} return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...)
} }
case provisioner.ProfileModifier: case provisioner.ProfileModifier:
mods = append(mods, k.Option(signOpts)) mods = append(mods, k.Option(signOpts))
default: default:
return nil, &apiError{errors.Errorf("sign: invalid extra option type %T", k), return nil, errs.InternalServer("authority.Sign; invalid extra option type %T", append([]interface{}{k}, opts...)...)
http.StatusInternalServerError, errContext}
} }
} }
if err := csr.CheckSignature(); err != nil { if err := csr.CheckSignature(); err != nil {
return nil, &apiError{errors.Wrap(err, "sign: invalid certificate request"), return nil, errs.Wrap(http.StatusBadRequest, err, "authority.Sign; invalid certificate request", opts...)
http.StatusBadRequest, errContext}
} }
leaf, err := x509util.NewLeafProfileWithCSR(csr, issIdentity.Crt, issIdentity.Key, mods...) leaf, err := x509util.NewLeafProfileWithCSR(csr, issIdentity.Crt, issIdentity.Key, mods...)
if err != nil { if err != nil {
return nil, &apiError{errors.Wrapf(err, "sign"), http.StatusInternalServerError, errContext} return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign", opts...)
} }
for _, v := range certValidators { for _, v := range certValidators {
if err := v.Valid(leaf.Subject()); err != nil { if err := v.Valid(leaf.Subject(), signOpts); err != nil {
return nil, &apiError{errors.Wrap(err, "sign"), http.StatusUnauthorized, errContext} return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...)
} }
} }
crtBytes, err := leaf.CreateCertificate() crtBytes, err := leaf.CreateCertificate()
if err != nil { if err != nil {
return nil, &apiError{errors.Wrap(err, "sign: error creating new leaf certificate"), return nil, errs.Wrap(http.StatusInternalServerError, err,
http.StatusInternalServerError, errContext} "authority.Sign; error creating new leaf certificate", opts...)
} }
serverCert, err := x509.ParseCertificate(crtBytes) serverCert, err := x509.ParseCertificate(crtBytes)
if err != nil { if err != nil {
return nil, &apiError{errors.Wrap(err, "sign: error parsing new leaf certificate"), return nil, errs.Wrap(http.StatusInternalServerError, err,
http.StatusInternalServerError, errContext} "authority.Sign; error parsing new leaf certificate", opts...)
} }
caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw) caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw)
if err != nil { if err != nil {
return nil, &apiError{errors.Wrap(err, "sign: error parsing intermediate certificate"), return nil, errs.Wrap(http.StatusInternalServerError, err,
http.StatusInternalServerError, errContext} "authority.Sign; error parsing intermediate certificate", opts...)
} }
if err = a.db.StoreCertificate(serverCert); err != nil { if err = a.db.StoreCertificate(serverCert); err != nil {
if err != db.ErrNotImplemented { if err != db.ErrNotImplemented {
return nil, &apiError{errors.Wrap(err, "sign: error storing certificate in db"), return nil, errs.Wrap(http.StatusInternalServerError, err,
http.StatusInternalServerError, errContext} "authority.Sign; error storing certificate in db", opts...)
} }
} }
@ -132,9 +131,11 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti
// Renew creates a new Certificate identical to the old certificate, except // Renew creates a new Certificate identical to the old certificate, except
// with a validity window that begins 'now'. // with a validity window that begins 'now'.
func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) { func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) {
opts := []interface{}{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())}
// Check step provisioner extensions // Check step provisioner extensions
if err := a.authorizeRenew(oldCert); err != nil { if err := a.authorizeRenew(oldCert); err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Renew", opts...)
} }
// Issuer // Issuer
@ -161,16 +162,16 @@ func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error
MaxPathLenZero: oldCert.MaxPathLenZero, MaxPathLenZero: oldCert.MaxPathLenZero,
OCSPServer: oldCert.OCSPServer, OCSPServer: oldCert.OCSPServer,
IssuingCertificateURL: oldCert.IssuingCertificateURL, IssuingCertificateURL: oldCert.IssuingCertificateURL,
PermittedDNSDomainsCritical: oldCert.PermittedDNSDomainsCritical,
PermittedEmailAddresses: oldCert.PermittedEmailAddresses,
DNSNames: oldCert.DNSNames, DNSNames: oldCert.DNSNames,
EmailAddresses: oldCert.EmailAddresses, EmailAddresses: oldCert.EmailAddresses,
IPAddresses: oldCert.IPAddresses, IPAddresses: oldCert.IPAddresses,
URIs: oldCert.URIs, URIs: oldCert.URIs,
PermittedDNSDomainsCritical: oldCert.PermittedDNSDomainsCritical,
PermittedDNSDomains: oldCert.PermittedDNSDomains, PermittedDNSDomains: oldCert.PermittedDNSDomains,
ExcludedDNSDomains: oldCert.ExcludedDNSDomains, ExcludedDNSDomains: oldCert.ExcludedDNSDomains,
PermittedIPRanges: oldCert.PermittedIPRanges, PermittedIPRanges: oldCert.PermittedIPRanges,
ExcludedIPRanges: oldCert.ExcludedIPRanges, ExcludedIPRanges: oldCert.ExcludedIPRanges,
PermittedEmailAddresses: oldCert.PermittedEmailAddresses,
ExcludedEmailAddresses: oldCert.ExcludedEmailAddresses, ExcludedEmailAddresses: oldCert.ExcludedEmailAddresses,
PermittedURIDomains: oldCert.PermittedURIDomains, PermittedURIDomains: oldCert.PermittedURIDomains,
ExcludedURIDomains: oldCert.ExcludedURIDomains, ExcludedURIDomains: oldCert.ExcludedURIDomains,
@ -190,29 +191,28 @@ func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error
leaf, err := x509util.NewLeafProfileWithTemplate(newCert, leaf, err := x509util.NewLeafProfileWithTemplate(newCert,
issIdentity.Crt, issIdentity.Key) issIdentity.Crt, issIdentity.Key)
if err != nil { if err != nil {
return nil, &apiError{err, http.StatusInternalServerError, apiCtx{}} return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Renew", opts...)
} }
crtBytes, err := leaf.CreateCertificate() crtBytes, err := leaf.CreateCertificate()
if err != nil { if err != nil {
return nil, &apiError{errors.Wrap(err, "error renewing certificate from existing server certificate"), return nil, errs.Wrap(http.StatusInternalServerError, err,
http.StatusInternalServerError, apiCtx{}} "authority.Renew; error renewing certificate from existing server certificate", opts...)
} }
serverCert, err := x509.ParseCertificate(crtBytes) serverCert, err := x509.ParseCertificate(crtBytes)
if err != nil { if err != nil {
return nil, &apiError{errors.Wrap(err, "error parsing new server certificate"), return nil, errs.Wrap(http.StatusInternalServerError, err,
http.StatusInternalServerError, apiCtx{}} "authority.Renew; error parsing new server certificate", opts...)
} }
caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw) caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw)
if err != nil { if err != nil {
return nil, &apiError{errors.Wrap(err, "error parsing intermediate certificate"), return nil, errs.Wrap(http.StatusInternalServerError, err,
http.StatusInternalServerError, apiCtx{}} "authority.Renew; error parsing intermediate certificate", opts...)
} }
if err = a.db.StoreCertificate(serverCert); err != nil { if err = a.db.StoreCertificate(serverCert); err != nil {
if err != db.ErrNotImplemented { if err != db.ErrNotImplemented {
return nil, &apiError{errors.Wrap(err, "error storing certificate in db"), return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Renew; error storing certificate in db", opts...)
http.StatusInternalServerError, apiCtx{}}
} }
} }
@ -236,26 +236,26 @@ type RevokeOptions struct {
// being renewed. // being renewed.
// //
// TODO: Add OCSP and CRL support. // TODO: Add OCSP and CRL support.
func (a *Authority) Revoke(ctx context.Context, opts *RevokeOptions) error { func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error {
errContext := apiCtx{ opts := []interface{}{
"serialNumber": opts.Serial, errs.WithKeyVal("serialNumber", revokeOpts.Serial),
"reasonCode": opts.ReasonCode, errs.WithKeyVal("reasonCode", revokeOpts.ReasonCode),
"reason": opts.Reason, errs.WithKeyVal("reason", revokeOpts.Reason),
"passiveOnly": opts.PassiveOnly, errs.WithKeyVal("passiveOnly", revokeOpts.PassiveOnly),
"mTLS": opts.MTLS, errs.WithKeyVal("MTLS", revokeOpts.MTLS),
"context": string(provisioner.MethodFromContext(ctx)), errs.WithKeyVal("context", string(provisioner.MethodFromContext(ctx))),
} }
if opts.MTLS { if revokeOpts.MTLS {
errContext["certificate"] = base64.StdEncoding.EncodeToString(opts.Crt.Raw) opts = append(opts, errs.WithKeyVal("certificate", base64.StdEncoding.EncodeToString(revokeOpts.Crt.Raw)))
} else { } else {
errContext["ott"] = opts.OTT opts = append(opts, errs.WithKeyVal("token", revokeOpts.OTT))
} }
rci := &db.RevokedCertificateInfo{ rci := &db.RevokedCertificateInfo{
Serial: opts.Serial, Serial: revokeOpts.Serial,
ReasonCode: opts.ReasonCode, ReasonCode: revokeOpts.ReasonCode,
Reason: opts.Reason, Reason: revokeOpts.Reason,
MTLS: opts.MTLS, MTLS: revokeOpts.MTLS,
RevokedAt: time.Now().UTC(), RevokedAt: time.Now().UTC(),
} }
@ -264,48 +264,43 @@ func (a *Authority) Revoke(ctx context.Context, opts *RevokeOptions) error {
err error err error
) )
// If not mTLS then get the TokenID of the token. // If not mTLS then get the TokenID of the token.
if !opts.MTLS { if !revokeOpts.MTLS {
// Validate payload token, err := jose.ParseSigned(revokeOpts.OTT)
token, err := jose.ParseSigned(opts.OTT)
if err != nil { if err != nil {
return &apiError{errors.Wrapf(err, "revoke: error parsing token"), return errs.Wrap(http.StatusUnauthorized, err,
http.StatusUnauthorized, errContext} "authority.Revoke; error parsing token", opts...)
} }
// Get claims w/out verification. We should have already verified this token // Get claims w/out verification.
// earlier with a call to authorizeSSHRevoke.
var claims Claims var claims Claims
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil { if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil {
return &apiError{errors.Wrap(err, "revoke"), http.StatusUnauthorized, errContext} return errs.Wrap(http.StatusUnauthorized, err, "authority.Revoke", opts...)
} }
// This method will also validate the audiences for JWK provisioners. // This method will also validate the audiences for JWK provisioners.
var ok bool var ok bool
p, ok = a.provisioners.LoadByToken(token, &claims.Claims) p, ok = a.provisioners.LoadByToken(token, &claims.Claims)
if !ok { if !ok {
return &apiError{ return errs.InternalServer("authority.Revoke; provisioner not found", opts...)
errors.Errorf("revoke: provisioner not found"),
http.StatusInternalServerError, errContext}
} }
rci.TokenID, err = p.GetTokenID(opts.OTT) rci.TokenID, err = p.GetTokenID(revokeOpts.OTT)
if err != nil { if err != nil {
return &apiError{errors.Wrap(err, "revoke: could not get ID for token"), return errs.Wrap(http.StatusInternalServerError, err,
http.StatusInternalServerError, errContext} "authority.Revoke; could not get ID for token")
} }
errContext["tokenID"] = rci.TokenID opts = append(opts, errs.WithKeyVal("tokenID", rci.TokenID))
} else { } else {
// Load the Certificate provisioner if one exists. // Load the Certificate provisioner if one exists.
p, err = a.LoadProvisionerByCertificate(opts.Crt) p, err = a.LoadProvisionerByCertificate(revokeOpts.Crt)
if err != nil { if err != nil {
return &apiError{ return errs.Wrap(http.StatusUnauthorized, err,
errors.Wrap(err, "revoke: unable to load certificate provisioner"), "authority.Revoke: unable to load certificate provisioner", opts...)
http.StatusUnauthorized, errContext}
} }
} }
rci.ProvisionerID = p.GetID() rci.ProvisionerID = p.GetID()
errContext["provisionerID"] = rci.ProvisionerID opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID))
if provisioner.MethodFromContext(ctx) == provisioner.RevokeSSHMethod { if provisioner.MethodFromContext(ctx) == provisioner.SSHRevokeMethod {
err = a.db.RevokeSSH(rci) err = a.db.RevokeSSH(rci)
} else { // default to revoke x509 } else { // default to revoke x509
err = a.db.Revoke(rci) err = a.db.Revoke(rci)
@ -314,13 +309,12 @@ func (a *Authority) Revoke(ctx context.Context, opts *RevokeOptions) error {
case nil: case nil:
return nil return nil
case db.ErrNotImplemented: case db.ErrNotImplemented:
return &apiError{errors.New("revoke: no persistence layer configured"), return errs.NotImplemented("authority.Revoke; no persistence layer configured", opts...)
http.StatusNotImplemented, errContext}
case db.ErrAlreadyExists: case db.ErrAlreadyExists:
return &apiError{errors.Errorf("revoke: certificate with serial number %s has already been revoked", rci.Serial), return errs.BadRequest("authority.Revoke; certificate with serial "+
http.StatusBadRequest, errContext} "number %s has already been revoked", append([]interface{}{rci.Serial}, opts...)...)
default: default:
return &apiError{err, http.StatusInternalServerError, errContext} return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...)
} }
} }
@ -330,17 +324,17 @@ func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) {
a.intermediateIdentity.Crt, a.intermediateIdentity.Key, a.intermediateIdentity.Crt, a.intermediateIdentity.Key,
x509util.WithHosts(strings.Join(a.config.DNSNames, ","))) x509util.WithHosts(strings.Join(a.config.DNSNames, ",")))
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetTLSCertificate")
} }
crtBytes, err := profile.CreateCertificate() crtBytes, err := profile.CreateCertificate()
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetTLSCertificate")
} }
keyPEM, err := pemutil.Serialize(profile.SubjectPrivateKey()) keyPEM, err := pemutil.Serialize(profile.SubjectPrivateKey())
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetTLSCertificate")
} }
crtPEM := pem.EncodeToMemory(&pem.Block{ crtPEM := pem.EncodeToMemory(&pem.Block{
@ -352,19 +346,21 @@ func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) {
// to a tls.Certificate. // to a tls.Certificate.
intermediatePEM, err := pemutil.Serialize(a.intermediateIdentity.Crt) intermediatePEM, err := pemutil.Serialize(a.intermediateIdentity.Crt)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetTLSCertificate")
} }
tlsCrt, err := tls.X509KeyPair(append(crtPEM, tlsCrt, err := tls.X509KeyPair(append(crtPEM,
pem.EncodeToMemory(intermediatePEM)...), pem.EncodeToMemory(intermediatePEM)...),
pem.EncodeToMemory(keyPEM)) pem.EncodeToMemory(keyPEM))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error creating tls certificate") return nil, errs.Wrap(http.StatusInternalServerError, err,
"authority.GetTLSCertificate; error creating tls certificate")
} }
// Get the 'leaf' certificate and set the attribute accordingly. // Get the 'leaf' certificate and set the attribute accordingly.
leaf, err := x509.ParseCertificate(tlsCrt.Certificate[0]) leaf, err := x509.ParseCertificate(tlsCrt.Certificate[0])
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error parsing tls certificate") return nil, errs.Wrap(http.StatusInternalServerError, err,
"authority.GetTLSCertificate; error parsing tls certificate")
} }
tlsCrt.Leaf = leaf tlsCrt.Leaf = leaf

View file

@ -7,7 +7,6 @@ import (
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
"encoding/base64"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"net/http" "net/http"
@ -19,6 +18,7 @@ import (
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/keys" "github.com/smallstep/cli/crypto/keys"
"github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/crypto/tlsutil"
@ -77,7 +77,7 @@ func getCSR(t *testing.T, priv interface{}, opts ...func(*x509.CertificateReques
return csr return csr
} }
func TestSign(t *testing.T) { func TestAuthority_Sign(t *testing.T) {
pub, priv, err := keys.GenerateDefaultKeyPair() pub, priv, err := keys.GenerateDefaultKeyPair()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -102,7 +102,7 @@ func TestSign(t *testing.T) {
p := a.config.AuthorityConfig.Provisioners[1].(*provisioner.JWK) p := a.config.AuthorityConfig.Provisioners[1].(*provisioner.JWK)
key, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) key, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
assert.FatalError(t, err) assert.FatalError(t, err)
token, err := generateToken("smallstep test", "step-cli", "https://test.ca.smallstep.com/sign", []string{"test.smallstep.com"}, time.Now(), key) token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
extraOpts, err := a.Authorize(ctx, token) extraOpts, err := a.Authorize(ctx, token)
@ -113,7 +113,8 @@ func TestSign(t *testing.T) {
csr *x509.CertificateRequest csr *x509.CertificateRequest
signOpts provisioner.Options signOpts provisioner.Options
extraOpts []provisioner.SignOption extraOpts []provisioner.SignOption
err *apiError err error
code int
} }
tests := map[string]func(*testing.T) *signTest{ tests := map[string]func(*testing.T) *signTest{
"fail invalid signature": func(t *testing.T) *signTest { "fail invalid signature": func(t *testing.T) *signTest {
@ -124,10 +125,8 @@ func TestSign(t *testing.T) {
csr: csr, csr: csr,
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: signOpts, signOpts: signOpts,
err: &apiError{errors.New("sign: invalid certificate request"), err: errors.New("authority.Sign; invalid certificate request"),
http.StatusBadRequest, code: http.StatusBadRequest,
apiCtx{"csr": csr, "signOptions": signOpts},
},
} }
}, },
"fail invalid extra option": func(t *testing.T) *signTest { "fail invalid extra option": func(t *testing.T) *signTest {
@ -138,10 +137,8 @@ func TestSign(t *testing.T) {
csr: csr, csr: csr,
extraOpts: append(extraOpts, "42"), extraOpts: append(extraOpts, "42"),
signOpts: signOpts, signOpts: signOpts,
err: &apiError{errors.New("sign: invalid extra option type string"), err: errors.New("authority.Sign; invalid extra option type string"),
http.StatusInternalServerError, code: http.StatusInternalServerError,
apiCtx{"csr": csr, "signOptions": signOpts},
},
} }
}, },
"fail merge default ASN1DN": func(t *testing.T) *signTest { "fail merge default ASN1DN": func(t *testing.T) *signTest {
@ -153,10 +150,8 @@ func TestSign(t *testing.T) {
csr: csr, csr: csr,
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: signOpts, signOpts: signOpts,
err: &apiError{errors.New("sign: default ASN1DN template cannot be nil"), err: errors.New("authority.Sign: default ASN1DN template cannot be nil"),
http.StatusInternalServerError, code: http.StatusInternalServerError,
apiCtx{"csr": csr, "signOptions": signOpts},
},
} }
}, },
"fail create cert": func(t *testing.T) *signTest { "fail create cert": func(t *testing.T) *signTest {
@ -168,10 +163,8 @@ func TestSign(t *testing.T) {
csr: csr, csr: csr,
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: signOpts, signOpts: signOpts,
err: &apiError{errors.New("sign: error creating new leaf certificate"), err: errors.New("authority.Sign; error creating new leaf certificate"),
http.StatusInternalServerError, code: http.StatusInternalServerError,
apiCtx{"csr": csr, "signOptions": signOpts},
},
} }
}, },
"fail provisioner duration claim": func(t *testing.T) *signTest { "fail provisioner duration claim": func(t *testing.T) *signTest {
@ -185,10 +178,8 @@ func TestSign(t *testing.T) {
csr: csr, csr: csr,
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: _signOpts, signOpts: _signOpts,
err: &apiError{errors.New("sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h0m0s"), err: errors.New("authority.Sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h1m0s"),
http.StatusUnauthorized, code: http.StatusUnauthorized,
apiCtx{"csr": csr, "signOptions": _signOpts},
},
} }
}, },
"fail validate sans when adding common name not in claims": func(t *testing.T) *signTest { "fail validate sans when adding common name not in claims": func(t *testing.T) *signTest {
@ -200,10 +191,8 @@ func TestSign(t *testing.T) {
csr: csr, csr: csr,
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: signOpts, signOpts: signOpts,
err: &apiError{errors.New("sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"), err: errors.New("authority.Sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
http.StatusUnauthorized, code: http.StatusUnauthorized,
apiCtx{"csr": csr, "signOptions": signOpts},
},
} }
}, },
"fail rsa key too short": func(t *testing.T) *signTest { "fail rsa key too short": func(t *testing.T) *signTest {
@ -228,20 +217,16 @@ ZYtQ9Ot36qc=
csr: csr, csr: csr,
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: signOpts, signOpts: signOpts,
err: &apiError{errors.New("sign: rsa key in CSR must be at least 2048 bits (256 bytes)"), err: errors.New("authority.Sign: rsa key in CSR must be at least 2048 bits (256 bytes)"),
http.StatusUnauthorized, code: http.StatusUnauthorized,
apiCtx{"csr": csr, "signOptions": signOpts},
},
} }
}, },
"fail store cert in db": func(t *testing.T) *signTest { "fail store cert in db": func(t *testing.T) *signTest {
csr := getCSR(t, priv) csr := getCSR(t, priv)
_a := testAuthority(t) _a := testAuthority(t)
_a.db = &MockAuthDB{ _a.db = &db.MockAuthDB{
storeCertificate: func(crt *x509.Certificate) error { MStoreCertificate: func(crt *x509.Certificate) error {
return &apiError{errors.New("force"), return errors.New("force")
http.StatusInternalServerError,
apiCtx{"csr": csr, "signOptions": signOpts}}
}, },
} }
return &signTest{ return &signTest{
@ -249,17 +234,15 @@ ZYtQ9Ot36qc=
csr: csr, csr: csr,
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: signOpts, signOpts: signOpts,
err: &apiError{errors.New("sign: error storing certificate in db: force"), err: errors.New("authority.Sign; error storing certificate in db: force"),
http.StatusInternalServerError, code: http.StatusInternalServerError,
apiCtx{"csr": csr, "signOptions": signOpts},
},
} }
}, },
"ok": func(t *testing.T) *signTest { "ok": func(t *testing.T) *signTest {
csr := getCSR(t, priv) csr := getCSR(t, priv)
_a := testAuthority(t) _a := testAuthority(t)
_a.db = &MockAuthDB{ _a.db = &db.MockAuthDB{
storeCertificate: func(crt *x509.Certificate) error { MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test") assert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil return nil
}, },
@ -279,15 +262,17 @@ ZYtQ9Ot36qc=
certChain, err := tc.auth.Sign(tc.csr, tc.signOpts, tc.extraOpts...) certChain, err := tc.auth.Sign(tc.csr, tc.signOpts, tc.extraOpts...)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
switch v := err.(type) { assert.Nil(t, certChain)
case *apiError: sc, ok := err.(errs.StatusCoder)
assert.HasPrefix(t, v.err.Error(), tc.err.Error()) assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, v.code, tc.err.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.Equals(t, v.context, tc.err.context) assert.HasPrefix(t, err.Error(), tc.err.Error())
default:
t.Errorf("unexpected error type: %T", v) ctxErr, ok := err.(*errs.Error)
} assert.Fatal(t, ok, "error is not of type *errs.Error")
assert.Equals(t, ctxErr.Details["csr"], tc.csr)
assert.Equals(t, ctxErr.Details["signOptions"], tc.signOpts)
} }
} else { } else {
leaf := certChain[0] leaf := certChain[0]
@ -346,7 +331,7 @@ ZYtQ9Ot36qc=
} }
} }
func TestRenew(t *testing.T) { func TestAuthority_Renew(t *testing.T) {
pub, _, err := keys.GenerateDefaultKeyPair() pub, _, err := keys.GenerateDefaultKeyPair()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -375,9 +360,9 @@ func TestRenew(t *testing.T) {
x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"), x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"),
withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].(*provisioner.JWK).Key.KeyID)) withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].(*provisioner.JWK).Key.KeyID))
assert.FatalError(t, err) assert.FatalError(t, err)
crtBytes, err := leaf.CreateCertificate() certBytes, err := leaf.CreateCertificate()
assert.FatalError(t, err) assert.FatalError(t, err)
crt, err := x509.ParseCertificate(crtBytes) cert, err := x509.ParseCertificate(certBytes)
assert.FatalError(t, err) assert.FatalError(t, err)
leafNoRenew, err := x509util.NewLeafProfile("norenew", a.intermediateIdentity.Crt, leafNoRenew, err := x509util.NewLeafProfile("norenew", a.intermediateIdentity.Crt,
@ -388,15 +373,16 @@ func TestRenew(t *testing.T) {
withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].(*provisioner.JWK).Key.KeyID), withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].(*provisioner.JWK).Key.KeyID),
) )
assert.FatalError(t, err) assert.FatalError(t, err)
crtBytesNoRenew, err := leafNoRenew.CreateCertificate() certBytesNoRenew, err := leafNoRenew.CreateCertificate()
assert.FatalError(t, err) assert.FatalError(t, err)
crtNoRenew, err := x509.ParseCertificate(crtBytesNoRenew) certNoRenew, err := x509.ParseCertificate(certBytesNoRenew)
assert.FatalError(t, err) assert.FatalError(t, err)
type renewTest struct { type renewTest struct {
auth *Authority auth *Authority
crt *x509.Certificate cert *x509.Certificate
err *apiError err error
code int
} }
tests := map[string]func() (*renewTest, error){ tests := map[string]func() (*renewTest, error){
"fail-create-cert": func() (*renewTest, error) { "fail-create-cert": func() (*renewTest, error) {
@ -404,25 +390,22 @@ func TestRenew(t *testing.T) {
_a.intermediateIdentity.Key = nil _a.intermediateIdentity.Key = nil
return &renewTest{ return &renewTest{
auth: _a, auth: _a,
crt: crt, cert: cert,
err: &apiError{errors.New("error renewing certificate from existing server certificate"), err: errors.New("authority.Renew; error renewing certificate from existing server certificate"),
http.StatusInternalServerError, apiCtx{}}, code: http.StatusInternalServerError,
}, nil }, nil
}, },
"fail-unauthorized": func() (*renewTest, error) { "fail-unauthorized": func() (*renewTest, error) {
ctx := map[string]interface{}{
"serialNumber": crtNoRenew.SerialNumber.String(),
}
return &renewTest{ return &renewTest{
crt: crtNoRenew, cert: certNoRenew,
err: &apiError{errors.New("renew: renew is disabled for provisioner dev:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"), err: errors.New("authority.Renew: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner dev:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"),
http.StatusUnauthorized, ctx}, code: http.StatusUnauthorized,
}, nil }, nil
}, },
"success": func() (*renewTest, error) { "success": func() (*renewTest, error) {
return &renewTest{ return &renewTest{
auth: a, auth: a,
crt: crt, cert: cert,
}, nil }, nil
}, },
"success-new-intermediate": func() (*renewTest, error) { "success-new-intermediate": func() (*renewTest, error) {
@ -430,23 +413,23 @@ func TestRenew(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
newRootBytes, err := newRootProfile.CreateCertificate() newRootBytes, err := newRootProfile.CreateCertificate()
assert.FatalError(t, err) assert.FatalError(t, err)
newRootCrt, err := x509.ParseCertificate(newRootBytes) newRootCert, err := x509.ParseCertificate(newRootBytes)
assert.FatalError(t, err) assert.FatalError(t, err)
newIntermediateProfile, err := x509util.NewIntermediateProfile("new-intermediate", newIntermediateProfile, err := x509util.NewIntermediateProfile("new-intermediate",
newRootCrt, newRootProfile.SubjectPrivateKey()) newRootCert, newRootProfile.SubjectPrivateKey())
assert.FatalError(t, err) assert.FatalError(t, err)
newIntermediateBytes, err := newIntermediateProfile.CreateCertificate() newIntermediateBytes, err := newIntermediateProfile.CreateCertificate()
assert.FatalError(t, err) assert.FatalError(t, err)
newIntermediateCrt, err := x509.ParseCertificate(newIntermediateBytes) newIntermediateCert, err := x509.ParseCertificate(newIntermediateBytes)
assert.FatalError(t, err) assert.FatalError(t, err)
_a := testAuthority(t) _a := testAuthority(t)
_a.intermediateIdentity.Key = newIntermediateProfile.SubjectPrivateKey() _a.intermediateIdentity.Key = newIntermediateProfile.SubjectPrivateKey()
_a.intermediateIdentity.Crt = newIntermediateCrt _a.intermediateIdentity.Crt = newIntermediateCert
return &renewTest{ return &renewTest{
auth: _a, auth: _a,
crt: crt, cert: cert,
}, nil }, nil
}, },
} }
@ -458,32 +441,33 @@ func TestRenew(t *testing.T) {
var certChain []*x509.Certificate var certChain []*x509.Certificate
if tc.auth != nil { if tc.auth != nil {
certChain, err = tc.auth.Renew(tc.crt) certChain, err = tc.auth.Renew(tc.cert)
} else { } else {
certChain, err = a.Renew(tc.crt) certChain, err = a.Renew(tc.cert)
} }
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
switch v := err.(type) { assert.Nil(t, certChain)
case *apiError: sc, ok := err.(errs.StatusCoder)
assert.HasPrefix(t, v.err.Error(), tc.err.Error()) assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, v.code, tc.err.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.Equals(t, v.context, tc.err.context) assert.HasPrefix(t, err.Error(), tc.err.Error())
default:
t.Errorf("unexpected error type: %T", v) ctxErr, ok := err.(*errs.Error)
} assert.Fatal(t, ok, "error is not of type *errs.Error")
assert.Equals(t, ctxErr.Details["serialNumber"], tc.cert.SerialNumber.String())
} }
} else { } else {
leaf := certChain[0] leaf := certChain[0]
intermediate := certChain[1] intermediate := certChain[1]
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Equals(t, leaf.NotAfter.Sub(leaf.NotBefore), tc.crt.NotAfter.Sub(crt.NotBefore)) assert.Equals(t, leaf.NotAfter.Sub(leaf.NotBefore), tc.cert.NotAfter.Sub(cert.NotBefore))
assert.True(t, leaf.NotBefore.After(now.Add(-time.Minute))) assert.True(t, leaf.NotBefore.After(now.Add(-2*time.Minute)))
assert.True(t, leaf.NotBefore.Before(now.Add(time.Minute))) assert.True(t, leaf.NotBefore.Before(now.Add(time.Minute)))
expiry := now.Add(time.Minute * 7) expiry := now.Add(time.Minute * 7)
assert.True(t, leaf.NotAfter.After(expiry.Add(-time.Minute))) assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute)))
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute)))
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
@ -513,7 +497,7 @@ func TestRenew(t *testing.T) {
if a.intermediateIdentity.Crt.SerialNumber == tc.auth.intermediateIdentity.Crt.SerialNumber { if a.intermediateIdentity.Crt.SerialNumber == tc.auth.intermediateIdentity.Crt.SerialNumber {
assert.Equals(t, leaf.AuthorityKeyId, a.intermediateIdentity.Crt.SubjectKeyId) assert.Equals(t, leaf.AuthorityKeyId, a.intermediateIdentity.Crt.SubjectKeyId)
// Compare extensions: they can be in a different order // Compare extensions: they can be in a different order
for _, ext1 := range tc.crt.Extensions { for _, ext1 := range tc.cert.Extensions {
found := false found := false
for _, ext2 := range leaf.Extensions { for _, ext2 := range leaf.Extensions {
if reflect.DeepEqual(ext1, ext2) { if reflect.DeepEqual(ext1, ext2) {
@ -529,7 +513,7 @@ func TestRenew(t *testing.T) {
// We did change the intermediate before renewing. // We did change the intermediate before renewing.
assert.Equals(t, leaf.AuthorityKeyId, tc.auth.intermediateIdentity.Crt.SubjectKeyId) assert.Equals(t, leaf.AuthorityKeyId, tc.auth.intermediateIdentity.Crt.SubjectKeyId)
// Compare extensions: they can be in a different order // Compare extensions: they can be in a different order
for _, ext1 := range tc.crt.Extensions { for _, ext1 := range tc.cert.Extensions {
// The authority key id extension should be different b/c the intermediates are different. // The authority key id extension should be different b/c the intermediates are different.
if ext1.Id.Equal(oidAuthorityKeyIdentifier) { if ext1.Id.Equal(oidAuthorityKeyIdentifier) {
for _, ext2 := range leaf.Extensions { for _, ext2 := range leaf.Extensions {
@ -560,7 +544,7 @@ func TestRenew(t *testing.T) {
} }
} }
func TestGetTLSOptions(t *testing.T) { func TestAuthority_GetTLSOptions(t *testing.T) {
type renewTest struct { type renewTest struct {
auth *Authority auth *Authority
opts *tlsutil.TLSOptions opts *tlsutil.TLSOptions
@ -596,21 +580,12 @@ func TestGetTLSOptions(t *testing.T) {
} }
} }
func TestRevoke(t *testing.T) { func TestAuthority_Revoke(t *testing.T) {
reasonCode := 2 reasonCode := 2
reason := "bob was let go" reason := "bob was let go"
validIssuer := "step-cli" validIssuer := "step-cli"
validAudience := []string{"https://test.ca.smallstep.com/revoke"} validAudience := testAudiences.Revoke
now := time.Now().UTC() now := time.Now().UTC()
getCtx := func() map[string]interface{} {
return apiCtx{
"serialNumber": "sn",
"reasonCode": reasonCode,
"reason": reason,
"mTLS": false,
"passiveOnly": false,
}
}
jwk, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) jwk, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -619,30 +594,30 @@ func TestRevoke(t *testing.T) {
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
assert.FatalError(t, err) assert.FatalError(t, err)
a := testAuthority(t)
type test struct { type test struct {
a *Authority auth *Authority
opts *RevokeOptions opts *RevokeOptions
err *apiError err error
code int
checkErrDetails func(err *errs.Error)
} }
tests := map[string]func() test{ tests := map[string]func() test{
"error/token/authorizeRevoke error": func() test { "fail/token/authorizeRevoke error": func() test {
a := testAuthority(t)
ctx := getCtx()
ctx["ott"] = "foo"
return test{ return test{
a: a, auth: a,
opts: &RevokeOptions{ opts: &RevokeOptions{
OTT: "foo", OTT: "foo",
Serial: "sn", Serial: "sn",
ReasonCode: reasonCode, ReasonCode: reasonCode,
Reason: reason, Reason: reason,
}, },
err: &apiError{errors.New("revoke: authorizeRevoke: authorizeToken: error parsing token"), err: errors.New("authority.Revoke; error parsing token"),
http.StatusUnauthorized, ctx}, code: http.StatusUnauthorized,
} }
}, },
"error/nil-db": func() test { "fail/nil-db": func() test {
a := testAuthority(t)
cl := jwt.Claims{ cl := jwt.Claims{
Subject: "sn", Subject: "sn",
Issuer: validIssuer, Issuer: validIssuer,
@ -654,30 +629,30 @@ func TestRevoke(t *testing.T) {
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := getCtx()
ctx["ott"] = raw
ctx["tokenID"] = "44"
ctx["provisionerID"] = "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc"
return test{ return test{
a: a, auth: a,
opts: &RevokeOptions{ opts: &RevokeOptions{
Serial: "sn", Serial: "sn",
ReasonCode: reasonCode, ReasonCode: reasonCode,
Reason: reason, Reason: reason,
OTT: raw, OTT: raw,
}, },
err: &apiError{errors.New("revoke: no persistence layer configured"), err: errors.New("authority.Revoke; no persistence layer configured"),
http.StatusNotImplemented, ctx}, code: http.StatusNotImplemented,
checkErrDetails: func(err *errs.Error) {
assert.Equals(t, err.Details["token"], raw)
assert.Equals(t, err.Details["tokenID"], "44")
assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc")
},
} }
}, },
"error/db-revoke": func() test { "fail/db-revoke": func() test {
a := testAuthority(t) _a := testAuthority(t, WithDatabase(&db.MockAuthDB{
a.db = &MockAuthDB{ MUseToken: func(id, tok string) (bool, error) {
useToken: func(id, tok string) (bool, error) {
return true, nil return true, nil
}, },
err: errors.New("force"), Err: errors.New("force"),
} }))
cl := jwt.Claims{ cl := jwt.Claims{
Subject: "sn", Subject: "sn",
@ -690,30 +665,30 @@ func TestRevoke(t *testing.T) {
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := getCtx()
ctx["ott"] = raw
ctx["tokenID"] = "44"
ctx["provisionerID"] = "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc"
return test{ return test{
a: a, auth: _a,
opts: &RevokeOptions{ opts: &RevokeOptions{
Serial: "sn", Serial: "sn",
ReasonCode: reasonCode, ReasonCode: reasonCode,
Reason: reason, Reason: reason,
OTT: raw, OTT: raw,
}, },
err: &apiError{errors.New("force"), err: errors.New("authority.Revoke: force"),
http.StatusInternalServerError, ctx}, code: http.StatusInternalServerError,
checkErrDetails: func(err *errs.Error) {
assert.Equals(t, err.Details["token"], raw)
assert.Equals(t, err.Details["tokenID"], "44")
assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc")
},
} }
}, },
"error/already-revoked": func() test { "fail/already-revoked": func() test {
a := testAuthority(t) _a := testAuthority(t, WithDatabase(&db.MockAuthDB{
a.db = &MockAuthDB{ MUseToken: func(id, tok string) (bool, error) {
useToken: func(id, tok string) (bool, error) {
return true, nil return true, nil
}, },
err: db.ErrAlreadyExists, Err: db.ErrAlreadyExists,
} }))
cl := jwt.Claims{ cl := jwt.Claims{
Subject: "sn", Subject: "sn",
@ -726,29 +701,29 @@ func TestRevoke(t *testing.T) {
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := getCtx()
ctx["ott"] = raw
ctx["tokenID"] = "44"
ctx["provisionerID"] = "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc"
return test{ return test{
a: a, auth: _a,
opts: &RevokeOptions{ opts: &RevokeOptions{
Serial: "sn", Serial: "sn",
ReasonCode: reasonCode, ReasonCode: reasonCode,
Reason: reason, Reason: reason,
OTT: raw, OTT: raw,
}, },
err: &apiError{errors.New("revoke: certificate with serial number sn has already been revoked"), err: errors.New("authority.Revoke; certificate with serial number sn has already been revoked"),
http.StatusBadRequest, ctx}, code: http.StatusBadRequest,
checkErrDetails: func(err *errs.Error) {
assert.Equals(t, err.Details["token"], raw)
assert.Equals(t, err.Details["tokenID"], "44")
assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc")
},
} }
}, },
"ok/token": func() test { "ok/token": func() test {
a := testAuthority(t) _a := testAuthority(t, WithDatabase(&db.MockAuthDB{
a.db = &MockAuthDB{ MUseToken: func(id, tok string) (bool, error) {
useToken: func(id, tok string) (bool, error) {
return true, nil return true, nil
}, },
} }))
cl := jwt.Claims{ cl := jwt.Claims{
Subject: "sn", Subject: "sn",
@ -761,7 +736,7 @@ func TestRevoke(t *testing.T) {
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
a: a, auth: _a,
opts: &RevokeOptions{ opts: &RevokeOptions{
Serial: "sn", Serial: "sn",
ReasonCode: reasonCode, ReasonCode: reasonCode,
@ -770,39 +745,14 @@ func TestRevoke(t *testing.T) {
}, },
} }
}, },
"error/mTLS/authorizeRevoke": func() test {
a := testAuthority(t)
a.db = &MockAuthDB{}
crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt")
assert.FatalError(t, err)
ctx := getCtx()
ctx["certificate"] = base64.StdEncoding.EncodeToString(crt.Raw)
ctx["mTLS"] = true
return test{
a: a,
opts: &RevokeOptions{
Crt: crt,
Serial: "sn",
ReasonCode: reasonCode,
Reason: reason,
MTLS: true,
},
err: &apiError{errors.New("revoke: authorizeRevoke: serial number in certificate different than body"),
http.StatusUnauthorized, ctx},
}
},
"ok/mTLS": func() test { "ok/mTLS": func() test {
a := testAuthority(t) _a := testAuthority(t, WithDatabase(&db.MockAuthDB{}))
a.db = &MockAuthDB{}
crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt") crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt")
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
a: a, auth: _a,
opts: &RevokeOptions{ opts: &RevokeOptions{
Crt: crt, Crt: crt,
Serial: "102012593071130646873265215610956555026", Serial: "102012593071130646873265215610956555026",
@ -816,15 +766,24 @@ func TestRevoke(t *testing.T) {
for name, f := range tests { for name, f := range tests {
tc := f() tc := f()
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
if err := tc.a.Revoke(context.TODO(), tc.opts); err != nil { ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod)
if assert.NotNil(t, tc.err) { if err := tc.auth.Revoke(ctx, tc.opts); err != nil {
switch v := err.(type) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
case *apiError: sc, ok := err.(errs.StatusCoder)
assert.HasPrefix(t, v.err.Error(), tc.err.Error()) assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, v.code, tc.err.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.Equals(t, v.context, tc.err.context) assert.HasPrefix(t, err.Error(), tc.err.Error())
default:
t.Errorf("unexpected error type: %T", v) ctxErr, ok := err.(*errs.Error)
assert.Fatal(t, ok, "error is not of type *errs.Error")
assert.Equals(t, ctxErr.Details["serialNumber"], tc.opts.Serial)
assert.Equals(t, ctxErr.Details["reasonCode"], tc.opts.ReasonCode)
assert.Equals(t, ctxErr.Details["reason"], tc.opts.Reason)
assert.Equals(t, ctxErr.Details["MTLS"], tc.opts.MTLS)
assert.Equals(t, ctxErr.Details["context"], string(provisioner.RevokeMethod))
if tc.checkErrDetails != nil {
tc.checkErrDetails(ctxErr)
} }
} }
} else { } else {

View file

@ -22,6 +22,7 @@ import (
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/keys" "github.com/smallstep/cli/crypto/keys"
"github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/randutil" "github.com/smallstep/cli/crypto/randutil"
@ -102,7 +103,7 @@ func TestCASign(t *testing.T) {
ca: ca, ca: ca,
body: "invalid json", body: "invalid json",
status: http.StatusBadRequest, status: http.StatusBadRequest,
errMsg: "Bad Request", errMsg: errs.BadRequestDefaultMsg,
} }
}, },
"fail invalid-csr-sig": func(t *testing.T) *signTest { "fail invalid-csr-sig": func(t *testing.T) *signTest {
@ -140,7 +141,7 @@ ZEp7knvU2psWRw==
ca: ca, ca: ca,
body: string(body), body: string(body),
status: http.StatusBadRequest, status: http.StatusBadRequest,
errMsg: "Bad Request", errMsg: errs.BadRequestDefaultMsg,
} }
}, },
"fail unauthorized-ott": func(t *testing.T) *signTest { "fail unauthorized-ott": func(t *testing.T) *signTest {
@ -155,7 +156,7 @@ ZEp7knvU2psWRw==
ca: ca, ca: ca,
body: string(body), body: string(body),
status: http.StatusUnauthorized, status: http.StatusUnauthorized,
errMsg: "Unauthorized", errMsg: errs.UnauthorizedDefaultMsg,
} }
}, },
"fail commonname-claim": func(t *testing.T) *signTest { "fail commonname-claim": func(t *testing.T) *signTest {
@ -188,7 +189,7 @@ ZEp7knvU2psWRw==
ca: ca, ca: ca,
body: string(body), body: string(body),
status: http.StatusUnauthorized, status: http.StatusUnauthorized,
errMsg: "Unauthorized", errMsg: errs.UnauthorizedDefaultMsg,
} }
}, },
"ok": func(t *testing.T) *signTest { "ok": func(t *testing.T) *signTest {
@ -392,7 +393,7 @@ func TestCAProvisionerEncryptedKey(t *testing.T) {
ca: ca, ca: ca,
kid: "foo", kid: "foo",
status: http.StatusNotFound, status: http.StatusNotFound,
errMsg: "Not Found", errMsg: errs.NotFoundDefaultMsg,
} }
}, },
"ok": func(t *testing.T) *ekt { "ok": func(t *testing.T) *ekt {
@ -455,7 +456,7 @@ func TestCARoot(t *testing.T) {
ca: ca, ca: ca,
sha: "foo", sha: "foo",
status: http.StatusNotFound, status: http.StatusNotFound,
errMsg: "Not Found", errMsg: errs.NotFoundDefaultMsg,
} }
}, },
"success": func(t *testing.T) *rootTest { "success": func(t *testing.T) *rootTest {
@ -575,7 +576,7 @@ func TestCARenew(t *testing.T) {
ca: ca, ca: ca,
tlsConnState: nil, tlsConnState: nil,
status: http.StatusBadRequest, status: http.StatusBadRequest,
errMsg: "Bad Request", errMsg: errs.BadRequestDefaultMsg,
} }
}, },
"request-missing-peer-certificate": func(t *testing.T) *renewTest { "request-missing-peer-certificate": func(t *testing.T) *renewTest {
@ -583,7 +584,7 @@ func TestCARenew(t *testing.T) {
ca: ca, ca: ca,
tlsConnState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{}}, tlsConnState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{}},
status: http.StatusBadRequest, status: http.StatusBadRequest,
errMsg: "Bad Request", errMsg: errs.BadRequestDefaultMsg,
} }
}, },
"success": func(t *testing.T) *renewTest { "success": func(t *testing.T) *renewTest {

View file

@ -486,7 +486,7 @@ func (c *Client) Version() (*api.VersionResponse, error) {
retry: retry:
resp, err := c.client.Get(u.String()) resp, err := c.client.Get(u.String())
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u) return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Version; client GET %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) { if !retried && c.retryOnError(resp) {
@ -497,7 +497,7 @@ retry:
} }
var version api.VersionResponse var version api.VersionResponse
if err := readJSON(resp.Body, &version); err != nil { if err := readJSON(resp.Body, &version); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u) return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Version; error reading %s", u)
} }
return &version, nil return &version, nil
} }
@ -510,7 +510,7 @@ func (c *Client) Health() (*api.HealthResponse, error) {
retry: retry:
resp, err := c.client.Get(u.String()) resp, err := c.client.Get(u.String())
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u) return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Health; client GET %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) { if !retried && c.retryOnError(resp) {
@ -521,7 +521,7 @@ retry:
} }
var health api.HealthResponse var health api.HealthResponse
if err := readJSON(resp.Body, &health); err != nil { if err := readJSON(resp.Body, &health); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u) return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Health; error reading %s", u)
} }
return &health, nil return &health, nil
} }
@ -537,7 +537,7 @@ func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
retry: retry:
resp, err := newInsecureClient().Get(u.String()) resp, err := newInsecureClient().Get(u.String())
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u) return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Root; client GET %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) { if !retried && c.retryOnError(resp) {
@ -548,12 +548,12 @@ retry:
} }
var root api.RootResponse var root api.RootResponse
if err := readJSON(resp.Body, &root); err != nil { if err := readJSON(resp.Body, &root); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u) return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Root; error reading %s", u)
} }
// verify the sha256 // verify the sha256
sum := sha256.Sum256(root.RootPEM.Raw) sum := sha256.Sum256(root.RootPEM.Raw)
if sha256Sum != strings.ToLower(hex.EncodeToString(sum[:])) { if sha256Sum != strings.ToLower(hex.EncodeToString(sum[:])) {
return nil, errors.New("root certificate SHA256 fingerprint do not match") return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match")
} }
return &root, nil return &root, nil
} }
@ -564,13 +564,13 @@ func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) {
var retried bool var retried bool
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error marshaling request") return nil, errs.Wrap(http.StatusInternalServerError, err, "client.Sign; error marshaling request")
} }
u := c.endpoint.ResolveReference(&url.URL{Path: "/sign"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/sign"})
retry: retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u) return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Sign; client POST %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) { if !retried && c.retryOnError(resp) {
@ -581,7 +581,7 @@ retry:
} }
var sign api.SignResponse var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil { if err := readJSON(resp.Body, &sign); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u) return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Sign; error reading %s", u)
} }
// Add tls.ConnectionState: // Add tls.ConnectionState:
// We'll extract the root certificate from the verified chains // We'll extract the root certificate from the verified chains
@ -598,7 +598,7 @@ func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
retry: retry:
resp, err := client.Post(u.String(), "application/json", http.NoBody) resp, err := client.Post(u.String(), "application/json", http.NoBody)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u) return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Renew; client POST %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) { if !retried && c.retryOnError(resp) {
@ -609,7 +609,7 @@ retry:
} }
var sign api.SignResponse var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil { if err := readJSON(resp.Body, &sign); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u) return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Renew; error reading %s", u)
} }
return &sign, nil return &sign, nil
} }
@ -961,8 +961,8 @@ func (c *Client) SSHCheckHost(principal string, token string) (*api.SSHCheckPrin
retry: retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil { if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client POST %s failed", u, return nil, errs.Wrapf(http.StatusInternalServerError, err, "client POST %s failed",
errs.WithMessage("Failed to perform POST request to %s", u)) []interface{}{u, errs.WithMessage("Failed to perform POST request to %s", u)}...)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) { if !retried && c.retryOnError(resp) {
@ -974,8 +974,8 @@ retry:
} }
var check api.SSHCheckPrincipalResponse var check api.SSHCheckPrincipalResponse
if err := readJSON(resp.Body, &check); err != nil { if err := readJSON(resp.Body, &check); err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response", u, return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response",
errs.WithMessage("Failed to parse response from /ssh/check-host endpoint")) []interface{}{u, errs.WithMessage("Failed to parse response from /ssh/check-host endpoint")})
} }
return &check, nil return &check, nil
} }
@ -1008,13 +1008,13 @@ func (c *Client) SSHBastion(req *api.SSHBastionRequest) (*api.SSHBastionResponse
var retried bool var retried bool
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error marshaling request") return nil, errors.Wrap(err, "client.SSHBastion; error marshaling request")
} }
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/bastion"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/bastion"})
retry: retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u) return nil, errors.Wrapf(err, "client.SSHBastion; client POST %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) { if !retried && c.retryOnError(resp) {
@ -1025,7 +1025,7 @@ retry:
} }
var bastion api.SSHBastionResponse var bastion api.SSHBastionResponse
if err := readJSON(resp.Body, &bastion); err != nil { if err := readJSON(resp.Body, &bastion); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u) return nil, errors.Wrapf(err, "client.SSHBastion; error reading %s", u)
} }
return &bastion, nil return &bastion, nil
} }

View file

@ -16,12 +16,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/smallstep/certificates/errs" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/x509util" "github.com/smallstep/cli/crypto/x509util"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -154,18 +154,17 @@ func equalJSON(t *testing.T, a interface{}, b interface{}) bool {
func TestClient_Version(t *testing.T) { func TestClient_Version(t *testing.T) {
ok := &api.VersionResponse{Version: "test"} ok := &api.VersionResponse{Version: "test"}
internal := errs.InternalServerError(fmt.Errorf("Internal Server Error"))
notFound := errs.NotFound(fmt.Errorf("Not Found"))
tests := []struct { tests := []struct {
name string name string
response interface{} response interface{}
responseCode int responseCode int
wantErr bool wantErr bool
expectedErr error
}{ }{
{"ok", ok, 200, false}, {"ok", ok, 200, false, nil},
{"500", internal, 500, true}, {"500", errs.InternalServer("force"), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
{"404", notFound, 404, true}, {"404", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -185,7 +184,6 @@ func TestClient_Version(t *testing.T) {
got, err := c.Version() got, err := c.Version()
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
fmt.Printf("%+v", err)
t.Errorf("Client.Version() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Client.Version() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
@ -195,9 +193,7 @@ func TestClient_Version(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Version() = %v, want nil", got) t.Errorf("Client.Version() = %v, want nil", got)
} }
if !reflect.DeepEqual(err, tt.response) { assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
t.Errorf("Client.Version() error = %v, want %v", err, tt.response)
}
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Version() = %v, want %v", got, tt.response) t.Errorf("Client.Version() = %v, want %v", got, tt.response)
@ -209,16 +205,16 @@ func TestClient_Version(t *testing.T) {
func TestClient_Health(t *testing.T) { func TestClient_Health(t *testing.T) {
ok := &api.HealthResponse{Status: "ok"} ok := &api.HealthResponse{Status: "ok"}
nok := errs.InternalServerError(fmt.Errorf("Internal Server Error"))
tests := []struct { tests := []struct {
name string name string
response interface{} response interface{}
responseCode int responseCode int
wantErr bool wantErr bool
expectedErr error
}{ }{
{"ok", ok, 200, false}, {"ok", ok, 200, false, nil},
{"not ok", nok, 500, true}, {"not ok", errs.InternalServer("force"), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -248,9 +244,7 @@ func TestClient_Health(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Health() = %v, want nil", got) t.Errorf("Client.Health() = %v, want nil", got)
} }
if !reflect.DeepEqual(err, tt.response) { assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
t.Errorf("Client.Health() error = %v, want %v", err, tt.response)
}
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Health() = %v, want %v", got, tt.response) t.Errorf("Client.Health() = %v, want %v", got, tt.response)
@ -264,7 +258,6 @@ func TestClient_Root(t *testing.T) {
ok := &api.RootResponse{ ok := &api.RootResponse{
RootPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, RootPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
} }
notFound := errs.NotFound(fmt.Errorf("Not Found"))
tests := []struct { tests := []struct {
name string name string
@ -272,9 +265,10 @@ func TestClient_Root(t *testing.T) {
response interface{} response interface{}
responseCode int responseCode int
wantErr bool wantErr bool
expectedErr error
}{ }{
{"ok", "a047a37fa2d2e118a4f5095fe074d6cfe0e352425a7632bf8659c03919a6c81d", ok, 200, false}, {"ok", "a047a37fa2d2e118a4f5095fe074d6cfe0e352425a7632bf8659c03919a6c81d", ok, 200, false, nil},
{"not found", "invalid", notFound, 404, true}, {"not found", "invalid", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -307,9 +301,7 @@ func TestClient_Root(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Root() = %v, want nil", got) t.Errorf("Client.Root() = %v, want nil", got)
} }
if !reflect.DeepEqual(err, tt.response) { assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
t.Errorf("Client.Root() error = %v, want %v", err, tt.response)
}
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Root() = %v, want %v", got, tt.response) t.Errorf("Client.Root() = %v, want %v", got, tt.response)
@ -334,8 +326,6 @@ func TestClient_Sign(t *testing.T) {
NotBefore: api.NewTimeDuration(time.Now()), NotBefore: api.NewTimeDuration(time.Now()),
NotAfter: api.NewTimeDuration(time.Now().AddDate(0, 1, 0)), NotAfter: api.NewTimeDuration(time.Now().AddDate(0, 1, 0)),
} }
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct { tests := []struct {
name string name string
@ -343,11 +333,12 @@ func TestClient_Sign(t *testing.T) {
response interface{} response interface{}
responseCode int responseCode int
wantErr bool wantErr bool
expectedErr error
}{ }{
{"ok", request, ok, 200, false}, {"ok", request, ok, 200, false, nil},
{"unauthorized", request, unauthorized, 401, true}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", &api.SignRequest{}, badRequest, 403, true}, {"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", nil, badRequest, 403, true}, {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -364,7 +355,9 @@ func TestClient_Sign(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.SignRequest) body := new(api.SignRequest)
if err := api.ReadJSON(req.Body, body); err != nil { if err := api.ReadJSON(req.Body, body); err != nil {
api.WriteError(w, badRequest) e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e)
return return
} else if !equalJSON(t, body, tt.request) { } else if !equalJSON(t, body, tt.request) {
if tt.request == nil { if tt.request == nil {
@ -390,9 +383,7 @@ func TestClient_Sign(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Sign() = %v, want nil", got) t.Errorf("Client.Sign() = %v, want nil", got)
} }
if !reflect.DeepEqual(err, tt.response) { assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
t.Errorf("Client.Sign() error = %v, want %v", err, tt.response)
}
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Sign() = %v, want %v", got, tt.response) t.Errorf("Client.Sign() = %v, want %v", got, tt.response)
@ -409,19 +400,17 @@ func TestClient_Revoke(t *testing.T) {
OTT: "the-ott", OTT: "the-ott",
ReasonCode: 4, ReasonCode: 4,
} }
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct { tests := []struct {
name string name string
request *api.RevokeRequest request *api.RevokeRequest
response interface{} response interface{}
responseCode int responseCode int
wantErr bool wantErr bool
expectedErr error
}{ }{
{"ok", request, ok, 200, false}, {"ok", request, ok, 200, false, nil},
{"unauthorized", request, unauthorized, 401, true}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"nil request", nil, badRequest, 403, true}, {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -438,7 +427,9 @@ func TestClient_Revoke(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.RevokeRequest) body := new(api.RevokeRequest)
if err := api.ReadJSON(req.Body, body); err != nil { if err := api.ReadJSON(req.Body, body); err != nil {
api.WriteError(w, badRequest) e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e)
return return
} else if !equalJSON(t, body, tt.request) { } else if !equalJSON(t, body, tt.request) {
if tt.request == nil { if tt.request == nil {
@ -464,9 +455,7 @@ func TestClient_Revoke(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Revoke() = %v, want nil", got) t.Errorf("Client.Revoke() = %v, want nil", got)
} }
if !reflect.DeepEqual(err, tt.response) { assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
t.Errorf("Client.Revoke() error = %v, want %v", err, tt.response)
}
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Revoke() = %v, want %v", got, tt.response) t.Errorf("Client.Revoke() = %v, want %v", got, tt.response)
@ -485,19 +474,18 @@ func TestClient_Renew(t *testing.T) {
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(rootPEM)},
}, },
} }
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct { tests := []struct {
name string name string
response interface{} response interface{}
responseCode int responseCode int
wantErr bool wantErr bool
err error
}{ }{
{"ok", ok, 200, false}, {"ok", ok, 200, false, nil},
{"unauthorized", unauthorized, 401, true}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", badRequest, 403, true}, {"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", badRequest, 403, true}, {"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -527,9 +515,11 @@ func TestClient_Renew(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Renew() = %v, want nil", got) t.Errorf("Client.Renew() = %v, want nil", got)
} }
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Renew() error = %v, want %v", err, tt.response) sc, ok := err.(errs.StatusCoder)
} assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Renew() = %v, want %v", got, tt.response) t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
@ -543,7 +533,7 @@ func TestClient_Provisioners(t *testing.T) {
ok := &api.ProvisionersResponse{ ok := &api.ProvisionersResponse{
Provisioners: provisioner.List{}, Provisioners: provisioner.List{},
} }
internalServerError := errs.InternalServerError(fmt.Errorf("Internal Server Error")) internalServerError := errs.InternalServer("Internal Server Error")
tests := []struct { tests := []struct {
name string name string
@ -589,9 +579,7 @@ func TestClient_Provisioners(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Provisioners() = %v, want nil", got) t.Errorf("Client.Provisioners() = %v, want nil", got)
} }
if !reflect.DeepEqual(err, tt.response) { assert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error())
t.Errorf("Client.Provisioners() error = %v, want %v", err, tt.response)
}
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response) t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response)
@ -605,7 +593,6 @@ func TestClient_ProvisionerKey(t *testing.T) {
ok := &api.ProvisionerKeyResponse{ ok := &api.ProvisionerKeyResponse{
Key: "an encrypted key", Key: "an encrypted key",
} }
notFound := errs.NotFound(fmt.Errorf("Not Found"))
tests := []struct { tests := []struct {
name string name string
@ -613,9 +600,10 @@ func TestClient_ProvisionerKey(t *testing.T) {
response interface{} response interface{}
responseCode int responseCode int
wantErr bool wantErr bool
err error
}{ }{
{"ok", "kid", ok, 200, false}, {"ok", "kid", ok, 200, false, nil},
{"fail", "invalid", notFound, 500, true}, {"fail", "invalid", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -648,9 +636,11 @@ func TestClient_ProvisionerKey(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.ProvisionerKey() = %v, want nil", got) t.Errorf("Client.ProvisionerKey() = %v, want nil", got)
} }
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.ProvisionerKey() error = %v, want %v", err, tt.response) sc, ok := err.(errs.StatusCoder)
} assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response) t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response)
@ -666,19 +656,17 @@ func TestClient_Roots(t *testing.T) {
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(rootPEM)},
}, },
} }
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct { tests := []struct {
name string name string
response interface{} response interface{}
responseCode int responseCode int
wantErr bool wantErr bool
err error
}{ }{
{"ok", ok, 200, false}, {"ok", ok, 200, false, nil},
{"unauthorized", unauthorized, 401, true}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", badRequest, 403, true}, {"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", badRequest, 403, true},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -708,9 +696,10 @@ func TestClient_Roots(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Roots() = %v, want nil", got) t.Errorf("Client.Roots() = %v, want nil", got)
} }
if !reflect.DeepEqual(err, tt.response) { sc, ok := err.(errs.StatusCoder)
t.Errorf("Client.Roots() error = %v, want %v", err, tt.response) assert.Fatal(t, ok, "error does not implement StatusCoder interface")
} assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Roots() = %v, want %v", got, tt.response) t.Errorf("Client.Roots() = %v, want %v", got, tt.response)
@ -726,19 +715,16 @@ func TestClient_Federation(t *testing.T) {
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(rootPEM)},
}, },
} }
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct { tests := []struct {
name string name string
response interface{} response interface{}
responseCode int responseCode int
wantErr bool wantErr bool
err error
}{ }{
{"ok", ok, 200, false}, {"ok", ok, 200, false, nil},
{"unauthorized", unauthorized, 401, true}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", badRequest, 403, true},
{"nil request", badRequest, 403, true},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -768,9 +754,10 @@ func TestClient_Federation(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Federation() = %v, want nil", got) t.Errorf("Client.Federation() = %v, want nil", got)
} }
if !reflect.DeepEqual(err, tt.response) { sc, ok := err.(errs.StatusCoder)
t.Errorf("Client.Federation() error = %v, want %v", err, tt.response) assert.Fatal(t, ok, "error does not implement StatusCoder interface")
} assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Federation() = %v, want %v", got, tt.response) t.Errorf("Client.Federation() = %v, want %v", got, tt.response)
@ -790,16 +777,16 @@ func TestClient_SSHRoots(t *testing.T) {
HostKeys: []api.SSHPublicKey{{PublicKey: key}}, HostKeys: []api.SSHPublicKey{{PublicKey: key}},
UserKeys: []api.SSHPublicKey{{PublicKey: key}}, UserKeys: []api.SSHPublicKey{{PublicKey: key}},
} }
notFound := errs.NotFound(fmt.Errorf("Not Found"))
tests := []struct { tests := []struct {
name string name string
response interface{} response interface{}
responseCode int responseCode int
wantErr bool wantErr bool
err error
}{ }{
{"ok", ok, 200, false}, {"ok", ok, 200, false, nil},
{"not found", notFound, 404, true}, {"not found", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -829,9 +816,10 @@ func TestClient_SSHRoots(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.SSHKeys() = %v, want nil", got) t.Errorf("Client.SSHKeys() = %v, want nil", got)
} }
if !reflect.DeepEqual(err, tt.response) { sc, ok := err.(errs.StatusCoder)
t.Errorf("Client.SSHKeys() error = %v, want %v", err, tt.response) assert.Fatal(t, ok, "error does not implement StatusCoder interface")
} assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response) t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response)
@ -881,7 +869,7 @@ func Test_parseEndpoint(t *testing.T) {
func TestClient_RootFingerprint(t *testing.T) { func TestClient_RootFingerprint(t *testing.T) {
ok := &api.HealthResponse{Status: "ok"} ok := &api.HealthResponse{Status: "ok"}
nok := errs.InternalServerError(fmt.Errorf("Internal Server Error")) nok := errs.InternalServer("Internal Server Error")
httpsServer := httptest.NewTLSServer(nil) httpsServer := httptest.NewTLSServer(nil)
defer httpsServer.Close() defer httpsServer.Close()
@ -948,7 +936,6 @@ func TestClient_SSHBastion(t *testing.T) {
Hostname: "bastion.local", Hostname: "bastion.local",
}, },
} }
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct { tests := []struct {
name string name string
@ -956,11 +943,11 @@ func TestClient_SSHBastion(t *testing.T) {
response interface{} response interface{}
responseCode int responseCode int
wantErr bool wantErr bool
err error
}{ }{
{"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false}, {"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil},
{"bad response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true}, {"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil},
{"empty request", &api.SSHBastionRequest{}, badRequest, 403, true}, {"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", nil, badRequest, 403, true},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -990,8 +977,11 @@ func TestClient_SSHBastion(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.SSHBastion() = %v, want nil", got) t.Errorf("Client.SSHBastion() = %v, want nil", got)
} }
if tt.responseCode != 200 && !reflect.DeepEqual(err, tt.response) { if tt.responseCode != 200 {
t.Errorf("Client.SSHBastion() error = %v, want %v", err, tt.response) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
} }
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {

View file

@ -276,6 +276,7 @@ func TestIdentity_Renew(t *testing.T) {
} }
oldIdentityDir := identityDir oldIdentityDir := identityDir
identityDir = "testdata/identity"
defer func() { defer func() {
identityDir = oldIdentityDir identityDir = oldIdentityDir
os.RemoveAll(tmpDir) os.RemoveAll(tmpDir)

View file

@ -40,7 +40,7 @@ func (c *mutableTLSConfig) Init(base *tls.Config) {
// tls.Config GetConfigForClient. // tls.Config GetConfigForClient.
func (c *mutableTLSConfig) TLSConfig() (config *tls.Config) { func (c *mutableTLSConfig) TLSConfig() (config *tls.Config) {
c.RLock() c.RLock()
config = c.config config = c.config.Clone()
c.RUnlock() c.RUnlock()
return return
} }

View file

@ -80,7 +80,9 @@ func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOption
func (r *TLSRenewer) Run() { func (r *TLSRenewer) Run() {
cert := r.getCertificate() cert := r.getCertificate()
next := r.nextRenewDuration(cert.Leaf.NotAfter) next := r.nextRenewDuration(cert.Leaf.NotAfter)
r.Lock()
r.timer = time.AfterFunc(next, r.renewCertificate) r.timer = time.AfterFunc(next, r.renewCertificate)
r.Unlock()
} }
// RunContext starts the certificate renewer for the given certificate. // RunContext starts the certificate renewer for the given certificate.

View file

@ -270,6 +270,105 @@ func (db *DB) Shutdown() error {
return nil return nil
} }
// MockAuthDB mocks the AuthDB interface. //
type MockAuthDB struct {
Err error
Ret1 interface{}
MIsRevoked func(string) (bool, error)
MIsSSHRevoked func(string) (bool, error)
MRevoke func(rci *RevokedCertificateInfo) error
MRevokeSSH func(rci *RevokedCertificateInfo) error
MStoreCertificate func(crt *x509.Certificate) error
MUseToken func(id, tok string) (bool, error)
MIsSSHHost func(principal string) (bool, error)
MStoreSSHCertificate func(crt *ssh.Certificate) error
MGetSSHHostPrincipals func() ([]string, error)
MShutdown func() error
}
// IsRevoked mock.
func (m *MockAuthDB) IsRevoked(sn string) (bool, error) {
if m.MIsRevoked != nil {
return m.MIsRevoked(sn)
}
return m.Ret1.(bool), m.Err
}
// IsSSHRevoked mock.
func (m *MockAuthDB) IsSSHRevoked(sn string) (bool, error) {
if m.MIsSSHRevoked != nil {
return m.MIsSSHRevoked(sn)
}
return m.Ret1.(bool), m.Err
}
// UseToken mock.
func (m *MockAuthDB) UseToken(id, tok string) (bool, error) {
if m.MUseToken != nil {
return m.MUseToken(id, tok)
}
if m.Ret1 == nil {
return false, m.Err
}
return m.Ret1.(bool), m.Err
}
// Revoke mock.
func (m *MockAuthDB) Revoke(rci *RevokedCertificateInfo) error {
if m.MRevoke != nil {
return m.MRevoke(rci)
}
return m.Err
}
// RevokeSSH mock.
func (m *MockAuthDB) RevokeSSH(rci *RevokedCertificateInfo) error {
if m.MRevokeSSH != nil {
return m.MRevokeSSH(rci)
}
return m.Err
}
// StoreCertificate mock.
func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error {
if m.MStoreCertificate != nil {
return m.MStoreCertificate(crt)
}
return m.Err
}
// IsSSHHost mock.
func (m *MockAuthDB) IsSSHHost(principal string) (bool, error) {
if m.MIsSSHHost != nil {
return m.MIsSSHHost(principal)
}
return m.Ret1.(bool), m.Err
}
// StoreSSHCertificate mock.
func (m *MockAuthDB) StoreSSHCertificate(crt *ssh.Certificate) error {
if m.MStoreSSHCertificate != nil {
return m.MStoreSSHCertificate(crt)
}
return m.Err
}
// GetSSHHostPrincipals mock.
func (m *MockAuthDB) GetSSHHostPrincipals() ([]string, error) {
if m.MGetSSHHostPrincipals != nil {
return m.MGetSSHHostPrincipals()
}
return m.Ret1.([]string), m.Err
}
// Shutdown mock.
func (m *MockAuthDB) Shutdown() error {
if m.MShutdown != nil {
return m.MShutdown()
}
return m.Err
}
// MockNoSQLDB // // MockNoSQLDB //
type MockNoSQLDB struct { type MockNoSQLDB struct {
Err error Err error

View file

@ -21,9 +21,9 @@ type StackTracer interface {
// Option modifies the Error type. // Option modifies the Error type.
type Option func(e *Error) error type Option func(e *Error) error
// WithMessage returns an Option that modifies the error by overwriting the // withDefaultMessage returns an Option that modifies the error by overwriting the
// message only if it is empty. // message only if it is empty.
func WithMessage(format string, args ...interface{}) Option { func withDefaultMessage(format string, args ...interface{}) Option {
return func(e *Error) error { return func(e *Error) error {
if len(e.Msg) > 0 { if len(e.Msg) > 0 {
return e return e
@ -33,31 +33,33 @@ func WithMessage(format string, args ...interface{}) Option {
} }
} }
// WithMessage returns an Option that modifies the error by overwriting the
// message only if it is empty.
func WithMessage(format string, args ...interface{}) Option {
return func(e *Error) error {
e.Msg = fmt.Sprintf(format, args...)
return e
}
}
// WithKeyVal returns an Option that adds the given key-value pair to the
// Error details. This is helpful for debugging errors.
func WithKeyVal(key string, val interface{}) Option {
return func(e *Error) error {
if e.Details == nil {
e.Details = make(map[string]interface{})
}
e.Details[key] = val
return e
}
}
// Error represents the CA API errors. // Error represents the CA API errors.
type Error struct { type Error struct {
Status int Status int
Err error Err error
Msg string Msg string
} Details map[string]interface{}
// New returns a new Error. If the given error implements the StatusCoder
// interface we will ignore the given status.
func New(status int, err error, opts ...Option) error {
var e *Error
if sc, ok := err.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
cause := errors.Cause(err)
if sc, ok := cause.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
e = &Error{Status: status, Err: err}
}
}
for _, o := range opts {
o(e)
}
return e
} }
// ErrorResponse represents an error in JSON format. // ErrorResponse represents an error in JSON format.
@ -92,10 +94,11 @@ func (e *Error) Message() string {
// Wrap returns an error annotating err with a stack trace at the point Wrap is // Wrap returns an error annotating err with a stack trace at the point Wrap is
// called, and the supplied message. If err is nil, Wrap returns nil. // called, and the supplied message. If err is nil, Wrap returns nil.
func Wrap(status int, e error, m string, opts ...Option) error { func Wrap(status int, e error, m string, args ...interface{}) error {
if e == nil { if e == nil {
return nil return nil
} }
_, opts := splitOptionArgs(args)
if err, ok := e.(*Error); ok { if err, ok := e.(*Error); ok {
err.Err = errors.Wrap(err.Err, m) err.Err = errors.Wrap(err.Err, m)
e = err e = err
@ -111,25 +114,12 @@ func Wrapf(status int, e error, format string, args ...interface{}) error {
if e == nil { if e == nil {
return nil return nil
} }
var opts []Option as, opts := splitOptionArgs(args)
for i, arg := range args {
// Once we find the first Option, assume that all further arguments are Options.
if _, ok := arg.(Option); ok {
for _, a := range args[i:] {
// Ignore any arguments after the first Option that are not Options.
if opt, ok := a.(Option); ok {
opts = append(opts, opt)
}
}
args = args[:i]
break
}
}
if err, ok := e.(*Error); ok { if err, ok := e.(*Error); ok {
err.Err = errors.Wrapf(err.Err, format, args...) err.Err = errors.Wrapf(err.Err, format, args...)
e = err e = err
} else { } else {
e = errors.Wrapf(e, format, args...) e = errors.Wrapf(e, format, as...)
} }
return StatusCodeError(status, e, opts...) return StatusCodeError(status, e, opts...)
} }
@ -174,77 +164,172 @@ type Messenger interface {
func StatusCodeError(code int, e error, opts ...Option) error { func StatusCodeError(code int, e error, opts ...Option) error {
switch code { switch code {
case http.StatusBadRequest: case http.StatusBadRequest:
return BadRequest(e, opts...) return BadRequestErr(e, opts...)
case http.StatusUnauthorized: case http.StatusUnauthorized:
return Unauthorized(e, opts...) return UnauthorizedErr(e, opts...)
case http.StatusForbidden: case http.StatusForbidden:
return Forbidden(e, opts...) return ForbiddenErr(e, opts...)
case http.StatusInternalServerError: case http.StatusInternalServerError:
return InternalServerError(e, opts...) return InternalServerErr(e, opts...)
case http.StatusNotImplemented: case http.StatusNotImplemented:
return NotImplemented(e, opts...) return NotImplementedErr(e, opts...)
default: default:
return UnexpectedError(code, e, opts...) return UnexpectedErr(code, e, opts...)
} }
} }
var seeLogs = "Please see the certificate authority logs for more info." var (
seeLogs = "Please see the certificate authority logs for more info."
// BadRequestDefaultMsg 400 default msg
BadRequestDefaultMsg = "The request could not be completed; malformed or missing data" + seeLogs
// UnauthorizedDefaultMsg 401 default msg
UnauthorizedDefaultMsg = "The request lacked necessary authorization to be completed. " + seeLogs
// ForbiddenDefaultMsg 403 default msg
ForbiddenDefaultMsg = "The request was forbidden by the certificate authority. " + seeLogs
// NotFoundDefaultMsg 404 default msg
NotFoundDefaultMsg = "The requested resource could not be found. " + seeLogs
// InternalServerErrorDefaultMsg 500 default msg
InternalServerErrorDefaultMsg = "The certificate authority encountered an Internal Server Error. " + seeLogs
// NotImplementedDefaultMsg 501 default msg
NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs
)
// InternalServerError returns a 500 error with the given error. // splitOptionArgs splits the variadic length args into string formatting args
func InternalServerError(err error, opts ...Option) error { // and Option(s) to apply to an Error.
if len(opts) == 0 { func splitOptionArgs(args []interface{}) ([]interface{}, []Option) {
opts = append(opts, WithMessage("The certificate authority encountered an Internal Server Error. "+seeLogs)) indexOptionStart := -1
for i, a := range args {
if _, ok := a.(Option); ok {
indexOptionStart = i
break
} }
return New(http.StatusInternalServerError, err, opts...) }
if indexOptionStart < 0 {
return args, []Option{}
}
opts := []Option{}
// Ignore any non-Option args that come after the first Option.
for _, o := range args[indexOptionStart:] {
if opt, ok := o.(Option); ok {
opts = append(opts, opt)
}
}
return args[:indexOptionStart], opts
} }
// NotImplemented returns a 501 error with the given error. // NewErr returns a new Error. If the given error implements the StatusCoder
func NotImplemented(err error, opts ...Option) error { // interface we will ignore the given status.
if len(opts) == 0 { func NewErr(status int, err error, opts ...Option) error {
opts = append(opts, WithMessage("The requested method is not implemented by the certificate authority. "+seeLogs)) var (
e *Error
ok bool
)
if e, ok = err.(*Error); !ok {
if sc, ok := err.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
cause := errors.Cause(err)
if sc, ok := cause.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
e = &Error{Status: status, Err: err}
} }
return New(http.StatusNotImplemented, err, opts...) }
}
for _, o := range opts {
o(e)
}
return e
} }
// BadRequest returns an 400 error with the given error. // Errorf creates a new error using the given format and status code.
func BadRequest(err error, opts ...Option) error { func Errorf(code int, format string, args ...interface{}) error {
if len(opts) == 0 { as, opts := splitOptionArgs(args)
opts = append(opts, WithMessage("The request could not be completed due to being poorly formatted or "+ opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg))
"missing critical data. "+seeLogs)) e := &Error{Status: code, Err: fmt.Errorf(format, as...)}
for _, o := range opts {
o(e)
} }
return New(http.StatusBadRequest, err, opts...) return e
} }
// Unauthorized returns an 401 error with the given error. // InternalServer creates a 500 error with the given format and arguments.
func Unauthorized(err error, opts ...Option) error { func InternalServer(format string, args ...interface{}) error {
if len(opts) == 0 { args = append(args, withDefaultMessage(InternalServerErrorDefaultMsg))
opts = append(opts, WithMessage("The request lacked necessary authorization to be completed. "+seeLogs)) return Errorf(http.StatusInternalServerError, format, args...)
}
return New(http.StatusUnauthorized, err, opts...)
} }
// Forbidden returns an 403 error with the given error. // InternalServerErr returns a 500 error with the given error.
func Forbidden(err error, opts ...Option) error { func InternalServerErr(err error, opts ...Option) error {
if len(opts) == 0 { opts = append(opts, withDefaultMessage(InternalServerErrorDefaultMsg))
opts = append(opts, WithMessage("The request was Forbidden by the certificate authority. "+seeLogs)) return NewErr(http.StatusInternalServerError, err, opts...)
}
return New(http.StatusForbidden, err, opts...)
} }
// NotFound returns an 404 error with the given error. // NotImplemented creates a 501 error with the given format and arguments.
func NotFound(err error, opts ...Option) error { func NotImplemented(format string, args ...interface{}) error {
if len(opts) == 0 { args = append(args, withDefaultMessage(NotImplementedDefaultMsg))
opts = append(opts, WithMessage("The requested resource could not be found. "+seeLogs)) return Errorf(http.StatusNotImplemented, format, args...)
}
return New(http.StatusNotFound, err, opts...)
} }
// UnexpectedError will be used when the certificate authority makes an outgoing // NotImplementedErr returns a 501 error with the given error.
func NotImplementedErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg))
return NewErr(http.StatusNotImplemented, err, opts...)
}
// BadRequest creates a 400 error with the given format and arguments.
func BadRequest(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(BadRequestDefaultMsg))
return Errorf(http.StatusBadRequest, format, args...)
}
// BadRequestErr returns an 400 error with the given error.
func BadRequestErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(BadRequestDefaultMsg))
return NewErr(http.StatusBadRequest, err, opts...)
}
// Unauthorized creates a 401 error with the given format and arguments.
func Unauthorized(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(UnauthorizedDefaultMsg))
return Errorf(http.StatusUnauthorized, format, args...)
}
// UnauthorizedErr returns an 401 error with the given error.
func UnauthorizedErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(UnauthorizedDefaultMsg))
return NewErr(http.StatusUnauthorized, err, opts...)
}
// Forbidden creates a 403 error with the given format and arguments.
func Forbidden(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(ForbiddenDefaultMsg))
return Errorf(http.StatusForbidden, format, args...)
}
// ForbiddenErr returns an 403 error with the given error.
func ForbiddenErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(ForbiddenDefaultMsg))
return NewErr(http.StatusForbidden, err, opts...)
}
// NotFound creates a 404 error with the given format and arguments.
func NotFound(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(NotFoundDefaultMsg))
return Errorf(http.StatusNotFound, format, args...)
}
// NotFoundErr returns an 404 error with the given error.
func NotFoundErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(NotFoundDefaultMsg))
return NewErr(http.StatusNotFound, err, opts...)
}
// UnexpectedErr will be used when the certificate authority makes an outgoing
// request and receives an unhandled status code. // request and receives an unhandled status code.
func UnexpectedError(code int, err error, opts ...Option) error { func UnexpectedErr(code int, err error, opts ...Option) error {
if len(opts) == 0 { opts = append(opts, withDefaultMessage("The certificate authority received an "+
opts = append(opts, WithMessage("The certificate authority received an "+
"unexpected HTTP status code - '%d'. "+seeLogs, code)) "unexpected HTTP status code - '%d'. "+seeLogs, code))
} return NewErr(code, err, opts...)
return New(code, err, opts...)
} }

View file

@ -1,11 +1,9 @@
package api package errs
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"testing" "testing"
"github.com/smallstep/certificates/errs"
) )
func TestError_MarshalJSON(t *testing.T) { func TestError_MarshalJSON(t *testing.T) {
@ -24,7 +22,7 @@ func TestError_MarshalJSON(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
e := &errs.Error{ e := &Error{
Status: tt.fields.Status, Status: tt.fields.Status,
Err: tt.fields.Err, Err: tt.fields.Err,
} }
@ -47,15 +45,15 @@ func TestError_UnmarshalJSON(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args args args args
expected *errs.Error expected *Error
wantErr bool wantErr bool
}{ }{
{"ok", args{[]byte(`{"status":400,"message":"bad request"}`)}, &errs.Error{Status: 400, Err: fmt.Errorf("bad request")}, false}, {"ok", args{[]byte(`{"status":400,"message":"bad request"}`)}, &Error{Status: 400, Err: fmt.Errorf("bad request")}, false},
{"fail", args{[]byte(`{"status":"400","message":"bad request"}`)}, &errs.Error{}, true}, {"fail", args{[]byte(`{"status":"400","message":"bad request"}`)}, &Error{}, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
e := new(errs.Error) e := new(Error)
if err := e.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { if err := e.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
t.Errorf("Error.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Error.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
} }